Specify that a given task use a huge amount of RAM to the Dask Ressource Manager

Hi :wave:

I’m having a issue to distribute a deep learning process over a big raster image (i’m using xarray as a dask interface).
I use a big image of about 1GB in size and I’m reading this image using xarray and chunk it in 10 different chunk of 10MB each.
Then for each chunk I’m doing a deep learning magic using ONNX Runtime. To do so I use xarray apply_ufunc (dask_gufunc wrapper, like a map_block) which apply my inference function predict on each chunk.
At this point I’m getting a nice graph of computation with 10 task, one for each chunk.

However my predict function is using a lot of ram, for a chunk of 100MB I’m using at the peak 300MB of RAM (because of convolution and other fancy stuff going on in it). This lead to a high memory usage and a crash while trying to execute with few ram, at the start of the computation each worker get associated with too many task because Dask think it will handle it but finally it explode.

So my question is: How can I communicate to the Dask ressource manager to NOT give too many predict task to my workers ?

Currently, I’m on a LocalCluster and I want to ensure that my work could run with approx 3GB of RAM.
I found that using the Worker Resources, I can specify a resources for a specific summit (ie, client.submit(process, d, resources={'MEMORY': "200MB"})) or with a dask.annotate context. However it seems that I can’t specify constraints on apply_ufunc task. EDIT: It seems to work using the dask.annotate context (see post behind), but the scheduler is not handling it well anyway.

Notebook Version of my examble: notebook version

Here is a minimal reproducible example:

Define Cluster and workers

# Load dask cluster and monitoring.
from distributed.client import Client
from distributed.deploy.local import LocalCluster

cluster = LocalCluster(
    # n_workers=4,
    # threads_per_worker=1,
    # processes=True,
    memory_limit="500MB"
)
client = Client(cluster)

Define custom function

With resnet18-v1-7.onnx from here (40MB)

import numpy as np
import onnxruntime as ort

def forward_onnx(image_tiled: np.ndarray) -> np.ndarray:
    *batch, c, h, w = image_tiled.shape
    image_tiled_batched = image_tiled.reshape(np.prod(batch), c, h, w)

    model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider'])

    outputs = model_session.run(
        output_names=["resnetv15_dense0_fwd"],
        input_feed={"data": image_tiled_batched.astype(np.float32)},
    )
    output_unbatched = outputs[0].reshape(*batch, -1).astype(np.float32)
    return output_unbatched

Load image

With image.tif from here (1GB)

import rioxarray
image = rioxarray.open_rasterio(  # type: ignore
    "./image.tif", 
    parse_coordinates=True, 
    #chunks={"x": "auto", "y": "auto"}
)

Make tile and a rolling windows with rolling

shift = 224
input_size = 224
output_size = 1000

image_tiled = (
    # (band, y, x).
    image.transpose()  # type: ignore # (x, y, band).
    .rolling(  # Rolling object, future computation of sliding_window_view
        dim={"x": input_size, "y": input_size}
    )  # Rolling x->16, y->16.
    .construct(  # Construct the rolling view and apply stride
        x="x_tile", y="y_tile", stride=shift,
    )  # (x, y, band, x_tile, y_tile).
    .chunk(  # Auto chunk (chunk_size ~ jobs_memory).
        ("auto", "auto", -1, -1, -1), merge_chunks=False
    )  # (x(chunked), y(chunked), band, x_tile, y_tile).
)

Compute Deep Learning for each tile on each chunk (as a batch)

I tried to annotate that this task could use up too 300MB per task. But this don’t seems to work.

# Summit future features view into computation graph (async).
from dask import annotate
import xarray as xr
with annotate(resources={'MEMORY': 0.300e9}):
    image_features: xr.DataArray = (
        xr.apply_ufunc(  # Call dask parralelized gufunc
            forward_onnx,
            image_tiled,
            input_core_dims=[["band", "y_tile", "x_tile"]],
            output_core_dims=[["features"]],
            keep_attrs=True,
            dask="parallelized",
            output_dtypes=[np.float32],
            dask_gufunc_kwargs = {"output_sizes": {"features": output_size}}
            
        )  # (x, y, features)
        .stack(xy=["x", "y"])  # (features, x×y('xy'))
        .transpose("xy", ...)  # (x×y('xy'), features)
        # .chunk(("auto", -1), merge_chunks=False)  # (xy(chunked), features)
    )

