Optimistic Memory Scheduling

Hi,

I have some python code that can sometimes use a high amount of memory, but normally, not. Determining it’s memory footprint before running is not a straightforward heuristic / calculation so is currently no solved.

I was wanting to set up a few workers with high memory limits, relative to the regular workers. And optimistically try and run that task anywhere on the cluster, and on a MemoryError reschedule the task to be reattempted on the worker with a higher memory capacity.

I’ve read through the distributed documentation, under the “build understanding” section. So I’m aware there are a few options available for scheduling work. However, none of them seem to stand out as letting me get the kind of behaviour described above.

Could anyone offer any suggestions on where they think this kind of logic might best sit?

I don’t want to just limit all runs using the worker resources feature, as most of the time memory isn’t an issue, and I want to use the full cluster.

Thanks for any suggestions.

Hi @freebie,

Did you read through
https://distributed.dask.org/en/stable/resources.html

?

With resources settings and a bit of try/catch logic on client side, I think you should be able to achieve what you want if I understood it correctly.

Hi @guillaumeeb ,

Sorry for the slow response, had to find the time to put together an example.

Yeah, this was one of the methods I was considering, and it would look maybe something like this:

import time
import asyncio

import numpy as np

from dask import delayed
from dask.distributed import Client, KilledWorker, get_worker


def memory_hard(result):
    """
    A function requiring more memory then some workers can handle.
    """
    worker = get_worker()
    print("running", worker.id, worker.name, result)
    _ = np.ones((2**32), dtype=np.uint8)  # memory hog
    time.sleep(5)  # give it time to be killed
    return result


async def suppress(awaitable):
    """
    Suppresses exceptions from an awaitable
    and instead returns the exception instead of raising it.
    """
    try:
        return await awaitable
    except Exception as e:
        return e


async def key(name, awaitable):
    """
    Wraps a awaitable to return its result along with
    a identifying name as a (name, result) tuple.
    """
    return name, await awaitable


def make_job(client, name, job_meta):
    """
    Creates a dask delayed from given job_meta.
    Also wraps the job so that exceptions are suppressed
    (returned instead of raised), and all results
    are returned as (job_name, job_result).
    """
    fn = job_meta["fn"]
    d = delayed(fn, pure=False)(*job_meta["args"], **job_meta["kwargs"])
    job = client.compute(d, resources=job_meta["resources"])
    job = suppress(job)
    job = key(name, job)
    return job


async def main():
    async with Client("10.0.0.101:8786", asynchronous=True) as client:
        # Collection of all jobs to run. Function, arguments, and run meta.
        all_jobs = {
            "job_1": {"fn": memory_hard, "args": [1], "kwargs": {}, "resources": {"low-memory": 1}},
            "job_2": {"fn": memory_hard, "args": [2], "kwargs": {}, "resources": {"low-memory": 1}},
            "job_3": {"fn": memory_hard, "args": [3], "kwargs": {}, "resources": {"high-memory": 1}},
        }

        # dictionaries to track all remaining jobs (todo)
        # and completed jobs (done).
        todo = {n: make_job(client, n, m) for n, m in all_jobs.items()}
        done = {}

        # keep trying until todo is empty
        while todo:
            print("todo iteration", todo)

            for completed in asyncio.as_completed(todo.values()):
                name, result = await completed

                # was this run ran with low-memory resource?
                low_mem_run = "low-memory" in all_jobs[name]["resources"]

                # rerun on high resource if looks to have been killed.
                if isinstance(result, KilledWorker) and low_mem_run:
                    print("rescheduling", name)
                    all_jobs[name]["resources"] = {"high-memory": 1}
                    todo[name] = make_job(client, name, all_jobs[name])
                # failed on high resource, or other reason. give up.
                elif isinstance(result, Exception):
                    raise result
                # completed with no problems.
                else:
                    del todo[name]
                    done[name] = result
                    print("success", result)

        return done


if __name__ == "__main__":
    asyncio.run(main())

So I think this works, but it has a few drawbacks which I would want to try and avoid.

  1. If you inspect the KilledWorker exception you’ll see a message along the lines of "Attempted to run task memory_hard-1234 on 3 different workers, but all those workers died while running it.". These attempts are separate to the retries=0 that client.compute can take as an argument. It would be nicer to intercept on one of these re-attemps, instead of having to let the work try and fail up to 3 times (as the function may not hit memory problems until quite far in).

  2. The KilledWorker exception doesn’t (as far as I’m aware) give me the reason for the worker being killed. I would need to assume any KilledWorker is because of a MemoryError. There may be a way to find this information, but I’m currently unaware.

  3. Finally, at least in my example code, work that failed has to wait to be re-run on the next iteration of the while loop, which can serialise jobs that dont need to be, when re-attempting. Also the code also generally feels quite complex and cumbersome. Though these two points might be able to be fixable if I spend more time on the code vs this quickly hashed out example.

I think ideally hooking into the re-attempts in a scheduler plugin (if possible) might be the cleanest/optimal way to get the behaviour.

Happy to hear any thoughts on the above.

Forgot to include, these are the commands I used to start the local cluster for the test code:

dask worker 10.0.0.101:8786 --name low --memory-limit 500MB --nworkers 1 --resources 'low-memory=1'
dask worker 10.0.0.101:8786 --name high --memory-limit 5GB --nworkers 1 --resources 'high-memory=1'

See What do KilledWorker exceptions mean in Dask? - Stack Overflow.

I have to admit I didn’t know this, but when looking at Why did my worker die? — Dask.distributed 2023.11.0 documentation, it looks you are right. Do you see the MemoryError on the killed workers? Maybe this is some Dask improvement the should be done.

Would Dask as_completed function be of any help?

Really useful stuff, thank you!

I will have another look and see if I can get a cleaner solution with what you’ve pointed to.

RE: Do you see the MemoryError on the killed workers?
I do not believe so, unless there’s something special about catching MemoryErrors I’m unaware of. I tried putting a try catch in the worker code, but I didn’t see it firing. But I may have missed it, I’ll give it another look.

Thanks again for your help / pointers. very much appreciated.

1 Like