Hi @guillaumeeb ,
Sorry for the slow response, had to find the time to put together an example.
Yeah, this was one of the methods I was considering, and it would look maybe something like this:
import time
import asyncio
import numpy as np
from dask import delayed
from dask.distributed import Client, KilledWorker, get_worker
def memory_hard(result):
"""
A function requiring more memory then some workers can handle.
"""
worker = get_worker()
print("running", worker.id, worker.name, result)
_ = np.ones((2**32), dtype=np.uint8) # memory hog
time.sleep(5) # give it time to be killed
return result
async def suppress(awaitable):
"""
Suppresses exceptions from an awaitable
and instead returns the exception instead of raising it.
"""
try:
return await awaitable
except Exception as e:
return e
async def key(name, awaitable):
"""
Wraps a awaitable to return its result along with
a identifying name as a (name, result) tuple.
"""
return name, await awaitable
def make_job(client, name, job_meta):
"""
Creates a dask delayed from given job_meta.
Also wraps the job so that exceptions are suppressed
(returned instead of raised), and all results
are returned as (job_name, job_result).
"""
fn = job_meta["fn"]
d = delayed(fn, pure=False)(*job_meta["args"], **job_meta["kwargs"])
job = client.compute(d, resources=job_meta["resources"])
job = suppress(job)
job = key(name, job)
return job
async def main():
async with Client("10.0.0.101:8786", asynchronous=True) as client:
# Collection of all jobs to run. Function, arguments, and run meta.
all_jobs = {
"job_1": {"fn": memory_hard, "args": [1], "kwargs": {}, "resources": {"low-memory": 1}},
"job_2": {"fn": memory_hard, "args": [2], "kwargs": {}, "resources": {"low-memory": 1}},
"job_3": {"fn": memory_hard, "args": [3], "kwargs": {}, "resources": {"high-memory": 1}},
}
# dictionaries to track all remaining jobs (todo)
# and completed jobs (done).
todo = {n: make_job(client, n, m) for n, m in all_jobs.items()}
done = {}
# keep trying until todo is empty
while todo:
print("todo iteration", todo)
for completed in asyncio.as_completed(todo.values()):
name, result = await completed
# was this run ran with low-memory resource?
low_mem_run = "low-memory" in all_jobs[name]["resources"]
# rerun on high resource if looks to have been killed.
if isinstance(result, KilledWorker) and low_mem_run:
print("rescheduling", name)
all_jobs[name]["resources"] = {"high-memory": 1}
todo[name] = make_job(client, name, all_jobs[name])
# failed on high resource, or other reason. give up.
elif isinstance(result, Exception):
raise result
# completed with no problems.
else:
del todo[name]
done[name] = result
print("success", result)
return done
if __name__ == "__main__":
asyncio.run(main())
So I think this works, but it has a few drawbacks which I would want to try and avoid.
-
If you inspect the KilledWorker
exception you’ll see a message along the lines of "Attempted to run task memory_hard-1234 on 3 different workers, but all those workers died while running it."
. These attempts are separate to the retries=0
that client.compute
can take as an argument. It would be nicer to intercept on one of these re-attemps, instead of having to let the work try and fail up to 3 times (as the function may not hit memory problems until quite far in).
-
The KilledWorker
exception doesn’t (as far as I’m aware) give me the reason for the worker being killed. I would need to assume any KilledWorker
is because of a MemoryError
. There may be a way to find this information, but I’m currently unaware.
-
Finally, at least in my example code, work that failed has to wait to be re-run on the next iteration of the while loop, which can serialise jobs that dont need to be, when re-attempting. Also the code also generally feels quite complex and cumbersome. Though these two points might be able to be fixable if I spend more time on the code vs this quickly hashed out example.
I think ideally hooking into the re-attempts in a scheduler plugin (if possible) might be the cleanest/optimal way to get the behaviour.
Happy to hear any thoughts on the above.