Callback when all computations of a certain type are finished

For client.gather(), this could work if you have a list of futures of your GPU operations, which depends on your setup. Regarding the Plugins you mentioned, I don’t think these will help you, since it sounds like you’d like to manipulate the scheduling itself.

Another more complicated option is to take advantage of the relatively-recently-added utilities for graph manipulation in dask.graph_manipulation , specifically bind, to add implicit dependencies to a Dask collection, and wait_on , to ensure dependents of a collection wait on another unrelated collection. You can also check out this SO post for another example.

Below is a minimally reproducible example of how this might work, but using client.gather() is definitely the simpler option if that’s possible.

import dask
from dask import delayed
from dask.graph_manipulation import bind, wait_on

@delayed
def gpu_func(x):
    return x

@delayed
def non_gpu_func(x):
    return x

@delayed
def cleanup_gpu():
    pass

gpu_delayeds = [gpu_func(x) for x in range(3)]
non_gpu_delayeds = [non_gpu_func(x) for x in range(3)]
delayeds = gpu_delayeds + non_gpu_delayeds

# gpu_cleanup is now dependent on gpu_delayeds
new_cleanup = bind(cleanup_gpu(), gpu_delayeds)
# ensure dependents of new_cleanup wait until gpu_delayeds are done
wait_delayeds = wait_on(new_cleanup, *gpu_delayeds)

Here’s the original task graph of your GPU and non-GPU operations:

And after, the GPU operations depend on the cleanup function, which will remove the model, but not until after the GPU operations are finished:

1 Like