Callback when all computations of a certain type are finished


I have an operation that I perform on a Dask Bag for which a GPU is required. I use the resources syntax to assign certain workers with a GPU (I only have a limited amount). When the GPU op is fired the function uses get_worker to get the worker instance and check whether it has the model attribute. When it does, the function continues and uses the model in the attribute. If the attribute is not present, the model is first loaded in the model attribute of the worker after which the function continues. This way the model only needs to be loaded in once per worker.

This works very well. However, I would like to remove the model from the workers once all GPU ops have been processed. Is there some way to fire a callback once all ops with a certain name have been processed?


Hi @MaximLippeveld, thanks for the question! I think Dask Futures will be helpful here, and there are a number of functions that let you manage futures. For instance, you could use client.gather(futures) where futures is the list of futures with GPU operations you’re interested in, and then call your cleanup functionality to remove your models. You might also find this section helpful if you need to wait for futures to be completed.

I indeed looked into futures, but the issue is that I would like to let other operations proceed while the GPU operations are happening. From what I understand that wouldn’t be possible with client.gather. Am I correct in thinking that?

I think it might be feasible to do what I want using a SchedulerPlugin that keeps track of the number of GPU operations, a WorkerPlugin that keeps a hold of the model, and a done-callback on the GPU futures. The callback would then call a method on the SchedulerPlugin, decrementing some counter. If the counter is zero, the SchedulerPlugin could unregister the WorkerPlugins holding the models. Might be a bit complex though…

For client.gather(), this could work if you have a list of futures of your GPU operations, which depends on your setup. Regarding the Plugins you mentioned, I don’t think these will help you, since it sounds like you’d like to manipulate the scheduling itself.

Another more complicated option is to take advantage of the relatively-recently-added utilities for graph manipulation in dask.graph_manipulation , specifically bind, to add implicit dependencies to a Dask collection, and wait_on , to ensure dependents of a collection wait on another unrelated collection. You can also check out this SO post for another example.

Below is a minimally reproducible example of how this might work, but using client.gather() is definitely the simpler option if that’s possible.

import dask
from dask import delayed
from dask.graph_manipulation import bind, wait_on

def gpu_func(x):
    return x

def non_gpu_func(x):
    return x

def cleanup_gpu():

gpu_delayeds = [gpu_func(x) for x in range(3)]
non_gpu_delayeds = [non_gpu_func(x) for x in range(3)]
delayeds = gpu_delayeds + non_gpu_delayeds

# gpu_cleanup is now dependent on gpu_delayeds
new_cleanup = bind(cleanup_gpu(), gpu_delayeds)
# ensure dependents of new_cleanup wait until gpu_delayeds are done
wait_delayeds = wait_on(new_cleanup, *gpu_delayeds)

Here’s the original task graph of your GPU and non-GPU operations:

And after, the GPU operations depend on the cleanup function, which will remove the model, but not until after the GPU operations are finished:

1 Like