For client.gather()
, this could work if you have a list of futures of your GPU operations, which depends on your setup. Regarding the Plugins you mentioned, I don’t think these will help you, since it sounds like you’d like to manipulate the scheduling itself.
Another more complicated option is to take advantage of the relatively-recently-added utilities for graph manipulation in dask.graph_manipulation
, specifically bind
, to add implicit dependencies to a Dask collection, and wait_on
, to ensure dependents of a collection wait on another unrelated collection. You can also check out this SO post for another example.
Below is a minimally reproducible example of how this might work, but using client.gather()
is definitely the simpler option if that’s possible.
import dask
from dask import delayed
from dask.graph_manipulation import bind, wait_on
@delayed
def gpu_func(x):
return x
@delayed
def non_gpu_func(x):
return x
@delayed
def cleanup_gpu():
pass
gpu_delayeds = [gpu_func(x) for x in range(3)]
non_gpu_delayeds = [non_gpu_func(x) for x in range(3)]
delayeds = gpu_delayeds + non_gpu_delayeds
# gpu_cleanup is now dependent on gpu_delayeds
new_cleanup = bind(cleanup_gpu(), gpu_delayeds)
# ensure dependents of new_cleanup wait until gpu_delayeds are done
wait_delayeds = wait_on(new_cleanup, *gpu_delayeds)
Here’s the original task graph of your GPU and non-GPU operations:
And after, the GPU operations depend on the cleanup function, which will remove the model, but not until after the GPU operations are finished: