I have been using a distributed dask cluster to execute a parameter search in parallel for some scikit learn models. My goal is to use the DaskDistributedBackend
whilst also supporting the ability to halt execution after checking for some stopping condition after each model has been trained in the parameter search.
To achieve this I have subclassed the DaskDistribtuedBackend
and looked to override the _collect
method so that I can inject some custom behaviour as tasks return from completion. Here is the gist of what changes I have made so far:
# within my subclass
async def _collect(self):
while self._continue:
async for future, result in self.waiting_futures:
cf_future = self._results.pop(future)
callback = self._callbacks.pop(future)
if future.status == "error":
typ, exc, tb = result
cf_future.set_exception(exc)
else:
cf_future.set_result(result)
callback(result)
# check for stopping condition if true abort()
has_stopped = stop_check()
if stopped:
self.abort()
await asyncio.sleep(0.01)
def abort(self):
# some method to kill
self.stop_call() # ?
self.abort_everything() # ?
The beginning of the parameter search then begins with:
with joblib.parallel_backend("my_custom_dask_backend"):
model.fit(...)
The issue lies in being able to properly end the training and free up the dask worker for future work. I see that the DaskDistributedBackend
implements a collection of functions that appear to be some kind of stopping method such as stop_call()
and abort_everything()
. Although the results of running these methods are not what I expected.
With stop_call()
the execution of future tasks seems to continue as normal with no effects. And with
abort_everything
execution seems to stop but the code waits indefintely on python/3.10.2/lib/python3.10/threading.py", line 320, in wait waiter.acquire()
. It seems like it is expecting some lock to be released but I have no idea how or where it is expecting this to come from.
Any guidance on how to get around this would be much appreciated.