I’ve trained a custom keras model on annotated 2d slices from 3d CT-scan images. I now want to predict on a whole 3d image, i.e. to apply the model.predict() function along the z-dimension of the array.
I know that I can do this using a for-loop or by converting to xarray and using apply_ufunc(), but I’m wondering if there is a way to do it using dask and map_blocks() or some other method. My goal is to make the prediction go faster.
I’ve tried to make a minimum reproducible example using a pre-trained model and random data but I seem to make something wrong with map_blocks().
from keras.applications import vgg16
import dask.array as darray
import numpy as np
input_shape = (512, 512, 3)
VGG16_weight = "imagenet"
VGG16 = vgg16.VGG16(include_top=False, weights=VGG16_weight, input_shape=input_shape)
arr = darray.from_array(np.random.randint(0, (2**16)-1, size=(7000, 512, 512, 3), dtype=np.uint32)) # dummy RGB 3d-array
sz, sy, sx, zb = arr.shape
# Testing on a single slice, this works
slice = arr[0, :, :, :]
slice_exp = np.expand_dims(slice, axis=0) # expand batch dimension
slice_pred = VGG16.predict(slice_exp)
print(slice_pred)
# Testing on full stack
arr_exp = np.expand_dims(arr, axis=0)
arr_exp = arr_exp.rechunk((1, sz, sy, sx, zb))
arr_pred = darray.map_blocks(lambda x: VGG16.predict(x), arr_exp, chunks=(1, sz, sy, sx, zb))
arr_pred = arr_pred.compute_chunk_sizes()
print(arr_pred)
Error:
—> 41 arr_pred = arr_pred.compute_chunk_sizes()
42 print(arr_pred)File ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/dask/array/core.py:1501, in Array.compute_chunk_sizes(self)
1496 c.append(tuple(chunk_shapes[s]))
1498 #map_blocks
assigns numpy dtypes
1499 # cast chunk dimensions back to python int before returning
1500 x._chunks = tuple(
→ 1501 tuple(int(chunk) for chunk in chunks) for chunks in compute(tuple(c))[0]
1502 )
1503 return xFile ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/dask/base.py:664, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
661 postcomputes.append(x.dask_postcompute())
663 with shorten_traceback():
→ 664 results = schedule(dsk, keys, **kwargs)
666 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])Cell In[4], line 40, in main..(x)
38 arr_exp = np.expand_dims(arr, axis=0)
39 arr_exp = arr_exp.rechunk((1, sz, sy, sx, zb))
—> 40 arr_pred = darray.map_blocks(lambda x: VGG16.predict(x), arr_exp, dtype=np.uint32, chunks=(1, sz, sy, sx, zb))
41 arr_pred = arr_pred.compute_chunk_sizes()
42 print(arr_pred)File ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.traceback)
120 # To get the full stack trace, call:
121 #keras.config.disable_traceback_filtering()
→ 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tbFile ~/anaconda3/envs/ct-env/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.traceback)
120 # To get the full stack trace, call:
121 #keras.config.disable_traceback_filtering()
→ 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tbValueError: as_list() is not defined on an unknown TensorShape."