Dask array with pytorch

Hi All,
I have a question about the Dask array pytorch example

In Step 4

# Apply UNet featurization
out = da.map_blocks(unet_featurize, imgs, model, dtype=np.float32, chunks=(1, 1, imgs.shape[2], imgs.shape[3], 16), new_axis=-1)

why the chunk shape/size is (1, 1, imgs.shape[2], imgs.shape[3], 16)
I am confused why there is 16 at the end.


Hi @zeroth, welcome to Dask community!

In this example, we are applying a pretrained model to a Dask Array, using map_blocks to apply the model to each chunk of data. As explained in Step 2:

This UNet model takes in an 2D image and returns a 2D x 16 array

So we expect a new dimension of len 16 after applying the model to the Dask Array, which is why we are telling map_blocks that the output chunk shape is (1, 1, imgs.shape[2], imgs.shape[3], 16).

Does it make things clearer to you ?

Hi @guillaumeeb ,
Thanks for the explanation.
This makes sense.

1 Like