Best way to process a large array given a smaller binary mask

Hi, I am working with whole slide images using napari-lazy-openslide.
So I have a large dask array corrisponding to an image with shape (X, Y, C) and a binary mask with a shape (x, y) where x <X and y < Y.

I need to slice the image using the binary mask, process the sliced array, generate a new array with shape (X,Y) that is zeros everywhere except for the regions corrisponding to the binary mask.
Ideally, in terms of some pseudo code:

out = some_function(image[fit(binary_mask)])


image-> shape X,Y, C
binary_mask -> shape x,y
fit(binary_mask) -> shape X,Y,C (it scales array and add a new axis)
out -> shape X,Y

I am trying to achieve that this way:

import dask.array as da

image = da.random.randint(0, 255, (100, 100, 3), dtype="uint8") #this just fakes the image
mask = da.zeros((10, 10), dtype="bool")
mask[0, :] = 1  #just for having some non zero values

def fit(array, h, w):
    x, y = array.shape
    m, n = x * h, y * w
    rescaled = da.broadcast_to(array.reshape(x, 1, y, 1), (x, h, y, w)).reshape(
        m, n
    return da.stack([rescaled]*3, axis=2)

fitted_mask = fit(mask, 10, 10)

masked_image = image[fitted_mask].compute_chunk_sizes().reshape(-1,3) # this is in my real use case

out = da.map_blocks(lambda array: np.ones(array.shape[:1]), masked_image, drop_axis=[1], meta=np.array(()))

processed_image =da.zeros(fitted_mask.shape[:2]) # final output is a binary image
processed_image = da.piecewise(processed_image, [fitted_mask[:,:, 0]], [out])

This just works fine, anyway an error occurs when multiple chunks exists in the arrays:

processed_image =da.zeros(fitted_mask.shape[:2])
processed_image = da.piecewise(processed_image.rechunk(2, 2, 3), [fitted_mask[:,:, 0].rechunk(2, 2)], [out.rechunk(2)])

#it generates ValueError: NumPy boolean array indexing assignment cannot assign 100 input values to the 0 output values where the mask is true.

Probably I misunderstood how to use da.piecewise, moreover I am not sure this is the right approach from a performance point of view, since the fit function and the rechunks can be heavy tasks.

Please note I am avoiding to use da.map_blocks on the image, since it has to be read from the file system and I assume it affects the performance.

Any help is appreciated, thanks.