How to handle job migration of 3rd party tasks?

I have a use case where I am submitting jobs to SLURM in order to run third party, multithreaded software, with one task per job. My use case is very similar to:

This is generally done by passing arguments to the Cluster to make dask think that there is one thread, but job keyword arguments to make the scheduler book multiple cpus, thus dask only sends one job and the called software can use all cores.

However, the main question of this post - since dask is not actually running the task, just keeping track of the inputs, outputs, and completion state, migration to a new worker requires a full restart of the task, how can this be circumvented? This is suboptimal because if a task starts on a worker which is soon gracefully killed to avoid SLURM timeout, the current state of the computation is not actually migrated to the new worker, instead we just waste that compute.

Two solutions I can see but do not know how to implement:

  • graciously stop worker at the end of every a single task from within the worker. eg a new slurm job is submitted for each task
  • migrate the state of the third party software

Any suggestions?
Thanks for assistance.

EDIT:
I tried to attack strategy 1 using a worker plugin:

class KillerNannyPlugin(distributed.diagnostics.plugin.WorkerPlugin):
    """Kills worker after task is completed.
    
    Transitions states "memory" or "error" occur after the "executing" state and trigger this
    plugin. Ensures that each task gets a new Worker.
    
    This should be a nanny plugin to be more dask friendly, but those don't trigger transtions
    as of 11.08.22
    
    
    Parameters
    ----------
    max_stagger_seconds - float
        attentuates how long to wait after task before closing worker. 
        actual wait time is 3 + max_stagger_seconds * X where X is drawn from [0,1]
        ensures that data is not lost to workers closing at the same time.
    """
    def __init__(self, max_stagger_seconds: float = 10):
        self.max_stagger_seconds = max_stagger_seconds
    
    def setup(self, worker):
        self.worker = worker
        
    def transition(self, key, start, finish, *args, **kwargs):
        if start == 'memory' and finish == 'released':
            self.worker.io_loop.call_later(1+ random.random() * self.max_stagger_seconds, self.worker.close_gracefully)

This successfully causes the worker to close after it completes a task. I had to stagger it such that workers ending at the same time did not happen. The side effect of this is that about 20% of tasks are repeated using this strategy, I assume because a worker was killed before it could send the results back to as_completed, which defeats the original purpose of trying not to waste computation time.

If this is the case, it could be that my start and finish transition events are off, is there a way to ensure the result has been gathered before closing the worker?

Again I appreciate any help.

Hi @evankomp,

I answered in Restart cluster job on task completion · Issue #597 · dask/dask-jobqueue · GitHub. I think the best strategy would be to improve the --lifetime option handling to take into account the case where we want to wait for the end of a task before stopping a worker, as mentioned in Enhancement Request - Dask Workers lifetime option not waiting for job to finish · Issue #3141 · dask/distributed · GitHub.

I didn’t have the time to look deeply in your proposed solution of a Plugin though.