Compute

image_features_np = image_features.as_numpy()

According to .__dask_graph__() my annotation (here with 100MB) is well formated in the graph.

I don’t understand why even with the ressource annotation the Dask Scheduler give to many dask to my worker even if it’s explicitly too much for them.

The MEMORY resource is just an arbitrary label. Unless you start at least some of your workers stating that they offer that resource, the task will never get scheduled.

Everything else looks good. However beware that automatic graph optimizations may drop annotations; call compute()/persist() with optimize_graph=False to disable them.

1 Like

Okay I see, I thought that declaring memory_limit="500MB" in my cluster would automatically annotate my workers with this amount of memory.

About the graph optimization, you may talk about this issue: Graph optimization loses annotations · Issue #7036 · dask/dask · GitHub.

However even with MEMORY specify in my dask.config with (as show here):

with dask.config.set({"distributed.worker.resources.MEMORY": 0.500e9}):
   cluster = LocalCluster(
       # n_workers=4,
       # threads_per_worker=1,
       # processes=True,
       memory_limit="500MB"
   )

And image_features.persist(optimize_graph=False).

I still have memory issues with the scheduler giving too many task :frowning:

Ideally I would like to control that my scheduler will never ask for task using too many memory.

Question about annotate, does it annotate each task individually ? Meaning a annotation of 10MB will ensure that my worker have 10MB of memory to handle it. Or is 10MB for the whole layer ?

Annotations are for each task.

Works fine for me:

import time

import dask
import dask.array as da
import distributed

with dask.config.set({"distributed.worker.resources.MEMORY": 0.500e9}):
   client = distributed.Client(n_workers=1, threads_per_worker=2)

def f(x, i):
    time.sleep(1)
    return x + i

a = da.ones(2, chunks=1)
b = a.map_blocks(f, 1, dtype=a.dtype)

%time b.compute(optimize_graph=False)
CPU times: user 103 ms, sys: 35.8 ms, total: 139 ms
Wall time: 1.23 s

with dask.annotate(resources={"MEMORY": 0.400e9}):
    c = a.map_blocks(f, 2, dtype=a.dtype)

%time c.compute(optimize_graph=False)
CPU times: user 34.7 ms, sys: 11 ms, total: 45.7 ms
Wall time: 2.03 s

To debug your issue, after you’ve run the worklow, execute:

>>> client.run(lambda dask_worker: dask_worker.state.total_resources)
{'tcp://127.0.0.1:32999': {'MEMORY': 500000000.0}}
>>> client.run(lambda dask_worker: list(dask_worker.state.stimulus_log))
{'tcp://127.0.0.1:32999': [
  [...]
  # Note resource_restrictions=None
  ComputeTaskEvent(stimulus_id='compute-task-1666610020.7323325', key="('f-a34c10f6cca53ea8c0a0ccc8a00bec00', 0)", who_has={"('ones_like-5a84efbc02a961a1580dfeae33096db7', 0)": ('tcp://127.0.0.1:46737',)}, nbytes={"('ones_like-5a84efbc02a961a1580dfeae33096db7', 0)": 8}, priority=(0, 1, 1), duration=0.5, run_spec=SerializedTask(function=None, args=None, kwargs=None, task=None), function=None, args=None, kwargs=None, resource_restrictions={}, actor=False, annotations={}),
  [...]
  # Note resource_restrictions={'MEMORY': 400000000.0}
  ComputeTaskEvent(stimulus_id='compute-task-1666610022.750609', key="('f-7cb58bd08d211da8bf47490093dc32e4', 0)", who_has={"('ones_like-5a84efbc02a961a1580dfeae33096db7', 0)": ('tcp://127.0.0.1:46737',)}, nbytes={"('ones_like-5a84efbc02a961a1580dfeae33096db7', 0)": 8}, priority=(0, 1, 1), duration=1.0011636018753052, run_spec=SerializedTask(function=None, args=None, kwargs=None, task=None), function=None, args=None, kwargs=None, resource_restrictions={'MEMORY': 400000000.0}, actor=False, annotations={'resources': {'MEMORY': 400000000.0}}),
  [...]
}