Dask Arrays with TensorFlow

I have not seen any good examples demonstrating the use of dask arrays with tensorflow. My particular use case is that I have a small-medium model that will train on a gpu worker, but I would like all cpu workers responsible for loading dask arrays into memory, probably applying simple ops, and sending to the single gpu machine. Training datasets are usually < 500gb and I am expecting < 5 epochs.


  • I would like the gpu to be fully utilized by having the data loading and preprocessing be offloaded to non-gpu worker(s).
  • I am able to fit the training op on a single gpu worker. Worst case, I select bigger gpus and increase the number of gpus per worker (limited to 8–I haven’t seen gpu nodes with more gpus than that). Basically, if I have a single gpu worker with (8) large gpus responsible for catching data processed on other workers and is primarily focused on model training, I should meet my training time requirements if the network can support the data transfer.
  • If I need to scale to more than (8), than I will likely look into tf’s distributed strategy, parameter server approach, etc.


  • We are deploying with helm on GKE
  • No special networking options enabled
  • I have tried to accomplish this with multiple worker groups and annotations, but I am failing to exceed .33gb/s transfer rate. (I’ll add some code in a bit)


  1. Is this a reasonable task for dask?
  2. Is anyone doing this? I have a hard time believing I am the first to try this.
  3. Any rough guesses on the maximum data transfer rates from a group of workers to a single worker?
  4. Are there any recommended network optimizations to help speed up this data transfer?
  5. Is there any existing code that profiles data transfer between workers on separate nodes?
  6. would bypassing dask’s communication with a zmq-based protocol like GitHub - NVlabs/tensorcom help speed up data transfer?

Here is a code snippet I am using to try to profile numpy array data transfer between two worker types using the resources annotation and the additional worker groups deployment method with helm:

def load_data(i):
    # this will likely be reading from xarray and calling compute to load in 
    # memory wherever this is running
    batch = np.random.uniform(size=(1024, 42, 15, 15)).astype(np.float32) + i
    return batch

def train(data):
    # not training, but it is a function that will run elsewhere to catch data
    return data.nbytes

start = time.time()
results = []
futures = []
for i in range(500):
    with dask.annotate(resources={'units':1}):
        futures.append(client.submit(load_data, i))

with dask.annotate(resources={'additional_units':1}):
    for future in futures:
        r = client.submit(train, future)
elapsed = time.time() - start
gbps = sum([r.result() for r in results]) / 1e9 / elapsed

I can see the load_data function computing, then the data transfer call and then the train function happening on the correct workers in the dask ui. I am also noticing that increasing the number of workers associated with the load_data method increases gbps.

Now, all I need to do is somehow wrap up this logic into an iterator to build a tf.data.Dataset.from_generator like

class DataIterator:
    def __init__(self, prefetch: int = 1, delayed_objs: List = []):
        self.delayed_objs = delayed_objs
        self._futures: list = []
        self._prefetch = prefetch

    def _dask_prefetch(self):
        while len(self._futures) < self._prefetch:
            if len(self.delayed_objs) > 0:
                delayed = self.delayed_objs.pop()
        with worker_client() as client:
            with dask.annotate(resources={"units": 1}):
                self._futures = client.compute(self._futures)

    def __getitem__(self, i):
        future = self._futures.pop()
        return future.result()

def get_remote_dask_generator_tf_dataset(
    delayed_objs: List,
    prefetch: int = 20,
    tensor_spec = tf.TensorSpec.from_tensor(
    dataset = tf.data.Dataset.from_generator(
        lambda: DataIterator(

    return dataset

where I simply provide a list of delayed objects to the iterator.

Some weird things happening with tf complaining about the iterator, but figured it would be useful to get some code down in hopes of materializing some discussion, concerns, etc.

Of course I can build tfrecords from dask arrays or xarray offline and then use the tf.data.Dataset frmo tfrecords, but I am hoping to serve up dask arrays directly and take advantage of them potentially being persisted and avoiding the file transfer in addition to allowing easier preprocessing on non-training workers/vms

@ljstrnadiii Thanks, that’s an interesting question.

  1. Is this a reasonable task for dask?
  2. Is anyone doing this? I have a hard time believing I am the first to try this.

We think the machine learning part is certainly reasonable, however, we’re not certain about moving the data across workers.

We think some Dask developers on the NVIDIA team might have a better understanding of GPUs. @jacobtomlinson do you or maybe Charles (who isn’t on Discourse) have thoughts on this?

We found:

which might have some useful nuggets. :slight_smile:

(@ncclementi and I looked at this)

1 Like

@pavithraes (and @ncclementi) thanks for the response here.

I don’t really expect optimal data transfer from cpu workers directly to gpu memory. I figure the gpu workers will have enough resources to load the data on the gpu as fast as they can receive the data and decode into arrays, but I can easily be wrong there. Tensorflow and pytorch both support GitHub - NVIDIA/nccl: Optimized primitives for collective multi-GPU communication in their distributed comms, which seems like an further optimization. Another NVIDIA thing that didn’t get that much attention was GitHub - NVlabs/tensorcom, which uses zmq to communicate encoded array with most zmq protocols. I have successfully used that in the past to serve up training data at around 8gbps, which was fast enough for me at the time. Scaling up was just creating an additional pod to serve data by dumping into some tcp socket. I never found the limit, but mostly because I didn’t search for it.

Speaking of limits, I did actually run a network profiler on workers between nodes in dask and verified the limit was around 31 gbps using iperf recommended here: Benchmarking higher bandwidth VM instances  |  Compute Engine Documentation  |  Google Cloud

I’ll check out the examples, thanks for sharing!

1 Like