TensorFlow Inference

Has anyone found a good way to leverage dask for tensorflow prediction/inference?


  • We use Helm deployed dask on kubernetes
  • We need to make inference over a spatial mosaic. Our shape when flattened would be something like da.random.random((int(1.6e9), 30, 15, 15)) with the first dim being the number of examples. That is like 42Tb of data. With 32 workers @ 4cpu and 16gb I can iterate over all this data in a few minutes using numpy.lib.stride_tricks.sliding_window_view and map_blocks.
  • 1.6e9 is a spatial mosaic of 40000 x 40000 with 3 features using chips of size 15x15
  • We make a single pixel prediction from a single chip–lot’s of data to pass around for such a small point of inference…


  • Ideally, we use one instance of tensorflow per node in order to let tensorflow take advantage of resources most effectively.
  • Sometimes we use GPUs for inference if the model is big enough, so Ideally we use a set of heterogeneous workers (cpu workers for data ops and gpu workers for inference transferring data between)


  • How can we load one model per node in order to load only one model per gpu with dask? Can we use Actors and submit to gpu workers by specifying the workers arg? Will tasks in other threads play nicely?
  • As models grow larger, we can not use the naive approach of loading on cpu (with a homogeneous worker group) for each task/chunk. Refer to code below.
  • With a batch size of 2**14, that is almost 100k tasks. We would have to load a (potentially large model) that many times!
  • We have to load the model within each task because tensorflow does not like to be pickled and prefered to run in the thread it was loaded.

Example of how we can solve this today, but not in the general case and not very well:

import dask.array as da
import numpy as np
import tensorflow as tf

# generate a lazy dask array about 3gb--it will be chunked automatically
# this would be a view into flattened chips i.e. extracted from a spatial mosaic
gb = 3
n = int(gb * 1e9 / (15 ** 2 * 30 * 4 ))
data = da.random.random((n, 15, 15, 30)).astype(np.float32)

# build a function to load the model then predict
def predict_good(x):
    model = tf.keras.models.load_model("/tmp/temp_model")
    return model.predict(x, batch_size=len(x))

# tensorflow wants to run in the thread it is initialized
# tensorflow also does not like to be pickled
model = tf.keras.models.load_model("/tmp/temp_model")
def predict_bad(x):
    return model.predict(x, batch_size=len(x))

# apply the model over the chunks of the dataset
# each chunk has to load the model--this will be slow for large models
r = data.map_blocks(predict_good, dtype="float32", drop_axis=(1,2))

# apply the model over the chunks of the dataset
# this will cause us to pickle the model
r = data.map_blocks(predict_bad, dtype="float32", drop_axis=(1,2))

Has anyone successfully used dask for inference like this?


  • Use a model server to decouple model inference from dask and only use dask to preprocess and send data to some model endpoint that scales easily. Tools such as Yatai, Seldon-Core, KServe, and Jina would do the trick!

I have found some success with something like:

predict = lambda x, model, batch_size: model.predict(x, batch_size)
model = dask.delayed(tf.keras.models.load_model)(model_path)
preds = data.map_blocks(predict, model, batch_size=...)

However, it is pretty challenging to have a good mental model of what is happening.

From what I understand, when there are more tasks than workers there is one keras model loaded per task, which translates to per thread. So, if there are 12 blocks with 4 workers each with 2 threads, there will be 8 models loaded. Keras can handle this, but I think an optimal approach is one model per node/pod in order to take all available resources rather than many copies of the model on multiple threads. (this is probably why there isn’t much talk around using dask and tensorflow).

To alleviate this, and to attempt to run only one model per worker I have tried to use dask.annotate resources:

with dask.annotate(resources={'units': 1}): # where each worker is given 1 unit during deployment
    model = dask.delayed(tf.keras.models.load_model)(model_path)

But I am afraid the annotation gets lost in the graph built by map_blocks/blockwise. Am I missing something?

This is especially important with GPU. Tensorflow will by default take all available gpu resources adhering to the one vs many principle they seem to recommend. Then when multiple models are loaded on each worker in different threads, resources get consumed and I see transient errors during predictions. One way to alleviate this is by setting allow_growth=True. This would more easily allow models to share GPU resources.

Has anyone had success with this? I think I would prefer to specify one task can run per worker, but the only ways I know how to do this is either through resource annotation or with worker saturation.

@gjoseph92 (I hope you don’t mind I ping you), but have you had success with using worker saturation in the map_blocks context?