Has anyone found a good way to leverage dask for tensorflow prediction/inference?

Context:

- We use Helm deployed dask on kubernetes
- We need to make inference over a spatial mosaic. Our shape when flattened would be something like
`da.random.random((int(1.6e9), 30, 15, 15))`

with the first dim being the number of examples. That is like 42Tb of data. With 32 workers @ 4cpu and 16gb I can iterate over all this data in a few minutes using`numpy.lib.stride_tricks.sliding_window_view`

and map_blocks. - 1.6e9 is a spatial mosaic of 40000 x 40000 with 3 features using chips of size 15x15
- We make a single pixel prediction from a single chip–lot’s of data to pass around for such a small point of inference…

Goal:

- Ideally, we use one instance of tensorflow per node in order to let tensorflow take advantage of resources most effectively.
- Sometimes we use GPUs for inference if the model is big enough, so Ideally we use a set of heterogeneous workers (cpu workers for data ops and gpu workers for inference transferring data between)

Issues:

- How can we load one model per node in order to load only one model per gpu with dask? Can we use Actors and submit to gpu workers by specifying the workers arg? Will tasks in other threads play nicely?
- As models grow larger, we can not use the naive approach of loading on cpu (with a homogeneous worker group) for each task/chunk. Refer to code below.
- With a batch size of 2**14, that is almost 100k tasks. We would have to load a (potentially large model) that many times!
- We have to load the model within each task because tensorflow does not like to be pickled and prefered to run in the thread it was loaded.

Example of how we can solve this today, but not in the general case and not very well:

```
import dask.array as da
import numpy as np
import tensorflow as tf
# generate a lazy dask array about 3gb--it will be chunked automatically
# this would be a view into flattened chips i.e. extracted from a spatial mosaic
gb = 3
n = int(gb * 1e9 / (15 ** 2 * 30 * 4 ))
data = da.random.random((n, 15, 15, 30)).astype(np.float32)
# build a function to load the model then predict
def predict_good(x):
model = tf.keras.models.load_model("/tmp/temp_model")
return model.predict(x, batch_size=len(x))
# tensorflow wants to run in the thread it is initialized
# tensorflow also does not like to be pickled
model = tf.keras.models.load_model("/tmp/temp_model")
def predict_bad(x):
return model.predict(x, batch_size=len(x))
# apply the model over the chunks of the dataset
# each chunk has to load the model--this will be slow for large models
r = data.map_blocks(predict_good, dtype="float32", drop_axis=(1,2))
r.compute().shape
# apply the model over the chunks of the dataset
# this will cause us to pickle the model
r = data.map_blocks(predict_bad, dtype="float32", drop_axis=(1,2))
r.compute().shape
```

Has anyone successfully used dask for inference like this?

## Alternatives:

- Use a model server to decouple model inference from dask and only use dask to preprocess and send data to some model endpoint that scales easily. Tools such as Yatai, Seldon-Core, KServe, and Jina would do the trick!