Using map_blocks() to predict 2D keras model on 3D dask array

I’ve trained a custom keras model on annotated 2d slices from 3d CT-scan images. I now want to predict on a whole 3d image, i.e. to apply the model.predict() function along the z-dimension of the array.

I know that I can do this using a for-loop or by converting to xarray and using apply_ufunc(), but I’m wondering if there is a way to do it using dask and map_blocks() or some other method. My goal is to make the prediction go faster.

I’ve tried to make a minimum reproducible example using a pre-trained model and random data but I seem to make something wrong with map_blocks().

from keras.applications import vgg16
import dask.array as darray
import numpy as np

    input_shape = (512, 512, 3)
    VGG16_weight = "imagenet"
    VGG16 = vgg16.VGG16(include_top=False, weights=VGG16_weight, input_shape=input_shape)

    arr = darray.from_array(np.random.randint(0, (2**16)-1, size=(7000, 512, 512, 3), dtype=np.uint32))  # dummy RGB 3d-array
    sz, sy, sx, zb = arr.shape
    
    # Testing on a single slice, this works

    slice = arr[0, :, :, :]  
    slice_exp = np.expand_dims(slice, axis=0)  # expand batch dimension
    slice_pred = VGG16.predict(slice_exp)
    print(slice_pred)

    # Testing on full stack

    arr_exp = np.expand_dims(arr, axis=0)  
    arr_exp = arr_exp.rechunk((1, sz, sy, sx, zb))
    arr_pred = darray.map_blocks(lambda x: VGG16.predict(x), arr_exp, chunks=(1, sz, sy, sx, zb))
    arr_pred = arr_pred.compute_chunk_sizes()
    print(arr_pred)

Error:

—> 41 arr_pred = arr_pred.compute_chunk_sizes()
42 print(arr_pred)

File ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/dask/array/core.py:1501, in Array.compute_chunk_sizes(self)
1496 c.append(tuple(chunk_shapes[s]))
1498 # map_blocks assigns numpy dtypes
1499 # cast chunk dimensions back to python int before returning
1500 x._chunks = tuple(
→ 1501 tuple(int(chunk) for chunk in chunks) for chunks in compute(tuple(c))[0]
1502 )
1503 return x

File ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/dask/base.py:664, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
661 postcomputes.append(x.dask_postcompute())
663 with shorten_traceback():
→ 664 results = schedule(dsk, keys, **kwargs)
666 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

Cell In[4], line 40, in main..(x)
38 arr_exp = np.expand_dims(arr, axis=0)
39 arr_exp = arr_exp.rechunk((1, sz, sy, sx, zb))
—> 40 arr_pred = darray.map_blocks(lambda x: VGG16.predict(x), arr_exp, dtype=np.uint32, chunks=(1, sz, sy, sx, zb))
41 arr_pred = arr_pred.compute_chunk_sizes()
42 print(arr_pred)

File ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.traceback)
120 # To get the full stack trace, call:
121 # keras.config.disable_traceback_filtering()
→ 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb

File ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.traceback)
120 # To get the full stack trace, call:
121 # keras.config.disable_traceback_filtering()
→ 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb

ValueError: as_list() is not defined on an unknown TensorShape."

Hi @khyll,

I played a bit with your example without vgg16 in the loop, and I have to admit there are some parts I don’t understand.

Here, you are building a 20GB array using numpy, you should use darray.random. This does not work on my environment because of memory issue.

I don’t get this part: why are you adding a dimension? And most of all why are you rechunking the whole array inside only one chunk?

I think here the chunks kwarg is not good either, and you shouldn’t need it if it doesn’t change.

I guess the whole question is how to build appropriate chunks to apply your method. In the end, you want blocks that are not chunked allong sy, sx, sb dimensions, but along the fourth sz one. Then, you need to use map_blocks, and then probably some loop in each block on your applied function. Building chunks of size (1,512,512,3) would be doable, but make them a little too small.

You shouldn’t have to use this call.

Here is my toy example (not working at the end):

import dask.array as darray
import numpy as np

input_shape = (512, 512, 3)
arr = darray.random.randint(0, (2**16)-1, size=(7000, 512, 512, 3), dtype=np.uint32)  # dummy RGB 3d-array
sz, sy, sx, zb = arr.shape

arr_exp = np.expand_dims(arr, axis=0)
arr_chunk = arr_exp.rechunk((1, sz, sy, sx, zb))
arr_pred = darray.map_blocks(lambda x: x * 2, arr_chunk, chunks=(1, sz, sy, sx, zb))
arr_pred = arr_pred.compute_chunk_sizes()