Understanding Work Stealing

I’m trying to understand work stealing, with the plan to allow running workers in different datacenters, but prevent any but essential transfers between them.

I expected to be able to prevent stealing on a case-by-case basis by overriding _can_steal in stealing.py.

I set up an experiment where a subset of workers load data, and then run some calculations on it. What I see is that with DASK_DISTRIBUTED__SCHEDULER__WORK_STEALING="False", the calculations are only run on the workers that loaded the data. But with _can_steal always returning False, tasks are dispatched to other workers.

Am I misunderstanding data locality and work stealing?

More experiments reveal that _can_steal is only called if _has_restrictions has already returned True. Removing that requirement (it’s only an optimisation) I can now control work stealing for each task and worker, at least for a customised copy of stealing.py.

I’ll look at putting together a PR

1 Like

@dmcg Thanks for the question! I believe this is related to Multi-region clusters.

Would you be able to share a minimal version of your experiment using LocalCluster? It’ll be super helpful for someone in the future!

If I understand correctly, you’re looking for a way to allow work-stealing within each datacenter, but not between datacentres, right?

Overriding the private API (like with _can_steal) isn’t generally recommended… But, assuming you know which workers are in which datacenter, maybe you can set the workers parameters in Client.compute for the specific computations. Or, use annotations:

with dask.annotate(workers=[...]):
    ...

Would that work?

Yes, this is part of some work on distributed data across datacenters, regions or organisations, and having the compute tied as far as possible to the place where the data resides.

Instead of using workers we are (ab)using resources. So each worker in an organisation is assigned that organisation as a resource on startup:

dask-worker $schedule_host --resources "org-$org_name=1"

Then we can pin computations to an organisation by annotating with the required organisation as resource:

def in_org(name):
    return annotate(resources={f'org-{name}': 1})


