Aborting a task from the DaskDistributedBackend

I have been using a distributed dask cluster to execute a parameter search in parallel for some scikit learn models. My goal is to use the DaskDistributedBackend whilst also supporting the ability to halt execution after checking for some stopping condition after each model has been trained in the parameter search.

To achieve this I have subclassed the DaskDistribtuedBackend and looked to override the _collect method so that I can inject some custom behaviour as tasks return from completion. Here is the gist of what changes I have made so far:

   # within my subclass
    async def _collect(self):
        while self._continue:
            async for future, result in self.waiting_futures:
                cf_future = self._results.pop(future)
                callback = self._callbacks.pop(future)
                if future.status == "error":
                    typ, exc, tb = result
                    cf_future.set_exception(exc)
                else:
                    cf_future.set_result(result)
                    callback(result)
                    # check for stopping condition if true abort()
                    has_stopped = stop_check()
                    if stopped:
                        self.abort()
            await asyncio.sleep(0.01)

    def abort(self):
        # some method to kill
        self.stop_call() # ?
        self.abort_everything() # ?

The beginning of the parameter search then begins with:

with joblib.parallel_backend("my_custom_dask_backend"):
    model.fit(...)

The issue lies in being able to properly end the training and free up the dask worker for future work. I see that the DaskDistributedBackend implements a collection of functions that appear to be some kind of stopping method such as stop_call() and abort_everything(). Although the results of running these methods are not what I expected.

With stop_call() the execution of future tasks seems to continue as normal with no effects. And with
abort_everything execution seems to stop but the code waits indefintely on python/3.10.2/lib/python3.10/threading.py", line 320, in wait waiter.acquire(). It seems like it is expecting some lock to be released but I have no idea how or where it is expecting this to come from.

Any guidance on how to get around this would be much appreciated.

Hi,

Looking a bit at the code, I think you should first break from the for loop to release the self.waiting_futures variable before calling your abort method.

Could you tell me if it fixes the issue?

Hi, thanks for getting back,

I edited the _collect method to try and achieve what you mentioned heres how it looks roughly now

    async def _collect(self):
        while self._continue:
            async for future, result in self.waiting_futures:
                cf_future = self._results.pop(future)
                callback = self._callbacks.pop(future)
                if future.status == "error":
                    typ, exc, tb = result
                    cf_future.set_exception(exc)
                else:
                    await self.progress_did_update(1) # A custom method to track progress of model training
                    #  aborted is set to true upon reaching 30% inside ^ as i test out this code
                    cf_future.set_result(result)
                    callback(result)
                    if self.aborted:
                        break

            if self.aborted:
                # debugpy.listen(("localhost", 5678))
                # debugpy.wait_for_client()
                # breakpoint()
                self.stop_call()
                self.abort_everything()

With this, the goal was to run the abort code outside of the for loop and the while loop too hoping for some different results. Unfortunately the dask call stack for the work appears to be doing the same thing.

After tracing the code, it looks like the dask traceback does not show the full picture as from what debugpy shows me. I see it go to a file called base_events.py i think from asyncio and it loops around

            while True:
                self._run_once()
                if self._stopping:
                    break

(lines 594) in run_forever() for a while where it expects self._stopping to be set to True but i do not see where this attributes is supposed to be edited from yet. Though after some number of iterations the traecback stops abruptly and I see the same call stack from dask from before.