Overlapping Computations and Cloning

I am curious to hear thoughts on cloning blocks for a function which use overlapping computations.

Common examples are things like rechunk or map_overlap where some key is needed by multiple processes. In my case I have arrays which are usually 4/5 dimensional and I want to perform some action on the entire dataset. Part of the problem is that Dask does not release memory from keys that need to be reused again. For an N-dimensional dataset that is chunked in more than 1 dimension this becomes a large problem as the number of keys held in memory starts to spike.

Reloading the data from Disk is a possible solution here. A naive approach would be to do something like


overlapped_data = da.overlap.overlap(data, depth=overlapped_depth,boundary=boundary)

data_clones = da.concatenate([clone(b, omit=data_overlapped) for b in overlapped_data.blocks]

mapped = data_clones.map_blocks(func,**kwargs)
trimmed_data = da.overlap.trim_internal(mapped, overlapped_depth)

In this case the entire task tree is cloned for one dimension of the data resulting in the data being reloaded from data for one of the chunked dimensions. This reduces the memory usage by a fair amount.

I was wondering if there is a way to generalize this or make it more efficient?

Hi @CSSFrancis, welcome here!

This looks like a very interesting question, but it’s a bit hard to grasp. Could you come up with a small reproducible example using generated data so we could play with it?

If I understand correctly, with a standard approach you are running into memory problems?

@guillaumeeb Let me see if I can go into a bit more detail.

Currently I am trying to take a large 4/5-D array and apply something like a Gaussian filter to the entire dataset.

The entire workflow is something like:

  1. Load the data from a compressed chunked storage
  2. Overlap the dataset
  3. Apply some function
  4. Trim the data
  5. Save the chunked data

The original dataset is usually too large to fit comfortably into RAM.

The major problem comes when you are trying to do overlapping computations when your dataset is chunked in more than one dimension and when trying to run in parallel.

Let’s start with the simplest scheduling task which is just operating on 1 core.

If I have a dataset where data.shape == (64,64,256,256) with a chunking structure of data.chunks==(16, 16, 256,256) then I would imagine that the scheduler would load and operate in a serpentine manner to maximize the reuse of keys and reduce the amount of total memory held

±—±—±—±—+ ±—±—±—±—+
| a | b | c | d | | 1 | 2 | 3 | 4 |
±—±—±—±—+ ±—±—±—±—+
| e | f | g | h | | 8 | 7 | 6 | 5 |
±—±—±—±—+ ±—±—±—±—+
| i | j | k | l | | 9 | 10 | 11 | 12 |
±—±—±—±—+ ±—±—±—±—+
| m | n | o | p | | 16 | 15 | 14 | 13 |
±—±—±—±—+ ±—±—±—±—+

  1. Operate on chunk a (load a,b,e,f into memory) (a,b,e,f in memory)
  2. Operate on chunk b (load c,g into memory) (a,b,e,f,c,g in memory)
  3. Operate on chunk c(load d, h into memory) (a,b,e,f,c,g, d,h in memory)
  4. And so on…

The issue is that a chunk a doesn’t leave memory until it is two rows below the first row so assuming a logical scheduling the maximum number of chunks loaded at one time will be 3 times the number of chunks in the slow dimension. In this case that would be 12. Of course this gets much worse when you have more chunked dimensions/ more chunks in one dimension.

For a data.shape == (1024,1024,256,256) with a chunking structure of data.chunks==(16, 16, 256,256) the total number of chunks loaded by one core is equal to 3*64 or 192 chunks at a minimum assuming that the scheduling is ideal for memory consumption. In practice I don’t believe that this is the case as introducing multiple cores often results in copied keys etc. Even running on 8 cores with 15 GB of RAM per core I routinely see the above process fail or get stuck in an endless cycle of dropping and recreating keys. This can be solved somewhat by only having one chunk but you end up with a problem where as you increase the size of the dataset you have to increase the memory per core.

My solution to this is to clone the key for the loading data process.

Using the example from above you could:

  1. Operate on chunk a (load a,b,e,f into memory) (a,b,e,f in memory)
  2. Operate on chunk b (load c,g into memory) (a,b,e,f,c,g in memory)
  3. Operate on chunk c(load d, h into memory, drop a,e from memory (b,f,c,g, d,h in memory)
  4. Operate on chunk f (reload a,e into memory, drop d,h,l) (a,b,c,e,f,g,i,j,k in memory)

In this case every core ideally only has 9 chunks in memory at the same time at the added cost of each key being loaded 3 times rather than once. For systems which aren’t I-O bound or penalized for multiple I-O this seems better than spiking memory usage to multiple times the size of the input array.

My question is two part.

  1. Does this make sense. Is there a way to handle something like this in the settings or should I be manipulating the graph in some way by cloning the tasks that load the data?

  2. If I do need to manipulate the graph what are my best options? I can play around with things a little bit but I am kind of confused by the clone and bind functions.

I’ll try and create a minimum working example and post that a little later today

Something like this is kind of what I had in mind if you replace the test array with loading something from disk.

import dask.array as da
from dask.graph_manipulation import clone
from scipy.ndimage import gaussian_filter

test_arr = da.ones((64,64), chunks=(16,16))
display(test_arr)
overlapped_array = da.overlap.overlap(test_arr,depth=2, boundary=None)
overlapped_array_cloned = da.concatenate([clone(b) for b in overlapped_array.blocks]) # clone only one dimension of blocks
filtered_overlapped = overlapped_array_cloned.map_blocks(gaussian_filter, sigma=1)
filtered = da.overlap.trim_overlap(filtered_overlapped, depth=2, boundary=None)

This definitely makes sense to me! I’m just wondering something (did not try your code yet), when you say:

Isn’t it what it is about? Releasing some memory and reading some keys again from disk? Or in Dask default behavior it tries to keep as long as possible data in memory?

I guess the problem is that Dask Scheduler is not advanced enough to schedule chunk processing optimally (e.g. in the serpentine way). Workers will get a bunch of tasks to process, and I don’t think they will necessarily get neighboring ones. But I might be wrong here.

I don’t see anything like that. But again I’m not really an expert in Dask Array.

Thanks for the example, does it work? I didn’t tested it yet.

Also, could you provide the more direct approach that you’d try to use but blow up the memory?

Just ran your code.

I’m not sure what you are trying to do here:

It looks like this is just cloning the entire overlapped array block by block. It really complexify the graph and I cannot see the purpose of it.

Hmm it’s possible that this workflow no longer works due to the HighLevelGraph abstraction. In the past I think that each block had a unique key so cloning each block would result in the block being recreated.

In this case this means create the array full of ones for that one block. I’m not fully versed in the HighLevelGraph abstraction so I might have to look into that further.

You can probably ignore the comment above…

It does appear to be working as intended. I just had to cull part of the task graph to make things look nice.

import dask.array as da
from dask.graph_manipulation import clone
from scipy.ndimage import gaussian_filter

test_arr = da.ones((32,32), chunks=(16,16))
overlapped_array = da.overlap.overlap(test_arr,depth=2, boundary=None)
filtered_overlapped = overlapped_array.map_blocks(gaussian_filter, sigma=1)
filtered = da.overlap.trim_overlap(filtered_overlapped, depth=2, boundary=None)
filtered.visualize()

overlapped_array_cloned = da.concatenate([clone(b, assume_layers=True) for b in overlapped_array.blocks]) # clone only one dimension of blocks
filtered_overlapped = overlapped_array_cloned.map_blocks(gaussian_filter, sigma=1)
filtered = da.overlap.trim_overlap(filtered_overlapped, depth=2, boundary=None)
optimize(filtered)[0].visualize()

So you can see that the second graph is a bit more embarrassingly parallel because the data creation is duplicated rather than reused many times. In one dimension the problem is entirely separable but in the other dimension we have a problem that still involves a lot of cross talk. If we take this to the extreme.

import numpy as np

from functools import reduce 
from operator import mul

def reshape(lst, shape):
    if len(shape) == 1:
        return lst
    n = reduce(mul, shape[1:])
    return [reshape(lst[i*n:(i+1)*n], shape[1:]) for i in range(len(lst)//n)]

test_arr = da.ones((32,32), chunks=(16,16))
overlapped_array = da.overlap.overlap(test_arr,depth=2, boundary=None)
overlapped_array_cloned = da.block(reshape([clone(b, assume_layers=True) for b in overlapped_array.blocks.ravel()], (2,2))) # clone 2 dimensions of blocks
filtered_overlapped = overlapped_array_cloned.map_blocks(gaussian_filter, sigma=1)
filtered = da.overlap.trim_overlap(filtered_overlapped, depth=2, boundary=None)
optimize(filtered)[0].visualize()

The problem is now entirely embarrassingly parallel at the cost of each of our original keys being generated 4 separate times.

When I try to do this with data I loaded myself I get a weird error when trying to compute but when I do this with an array of ones I don’t get anything.

This appears to be related to Error when computing a cloned graph from xarray.open_dataset · Issue #9621 · dask/dask · GitHub

le ~/micromamba/envs/dev/lib/python3.9/site-packages/dask/array/core.py:120, in getter()
    118     lock.acquire()
    119 try:
--> 120     c = a[b]
    121     # Below we special-case `np.matrix` to force a conversion to
    122     # `np.ndarray` and preserve original Dask behavior for `getter`,
    123     # as for all purposes `np.matrix` is array-like and thus
    124     # `is_arraylike` evaluates to `True` in that case.
    125     if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):

TypeError: string indices must be integers

I have to admit that your skills in Dask Arrays and graph manipulation are far above mine :smile:!

I still don’t get what is the magic under:

That makes Dask build two independent graphs (in this simple case) and read or produce the input blocks several times.

Anyway, this is really interesting and promising! If I understaned correctly, with this strategy and the more extreme one, you are able to chose wether Dask should try to keep chunks in memory or build/read them again in order to lower memory pressure.

This is great, I’ll try to ping some more experienced folks than me on this topic. I don’t think there is such a mechanism built-in. cc @mrocklin @rjzamora @jrbourbeau.

Were you able to test the none extreme strategy? Did you have good performances?

Hi @CSSFrancis,

I routinely see the above process fail or get stuck in an endless cycle of dropping and recreating keys.

This to me is a much bigger problem than everything else.

Dask is designed to cope with the use case you described by spilling/unspilling locally to disk on the workers. “dropping and recreating keys” is a symptom of workers being killed because they reach the distributed.worker.memory.terminate threshold (95% of memory_limit); you’ll find notice of that in the nannies’ logs.
What I suspect is happening is that your disk is lagging behind; e.g. your tasks produce more memory than the spill-to-disk activity can consume. This is normally OK, due to the distributed.worker.memory.pause threshold (80%) which will block any further tasks from being executed. However, it won’t work if the heap of one of your tasks is more than (terminate - pause) * memory_limit / threads_per_worker.
For example, if you have 4 threads per worker, 15 GiB memory limit, and the default terminate and pause thresholds (0.95 and 0.8 respectively), then if your tasks take more than 0.55 GiB each worth of heap you risk killing your workers.
Try tweaking distributed.worker.memory so that pause kicks in earlier (a paused worker is easy to spot as it will appear red in the memory plot on the dashboard).

If I understand correctly, what you’re asking for is some heuristic that automatically detects fast tasks with a very large output and realizes that is more convenient to compute such tasks multiple times than to keep them in memory. There is nothing like that right now, but it’s a very well known missing feature that crops up with some regularity.

Hi @crusaderky thank you for your response!

Dask is designed to cope with the use case you described by spilling/unspilling locally to disk on the workers. “dropping and recreating keys” is a symptom of workers being killed because they reach the distributed.worker.memory.terminate threshold (95% of memory_limit); you’ll find notice of that in the nannies’ logs.

What I suspect is happening is that your disk is lagging behind; e.g. your tasks produce more memory than the spill-to-disk activity can consume. This is normally OK, due to the distributed.worker.memory.pause threshold (80%) which will block any further tasks from being executed. However, it won’t work if the heap of one of your tasks is more than (terminate - pause) * memory_limit / threads_per_worker.

For example, if you have 4 threads per worker, 15 GiB memory limit, and the default terminate and pause thresholds (0.95 and 0.8 respectively), then if your tasks take more than 0.55 GiB each worth of heap you risk killing your workers.

Hmm this is interesting. I’ll have to look at pausing the workers earlier. Spilling to disk seems to exacerbate the situation so in many cases it is probably better to just pause that worker. These tasks do consume a lot of memory and quite possibly the best thing to do would be to look at the underlying code in scipy and see if there is a better way to handle the memory there.

I finally got around to profiling the other day and realized that the new dask scheduling which throttles processes that create data seems to run pretty close to the ideal case which is exciting! Admittedly I had first done this before the dask 2022.11.0 and the performance was much much worse. Now it seems like I have much fewer problems with the small added cost of more communication between workers.

That would be a very ideal case but might be difficult to accomplish. I was more thinking about how to achieve that manually and if there is a better way of doing it rather than cloning keys. Something like a way to take a task graph which includes a lot of cross-talk and create something that is more embarrassingly parallel. An extreme example would be a total rechunk which requires the entire dataset to be loaded into RAM. You could potentially reduce the total RAM usage by reloading the chunks along some dimension. Maybe this type of discussion is better in some package like Rechunker — Rechunker 0.5.2.dev2+g0ac43e8 documentation though.