with in_org('metoffice'):    
    predictions = xarray.open_dataset('/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})
        
with in_org('eumetsat'):
    measurement = xarray.open_dataset('/data/eumetsat/measurements.nc').chunk({"latitude": 10})

averages = predictions.mean('realization', keep_attrs=True)

diff = (averages.height[0] - comparison_dataset).to_array()[0,...,0]
    
    
diffed = diff.compute(optimize_graph=False)

This works well (provided all the workers can see each other) but doesn’t prevent work stealing between organisations, which would be costly in terms of data transfer. We can disable work stealing altogether, but that makes computation in an organisation less effective. So…

… I’d like to gain more control over work-stealing. Something like:

from distributed.stealing import WorkStealing, _can_steal


class OrganisationAwareWorkStealing(WorkStealing):
    def potential_thieves_for(self, ts, idle, sat):
        return [ws for ws in idle if _can_steal(ws, ts, sat)]


def can_steal(thief, ts, victim):
    return False if org_of(thief) != org_of(victim) else _can_steal(thief, ts, victim)


def org_of(worker):
    orgs = [item for item in worker.resources if item.startswith('org')]
    return orgs[0] if orgs else None

where potential_thieves_for is a new method in WorkStealing:

    def potential_thieves_for(self, ts, idle, sat):
        if _has_restrictions(ts):
            return [ws for ws in idle if _can_steal(ws, ts, sat)]
        else:
            return  idle

designed as a point of inflection for this and other applications.

Does this make sense? If so, it would be helpful to know how to install OrganisationAwareWorkStealing - Configuring Scheduler Extensions

Hey @dmcg this is such an interesting and unique use case!

I’d love to set up a time to learn more about it and see how our oss dask engineers at Coiled might be able to give you some tips.

Are you free to talk sometime this week or next? Here’s a link where you can find a time that works best for you – Meetings

2 Likes

That would be grand - I’ve booked a slot

1 Like

And in the meantime submitted a PR Remove duplication from stealing (was Allow subclasses to control work stealing) by dmcg · Pull Request #5787 · dask/distributed · GitHub

1 Like

If you have resource restrictions on tasks, but those tasks are getting stolen to workers that don’t have appropriate restrictions, that sounds like a bug. Work stealing should respect resource constraints just like anything else.

I’m guessing the problem is instead that your downstream tasks (averages, diff) aren’t annotated with a particular organization. Would it be possible to do so? With this resources model you’re using (which seems reasonable enough to me), you’d need to annotate every task that you insist runs with that particular resource, otherwise you’re implicitly saying “this can run anywhere”.

I understand changing the stealing logic would make this work in most cases, but that’s only because of an implementation detail in the scheduler’s current task-assignment logic: downstream tasks are only assigned to the workers that hold their dependencies, so as long as you place the root tasks in the right organization, all their downstream tasks will recursively happen to run in the same organization. But we’ve actually discussed changing this (https://github.com/dask/distributed/pull/4925). If you really do have an explicit requirement on where a certain task runs, the semantically-correct thing to do would be to annotate it as such.

If you don’t feel like doing this manually, you could also probably write a custom optimization function which traverses the High-Level Graph and propagates the resource annotation from parent tasks to children (except where a child has parents with multiple resources).

1 Like

Thanks all.

On further investigation we don’t see resource restrictions being ignored for the initial data loading tasks (these are not stolen). We could annotate downstream tasks to ensure that they are not stolen, but our trials show that with our modest change to the stealing, they don’t have to be. Which seems like a cheap useabilty win for very little effort.

The following Jupyter notebook shows our results.

Understanding Work Stealing

This sheet aims to show our problems with the current work stealing, and the behaviour of a system with our revised version.

It uses Docker containers to pretend to have 3 locations.

  1. This computer, where the client is running
  2. Metoffice
  3. Eumetsat

The idea is that tasks accessing data available in metoffice should be run there, and similarly with eumetsat. The algoritm should be defined, and results rendered, here.

We use Docker to host a scheduler and workers in such a way that all data in /data is visible to the client machine (this one) for metadata visibility, but the metoffice container can only see /data/metoffice, and eumetsat /data/metoffice. This keeps our data separation honest.

Setup

First some imports

from time import sleep
import dask
from dask.distributed import performance_report, get_task_stream
import pytest
import ipytest
import xarray
import matplotlib.pyplot as plt

ipytest.autoconfig()

Clear out the cluster (docker stop" requires at least 1 argument means that nothing was running)

%%bash
docker stop $(docker ps -a -q)
docker container prune --force
e395ef09a505
b0c29d9421c8
63f33eaf89f4
Deleted Containers:
e395ef09a50506098a46a88feb476de9317db9400db7ac73191236962e038e7e
b0c29d9421c8e1f11b65f039e30e749e21921509ae1248ef2367c2973b24397c
63f33eaf89f43f96ac2cb6b6fe7d4925fb75a8766ee0faf33067bc84c25f1cf2

Total reclaimed space: 66.78kB

Our Vision

Here we show an xarray calculation run on metoffice and eumetsat workers.

Run Up a Cluster

First run a scheduler. We’re going to run our version that prevents cross-organisation work-stealing.

The docker container is available at Docker Hub. It is basically conda install -c conda-forge python=3.10 dask distributed iris xarray bottleneck and then apply some patches detailed later.

%%bash --bg
docker run --network host metoffice/irisxarray resource-aware-scheduler.py

Start 4 metoffice workers, in separate processes. These have a resource org-metoffice=1

%%bash --bg
scheduler_host=localhost:8786
org_name=metoffice
worker_cmd="dask-worker $scheduler_host --nprocs 4 --nthreads 1 --resources org-$org_name=1 --name $org_name"
docker run --network host --mount type=bind,source="$(pwd)"/$org_name-data,target=/data/$org_name --env ORG_NAME=$org_name metoffice/irisxarray $worker_cmd

Start 4 eumetsat workers, in separate processes, and with org-eumetsat=1

%%bash --bg
scheduler_host=localhost:8786
org_name=eumetsat
worker_cmd="dask-worker $scheduler_host --nprocs 4 --nthreads 1 --resources org-$org_name=1 --name $org_name"
docker run --network host --mount type=bind,source="$(pwd)"/$org_name-data,target=/data/$org_name --env ORG_NAME=$org_name metoffice/irisxarray $worker_cmd

And a client to talk to them - you can click through to the Dashboard on http://localhost:8787/status

from dask.distributed import Client
import dask

client = Client('localhost:8786')
/home/ec2-user/miniconda3/envs/irisxarray/lib/python3.10/site-packages/distributed/client.py:1096: VersionMismatchWarning: Mismatched versions found

+-------------+----------------+----------------+----------------+
| Package     | client         | scheduler      | workers        |
+-------------+----------------+----------------+----------------+
| dask        | 2022.01.0      | 2022.01.1      | 2022.01.1      |
| distributed | 2022.01.0      | 2022.01.1      | 2022.01.1      |
| python      | 3.10.2.final.0 | 3.10.0.final.0 | 3.10.0.final.0 |
+-------------+----------------+----------------+----------------+
  warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"]))

Define Some Conveniences

Here are some utilities to work with workers in organisations

import os.path
from dask import annotate, delayed


def in_org(name):
    return annotate(resources={f'org-{name}': 1})


@delayed
def my_org():
    return os.environ['ORG_NAME']


@delayed
def tree(dir):
    result = []
    for path, dirs, files in os.walk(dir):
        result = result + [path]
        result = result + [os.path.join(path, file) for file in files]
    return result

We can use them to see what workers can see what data

with in_org('metoffice'):
    metoffice_data = tree('/data')
    
metoffice_data.compute()    
['/data',
 '/data/metoffice',
 '/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc']
with in_org('eumetsat'):
    eumetsat_data = tree('/data')
    
eumetsat_data.compute()    
['/data', '/data/eumetsat', '/data/eumetsat/observations.nc']

And we can identify each worker on the task stream observable at http://localhost:8787/status

def show_all_workers():
    my_org().compute(workers='metoffice-0')
    my_org().compute(workers='metoffice-1')
    my_org().compute(workers='metoffice-2')
    my_org().compute(workers='metoffice-3')
    my_org().compute(workers='eumetsat-0')
    my_org().compute(workers='eumetsat-1')
    my_org().compute(workers='eumetsat-2')
    my_org().compute(workers='eumetsat-3')

show_all_workers()

Run A Computation

Here we pin accessing metoffice data to its workers, and eumetsat to its workers. After that we let Dask work things out.

%%time
with in_org('metoffice'):    
    dataset = xarray.open_dataset('/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})

with in_org('eumetsat'):
    comparison_dataset = xarray.open_dataset('/data/eumetsat/observations.nc').chunk({"latitude": 10})    
    
averages = dataset.mean('realization', keep_attrs=True)
diff = averages.isel(height=5) - comparison_dataset

show_all_workers()
diffed = diff.compute(optimize_graph=False)    
    
fig = plt.figure(figsize=(6, 6))
plt.imshow(diffed.to_array()[0,...,0], origin='lower')
CPU times: user 327 ms, sys: 66.6 ms, total: 393 ms
Wall time: 4.38 s





<matplotlib.image.AxesImage at 0x7fdd5016ab30>

png

We know that at the very least, only the right workers can load the data from file. After that, Dask’s preferring of workers with the data will prefer to run the mean calculation on the metoffice. But when they get busy, then work stealing might be happening.

Let’s see whether we can observe any work stealing.

Work Stealing With Our Revisions

Remember that we are currently running our revised scheduler.

We’ll run just the averages bit. That should not be run on eumetsat, as the data isn’t there.

With the mildly hacked scheduler we’re running at the moment, that works fine - you can see it’s only run on the first 4 workers in the http://localhost:8787/status

%%time
with in_org('metoffice'):    
    dataset = xarray.open_dataset('/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})

averages = dataset.mean('realization', keep_attrs=True)
show_all_workers()
with get_task_stream() as ts:
    averaged = averages.compute(optimize_graph=False)     
CPU times: user 313 ms, sys: 148 ms, total: 461 ms
Wall time: 3.45 s

Here are the workers it’s run on

workers = set((each['worker'] for each in ts.data))
assert len(workers) == 4
workers
{'tcp://127.0.0.1:35315',
 'tcp://127.0.0.1:38781',
 'tcp://127.0.0.1:39893',
 'tcp://127.0.0.1:41725'}

Work Stealing Without Our Revisions

If we don’t use our hacked scheduler though, it goes wonky. Let’s see that by killing everything

!./docker-reset.sh
a40d54a3a528
b49f01a92f1f
a361494a9e20
Deleted Containers:
a40d54a3a528c9a0cd6b8b3e8af2defe1d81f58980af0a1fe3aaeccf1b8f83f6
b49f01a92f1f6f6e79c07b837cc247c55480f09e97ecdbb8fff15f1f4871c240
a361494a9e209850f9e514ec2e604122a6bf30498a95f132bca3647342ea4f75

Total reclaimed space: 571.5MB

and run up with the standard scheduler

%%bash --bg
docker run --network host metoffice/irisxarray dask-scheduler
/home/ec2-user/miniconda3/envs/irisxarray/lib/python3.10/site-packages/distributed/client.py:1096: VersionMismatchWarning: Mismatched versions found

+-------------+-----------+-----------+---------+
| Package     | client    | scheduler | workers |
+-------------+-----------+-----------+---------+
| dask        | 2022.01.0 | 2022.01.1 | None    |
| distributed | 2022.01.0 | 2022.01.1 | None    |
+-------------+-----------+-----------+---------+
  warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"]))

Start 4 metoffice workers, in separate processes

%%bash --bg
scheduler_host=localhost:8786
org_name=metoffice
worker_cmd="dask-worker $scheduler_host --nprocs 4 --nthreads 1 --resources org-$org_name=1 --name $org_name"
docker run --network host --mount type=bind,source="$(pwd)"/$org_name-data,target=/data/$org_name --env ORG_NAME=$org_name metoffice/irisxarray $worker_cmd

Start 4 eumetsat workers, in separate processes

%%bash --bg
scheduler_host=localhost:8786
org_name=eumetsat
worker_cmd="dask-worker $scheduler_host --nprocs 4 --nthreads 1 --resources org-$org_name=1 --name $org_name"
docker run --network host --mount type=bind,source="$(pwd)"/$org_name-data,target=/data/$org_name --env ORG_NAME=$org_name metoffice/irisxarray $worker_cmd

And a client to talk to them

from dask.distributed import Client
import dask

client = Client('localhost:8786')
/home/ec2-user/miniconda3/envs/irisxarray/lib/python3.10/site-packages/distributed/client.py:1096: VersionMismatchWarning: Mismatched versions found

+-------------+----------------+----------------+----------------+
| Package     | client         | scheduler      | workers        |
+-------------+----------------+----------------+----------------+
| dask        | 2022.01.0      | 2022.01.1      | 2022.01.1      |
| distributed | 2022.01.0      | 2022.01.1      | 2022.01.1      |
| python      | 3.10.2.final.0 | 3.10.0.final.0 | 3.10.0.final.0 |
+-------------+----------------+----------------+----------------+
  warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"]))

Now run the calculation viewing the task stream

with in_org('metoffice'):    
    dataset = xarray.open_dataset('/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})

averages = dataset.mean('realization', keep_attrs=True)
show_all_workers()
with get_task_stream() as ts:
    averaged = averages.compute(optimize_graph=False)     

That turns out to have been run on every worker

workers = set((each['worker'] for each in ts.data))
assert len(workers) == 8
workers
{'tcp://127.0.0.1:35075',
 'tcp://127.0.0.1:36303',
 'tcp://127.0.0.1:37557',
 'tcp://127.0.0.1:38653',
 'tcp://127.0.0.1:39673',
 'tcp://127.0.0.1:39953',
 'tcp://127.0.0.1:40385',
 'tcp://127.0.0.1:46845'}

Our Revisions

To be honest, I don’t know why our revisions have such an effect on work stealing.

Given our PR, resource-aware-scheduler.py looks like this

#!/usr/bin/env python

import click
import dask
from distributed.cli.dask_scheduler import main as scheduler_main
from distributed.stealing import WorkStealing


# Needs the hacked version of stealing.py to be installed in the conda environment to work
class OrganisationAwareWorkStealing(WorkStealing):

    def potential_thieves_for(self, ts, idle, sat):
        if ts.annotations.get('cross_org_stealing_allow', False):
            return super().potential_thieves_for(ts, idle, sat)
        else:
            return [ws for ws in idle if self.org_aware_can_steal(ws, ts, sat)]

    def org_aware_can_steal(self, thief, ts, victim):
        return False if _org_of(thief) != _org_of(victim) else self.can_steal(thief, ts, victim)


def _org_of(worker):
    orgs = [item for item in worker.resources if item.startswith('org')]
    return orgs[0] if orgs else None


@click.command()
def dask_setup(scheduler):
    scheduler.add_plugin(OrganisationAwareWorkStealing(scheduler))


if __name__ == '__main__':
    with dask.config.set({
        'distributed.scheduler.work-stealing': False,
        'distributed.scheduler.preload': __file__}
    ):
        scheduler_main()

I have yet to work out why this prevents work stealing when the standard resource constraint in stealing.py

for resource, value in ts.resource_restrictions.items():
            try:
                supplied = thief.resources[resource]
            except KeyError:
                return False
            else:
                if supplied < value:
                    return False

does not.


So, I now remember why this scheme works well even though the ‘downstream’ tasks are not annotated. It’s because our org_aware_can_steal takes no notice of the task when deciding if a worker can steal, it only uses the resources attached to the worker to establish whether or not they are in the same organisation. So it forbids work stealing from one organisation to another for any task.

As work stealing is a different mechanism than the task scheduling this doesn’t prevent the movement of data to workers, or the scheduling of tasks where the data is, it just stops idle workers from grabbing tasks across organisations, on the assumption that that is likely to be prohibitively expensive.

Thanks for the detailed experiment @dmcg. It seems like your tweaks do make your particular test case work. I’m still not sure though that they’ll address other scenarios.

As I mentioned before, I still think annotating downstream tasks is the only correct way to implement this. Changing work stealing logic as you did may seem to work right now, but it would be using private APIs, and brittle to many possible changes/refactorings in scheduler logic.

Additionally, it does nothing to prevent tasks from simply being scheduled in places you don’t want them. Task scheduling is far more sensitive to work busyness than avoiding transfers, and tends to underestimate transfer cost already, which in your case would be massive underestimates Scheduler underestimates data transfer cost for small transfers · Issue #5324 · dask/distributed · GitHub. I’d be curious if your diff calculation reliably runs in eumetsat, or if sometimes parts of the comparison_dataset get transferred over to metoffice (especially on a busy real-world cluster).

I wrote a little utility to do this. I haven’t tested it on real clusters at all, but I’d be curious if it works for you (would also welcome contributions):

Thanks for helping my understanding @gjoseph92 . We don’t expect running a cluster across datacenters to be a trivial thing, but given the huge potential of not having to assemble all the required data in one place, it seems like it’s worth a punt.

I’m sure that you appreciate that our reason for preferring to not annotate downstream tasks is to relieve the client from knowledge of where the data is altogether. So instead of

with in_org('metoffice'):    
    dataset = xarray.open_dataset('/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})

with in_org('eumetsat'):
    comparison_dataset = xarray.open_dataset('/data/eumetsat/observations.nc').chunk({"latitude": 10})    
    
averages = dataset.mean('realization', keep_attrs=True)
diff = averages.isel(height=5) - comparison_dataset

show_all_workers()
diffed = diff.compute(optimize_graph=False)    

researchers could just write

dataset = catalog.open_dataset('/some-predictions/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})
comparison_dataset = catalog.open_dataset('/some-observations/observations.nc').chunk({"latitude": 10})    
    
averages = dataset.mean('realization', keep_attrs=True)
diff = averages.isel(height=5) - comparison_dataset

show_all_workers()
diffed = diff.compute()    

and everything would just work.

Of course this could get horrendously inefficient, but that’s just a truth of distributed computing - we aim for simple and have ways of tuning how things are done when hints are required.

No doubt a more resilient system would need some representation of 'locality" inside Dask, or at least some formal cooperation from the scheduler. Frankly I’m amazed that we managed to get this far with so few changes to internal code, which is a huge tribute to Dask.

Thanks so much for taking the time to do this (and pool is such a better name than org ;-), I’ll try it out.

With a minor tweak this seems to work at least as well as the hacked scheduler @gjoseph92, thank you!

We’ll take this version into some cross data centre trails

Using Worker Pools

This sheet is to try out Gabe’s worker pools.

It uses Docker containers to pretend to have 3 locations.

  1. This computer, where the client is running
  2. Metoffice
  3. Eumetsat

The idea is that tasks accessing data available in metoffice should be run there, and similarly with eumetsat. The algoritm should be defined, and results rendered, here.

We use Docker to host a scheduler and workers in such a way that all data in /data is visible to the client machine (this one) for metadata visibility, but the metoffice container can only see /data/metoffice, and eumetsat /data/metoffice. This keeps our data separation honest.

Setup

First some imports

from time import sleep
import dask
from dask.distributed import performance_report, get_task_stream
import pytest
import ipytest
import xarray
import matplotlib.pyplot as plt

ipytest.autoconfig()

Clear out the cluster (docker stop" requires at least 1 argument means that nothing was running)

%%bash
docker stop $(docker ps -a -q)
docker container prune --force
84f934c59253
2e22679604f7
fb890cf81510
Deleted Containers:
84f934c592536d19ffbe1228a2417b328e229dfe8f0ce63070b6e97386b89503
2e22679604f73ef3241afe374e86cb5bbdc8314b14682bf2690497b5ec35b857
fb890cf815104d6f232ffce56ed1d4fca5472b7c55491d40da14707f6914f51b

Total reclaimed space: 24.79kB

Our Vision

Here we show an xarray calculation run on metoffice and eumetsat workers.

Run Up a Cluster

First run a scheduler. This time we only need the standard scheduler.

%%bash --bg
docker run --network host metoffice/irisxarray dask-scheduler

Start 4 metoffice workers, in separate processes. These have a resource pool-metoffice=1

%%bash --bg
scheduler_host=localhost:8786
org_name=metoffice
worker_cmd="dask-worker $scheduler_host --nprocs 4 --nthreads 1 --resources pool-$org_name=1 --name $org_name"
docker run --network host --mount type=bind,source="$(pwd)"/$org_name-data,target=/data/$org_name --env ORG_NAME=$org_name metoffice/irisxarray $worker_cmd

Start 4 eumetsat workers, in separate processes, and with pool-eumetsat=1

%%bash --bg
scheduler_host=localhost:8786
org_name=eumetsat
worker_cmd="dask-worker $scheduler_host --nprocs 4 --nthreads 1 --resources pool-$org_name=1 --name $org_name"
docker run --network host --mount type=bind,source="$(pwd)"/$org_name-data,target=/data/$org_name --env ORG_NAME=$org_name metoffice/irisxarray $worker_cmd

And a client to talk to them - you can click through to the Dashboard on http://localhost:8787/status

from dask.distributed import Client
import dask

client = Client('localhost:8786')
/home/ec2-user/miniconda3/envs/irisxarray/lib/python3.10/site-packages/distributed/client.py:1096: VersionMismatchWarning: Mismatched versions found

+-------------+----------------+----------------+----------------+
| Package     | client         | scheduler      | workers        |
+-------------+----------------+----------------+----------------+
| distributed | 2022.01.0      | 2022.01.1      | 2022.01.1      |
| python      | 3.10.2.final.0 | 3.10.0.final.0 | 3.10.0.final.0 |
+-------------+----------------+----------------+----------------+
  warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"]))

We can identify each worker on the task stream observable at http://localhost:8787/status

import os
from dask import annotate, delayed

@delayed
def my_org():
    return os.environ['ORG_NAME']


def show_all_workers():
    my_org().compute(workers='metoffice-0')
    my_org().compute(workers='metoffice-1')
    my_org().compute(workers='metoffice-2')
    my_org().compute(workers='metoffice-3')
    my_org().compute(workers='eumetsat-0')
    my_org().compute(workers='eumetsat-1')
    my_org().compute(workers='eumetsat-2')
    my_org().compute(workers='eumetsat-3')

show_all_workers()

Run A Computation

Here we pin accessing metoffice data to its workers, and eumetsat to its workers. After that we let Dask work things out with the propagate_pools context manager.

%%time
from dask_worker_pools import pool, propagate_pools, visualize_pools

with pool('metoffice'):    
    dataset = xarray.open_dataset('/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})

with pool('eumetsat'):
    comparison_dataset = xarray.open_dataset('/data/eumetsat/observations.nc').chunk({"latitude": 10})    
    
averages = dataset.mean('realization', keep_attrs=True)
diff = averages.isel(height=5) - comparison_dataset

with propagate_pools():
    show_all_workers()
    diffed = diff.compute(optimize_graph=False)    
    
fig = plt.figure(figsize=(6, 6))
plt.imshow(diffed.to_array()[0,...,0], origin='lower')
CPU times: user 297 ms, sys: 53.1 ms, total: 350 ms
Wall time: 4.88 s





<matplotlib.image.AxesImage at 0x7f4fa95bb8b0>

png

We know that at the very least, only the right workers can load the data from file. After that, Dask’s preferring of workers with the data will prefer to run the mean calculation on the metoffice.

Let’s run just the averages bit. That should not be run on eumetsat, as the data isn’t there. You can see it’s only run on the first 4 workers in the http://localhost:8787/status

%%time

with pool('metoffice'):    
    dataset = xarray.open_dataset('/data/metoffice/000490262cdd067721a34112963bcaa2b44860ab.nc').chunk({"latitude": 10})

averages = dataset.mean('realization', keep_attrs=True)


with propagate_pools():
    show_all_workers()
    with get_task_stream() as ts:
        averaged = averages.compute()     
CPU times: user 305 ms, sys: 150 ms, total: 455 ms
Wall time: 3.54 s

Here are the workers it’s run on

workers = set((each['worker'] for each in ts.data))
assert len(workers) == 4
workers
{'tcp://127.0.0.1:34111',
 'tcp://127.0.0.1:39351',
 'tcp://127.0.0.1:40753',
 'tcp://127.0.0.1:46619'}

If we don’t make the calculation inside propagate_pools, then work-stealing happens

%%time

show_all_workers()
with get_task_stream() as ts:
    averaged = averages.compute(optimize_graph=False)     
CPU times: user 324 ms, sys: 149 ms, total: 473 ms
Wall time: 4.12 s

Note that we had to specify optimize_graph=False here. Without that, the data cannot even be loaded, because the resource annotation can be lost from the loading tasks and then some of these will be run on eumetsat and fail.

Here are the workers it’s run on

workers = set((each['worker'] for each in ts.data))
assert len(workers) == 8
workers
{'tcp://127.0.0.1:34111',
 'tcp://127.0.0.1:37409',
 'tcp://127.0.0.1:37991',
 'tcp://127.0.0.1:39351',
 'tcp://127.0.0.1:40753',
 'tcp://127.0.0.1:42839',
 'tcp://127.0.0.1:45975',
 'tcp://127.0.0.1:46619'}
1 Like

Something I should note on my utility: propagate_pools() also automatically sets "optimization.fuse.active": False so that annotations aren’t lost (this is a little more fine-grained than optimize_graph=False, since it still allows high-level graphs optimizations to happen, which can help a bit).

Looking forward to hear how this goes! I do think that for real-world cases, you’ll want to change the logic like I started doing in Avoid restricting when not worth it by gjoseph92 · Pull Request #1 · gjoseph92/dask-worker-pools · GitHub to avoid over-restricting things.

2 Likes