Tracking the lifetime of a cluster-wide shared resource

I have a use case where I need to track the lifetime of a cluster-wide shared resource. Basically, some dask.delayed tasks will allocate a resource that is visible to the entire cluster (for now, this is POSIX shared memory on a single-machine/multi-process cluster, but in the future will be other things). When the object returned by this task is sent (or duplicated) to other workers, they will all view this shared resource. Once the task key has been forgotten by the cluster, it is safe to free the shared resource.

To manage this, I was thinking that I could do the following:

  1. Create a Scheduler plugin that keeps track of the association between task keys and shared resources to eventually free.
  2. Whenever a task state changes to “forgotten”, the scheduler plugin will check if this key is associated with a shared resource, and if so, remove the key from the list of resource users and free the resource if there are no more task keys attached to the shared resource.
  3. Whenever a delayed function returns a new object attached to a shared resource it will send a message (via a distributed.Queue) to the scheduler plugin with the handle of the shared resource as well as the task key that will be associated with it.

(Note that it is possible for two distinct task keys to be associated with one shared resource, which is why this looks like reference-counting.)

The main problem with this approach is that I don’t know how to do part 3. How does a delayed function look up the key for the task it is currently executing? What I need has a similar flavor to the get_client() function, but looks up the current task key.

Any suggestions? Is there a better way to approach this?

You might want get_worker().get_current_task() . More broadly a bunch of state gets tossed into a threadlocal at distributed.worker.thread_state

I wonder if this is actually a good application for Actors though? I don’t know how well our GC operates on those though.

I think this approach generally makes sense, but I wonder if you can simplify it. I’m not sure why you need the Queue to message the scheduler plugin, though? That sounds like you’re trying to implement reference-counting of some sort on your own on the scheduler, but I think you could just use what’s built in.

Also, I’m assuming here that you’re creating the resource in a task in a dask graph. If you’re using the Futures API directly through Client.submit, etc., this might be a little simpler.

What if you did something like:

  1. Create a Future for the resource. This could actually be the resource itself (if it’s pickleable), or just a dummy object (say, the path to the shared memory as a string). If you’re in a task, maybe use get_client().scatter(thing)
  2. Return that Future from the task
  3. Other tasks receive the Future as input and use it as usual
  4. Those tasks eventually let the Future go out of scope and at some point the scheduler realizes nothing needs it anymore
  5. Add a SchedulerPlugin which watches for keys transitioning to "forgotten". If the key is one of your resources (a few ways to check this, more later), then run your cleanup logic, either on the scheduler (easy) or by creating a new task to run on a worker (trickier).

This way, you leverage the distributed reference-counting logic of Futures. The only thing that API doesn’t offer is triggering some event when the refcount goes to zero; you add that yourself with the SchedulerPlugin.

Now, how can the plugin tell which keys matter?

  • Some pattern of the key name. scatter will name the keys f"{type(x).__name__}-{token}", so if your object has a specially-named type, then just check key.startswith("ShmWrapperType")
  • Use client.submit() instead, where you can specify the key=
  • Use client.submit(), but wrap in with dask.config.set({"annotations.release_me": True}), and have your SchedulerPlugin look for scheduler.tasks[key].annotations.get("release_me") on the tasks going to "forgotten"

Or if you’re using delayed, a simpler idea:

from dask.base import tokenize
import dask
import distributed

magic_shm_prefix = "magic-shm-release-me"

shm = delayed(make_shared_memory)(..., dask_key_name=f"{magic_shm_prefix}-{tokenize()}")
tasks = [delayed(func)(shm, x) for x in things_to_process]
more_tasks = [delayed(other_func)(shm, y) for y in other_things]

class Releaser(distributed.SchedulerPlugin):
    def transition(self, key, start, finish, *args, **kwargs):
        if finish == "released" and key.startswith(magic_shm_prefix):
            release_shared_memory(...)

Basically, have a special task that makes the shared resource. Everything that needs the resource depends on this task. When all tasks depending on it are done, the task will be transitioned to "released". If you know the name/pattern of the key, you can watch for this and take some action when it happens.

If your resources are complex/there are multiple of them, a tricky thing might be that you need a handle on the resource itself/the path to it/some other metadata in order to actually release it. This is a bit harder. You could use annotations again to add that metadata to the "magic-shm-release-me" task, like:

with dask.annotate(shm_path="/dev/shm/foo"):
    shm = delayed(make_shared_memory)(..., dask_key_name=f"{magic_shm_prefix}-{tokenize()}")

and then pull the annotations out in your plugin? Bit hacky but it might work for this case.