Prevent dask array from `compute()` behavior

I’m using dask to process whole-slide imaging (WSI) data and it has been very helpful!

I’ve been working on affine transforming the input WSI with the strategy of pre-locating the output (transformed) image as a chunked image and figuring out the mapping between each of the output chunks and the input WSI. Before switching to dask, I was using zarr array as the input image, and the behavior was expected - with chunked input WSI, only the requested region was loaded into RAM when iterating over the output image chunks.

After switching to using dask array as input WSI, I’m seeing that when iterating over the output chunk, the input dask array seems to be compute() in each iteration. Here’s the graph of the following snippet, note the “finalize” node and its direct child (rectangle box) there

import dask.array as da
import numpy as np

ref_img = da.from_array(np.eye(2), chunks=1)
out_img = da.empty_like(ref_img)

out_img.map_blocks(
    lambda x, y: np.atleast_2d(y[0, 0])+1,
    y=ref_img,
    dtype=ref_img.dtype
).visualize('da-as-source.png')

If numpy array was used as the input image, here’s the snippet and graph

out_img.map_blocks(
    lambda x, y: np.atleast_2d(y[0, 0])+1,
    y=np.eye(2),
    dtype=ref_img.dtype
).visualize('npa-as-source.png')

(sorry new user can only embed one image)
link to npa-as-source.png

I was expecting the dask array would have the same behavior as the numpy array, i.e. it’ll be sent to each of the task without the “finalize” step and each task will just get_item and only the relevant data will be touched (in the snipped the one pixel at the upper left corner).

Is this the expected behavior? How I might be able to change the behavior?

Thanks in advance!

1 Like

Hi @Yu-AnChen

Please note that in the following statement:

out_img.map_blocks(lambda x: x[0, 0]+1, x=ref_img, dtype=ref_img.dtype)

each chunk of out_img will be the input to the lambda function, which I suspect is not want you really want, right?

I’m guessing you rather want the chunks of ref_img to be the input. In that case, the following statement will suffice:

out_img = da.map_blocks(lambda x: x + 1, ref_img, dtype=ref_img.dtype)

Naturally, in the latter case, pre-allocating out_img is not necessary.

1 Like

Thank you @ParticularMiner for replying! In short, I do want to map over the output image, which is probably not well illustrated in the minimal example.

I’m using reverse mapping (see the above schematics) which is to map and compute each of the blocks in the out_img from taking the corresponding regions in the src_img (it usually isn’t a first-block-to-first-block situation).

I did need to fix a few places in the minimal example so it actually compute() (the graph doesn’t change as I checked), so thank you for commenting :slight_smile:

@Yu-AnChen

If your mappings are linear, you may consider using b = dask.array.matmul(a, b) where a would represent the transformation matrix of the (forward/reverse) mapping, and b the image (source/destination).
dask.array.matmul — Dask documentation

A more complicated mapping will likely require the use of dask.array.core.blockwise(). dask.array.blockwise — Dask documentation

Thank you @ParticularMiner. Yes, I’m only doing linear transformations and I have an implementation that gives me the right output. But the issue is that when I tried to peek the result of one of the blocks by calling map_blocks_result.blocks[Y, X] I’m seeing it computes my lazy dask array first, then perform get_item.

Here’s a hopefully better but still minimal example -

import dask.array as da
import numpy as np
import zarr


np_img = np.eye(3) * np.arange(3)
zarr_img = zarr.array(np_img, chunks=1)
da_img = da.from_array(np_img, chunks=1)
out_img = da.empty_like(da_img)

def shift_up_left(block, moving_img, n_blocks_up_left, block_info=None):
    Y, X = block_info[None]['array-location']
    assert n_blocks_up_left >= 0, '`n_blocks_up_left` must >= 0'
    n = int(n_blocks_up_left)
    try:
        lower_right = moving_img[Y[0] + n, X[0] + n]
    except IndexError:
        lower_right = 0
    return np.atleast_2d(lower_right)


da_out = out_img.map_blocks(
    shift_up_left,
    moving_img=da_img,
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

np_out = out_img.map_blocks(
    shift_up_left,
    moving_img=np_img,
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

zarr_out = out_img.map_blocks(
    shift_up_left,
    moving_img=zarr_img,
    n_blocks_up_left=2,
    dtype=da_img.dtype
)


da_out.visualize('da.png')
np_out.visualize('np.png')
zarr_out.visualize('zarr.png')
# just to validate it runs as expected
In [2]: (da_out == np_out).compute()
Out[2]: 
array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]])

In [3]: (da_out == zarr_out).compute()
Out[3]: 
array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]])

the simplified function shift_up_left shift the chunks in the target image up and left by n_blocks_up_left. And here are the three graphs, again note the “finalize” node that only is present in the dask-array-as-input graph.

There are just 9 chunks in the demo, in the real case, there are thousands of chunks on the disk and I’d like to only read the necessary chunks when I peek at the results of a few blocks. I have no issue firing .compute() for getting the result of the whole transformed image - the main problem is when I’m only checking a few blocks, it should not need to go through that “finalize” node, I imagine. (Although even when computing the whole image, in principle I’m probably touching too much data that didn’t end up in the transformed image.)

Hi Yu-AnChen,

Sorry, I went to bed before I saw your reply.

It is apparent to me that you require special dask functions to solve your problem. You could use a combination of dask.array.pad() and slicing:

https://docs.dask.org/en/stable/generated/dask.array.pad.html?highlight=pad

import numpy as np
import dask.array as da


def shift_up_left(img, n):
    return da.pad(img[n:, n:], ((0, n), (0, n)))


np_img = np.eye(3) * np.arange(3)
da_img = da.from_array(np_img, chunks=1)
n_blocks_up_left = 2
out_img = shift_up_left(da_img, n_blocks_up_left)
out_img.compute()

Or, as I suggested in my previous post, you could use transformation matrices, which in your case would be kth-diagonal matrices which matrix-multiply your images from the left and right. It would be interesting to know which method is more performant.

The thing with map_blocks() is that whenever it receives any keyword argument whose value is a dask array (like moving_img=da_array), it converts the entire array into a numpy array before passing it into the supplied user function, hence the structure of the graph you are visualizing. This is usually not the ideal way to use dask.array as it is potentially memory-intensive. To avoid such computation, we do not pass an array as a keyword argument, but rather as a positional argument.

But after doing so, each branch of the graph will receive a separate chunk of da_array. Clearly, that is not what you want either, since you would like to access content from da_array chunks in branches other than the current branch. Such a requirement is often not trivial, which is why I suggested the dask special functions above to solve your problem.

3 Likes

Thank you @ParticularMiner so much for pointing out that

The thing with map_blocks() is that whenever it receives any keyword argument whose value is a dask array (like moving_img=da_array ), it converts the entire array into a numpy array before passing it into the supplied user function

this is exactly the issue! (Could you point me to where this is happening/documented?)

I’m re-using the above to demo your fix/solution

import dask.array as da
import numpy as np
import zarr
import tqdm.dask

np_img = np.eye(3) * np.arange(3)
zarr_img = zarr.array(np_img, chunks=1)
da_img = da.from_array(np_img, chunks=1)
out_img = da.empty_like(da_img)

def shift_up_left(block, moving_img, n_blocks_up_left, block_info=None):
    Y, X = block_info[None]['array-location']
    assert n_blocks_up_left >= 0, '`n_blocks_up_left` must >= 0'
    n = int(n_blocks_up_left)
    try:
        lower_right = moving_img[Y[0] + n, X[0] + n]
    except IndexError:
        lower_right = 0
    return np.atleast_2d(lower_right)


da_out = out_img.map_blocks(
    shift_up_left,
    moving_img=da_img,
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

np_out = out_img.map_blocks(
    shift_up_left,
    moving_img=np_img,
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

zarr_out = out_img.map_blocks(
    shift_up_left,
    moving_img=zarr_img,
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

for o, name in zip(
    [da_out, np_out, zarr_out],
    ['da', 'np', 'zarr']
):
    with tqdm.dask.TqdmCallback(
        ascii=True, desc=f"{name} as kwarg"
    ):
        o.compute()

First when passing all the dask, numpy, and zarr array as keyword argument, the tqdm indicates one additional computation step for the dask array (10 steps) as opposed to the others which have 9 steps. That one extra step I believe is the “finalize” node in the above graph.

da as kwarg: 100%|############| 10/10 [00:00<00:00, 10348.64it/s]
np as kwarg: 100%|##############| 9/9 [00:00<00:00, 18669.01it/s]
zarr as kwarg: 100%|############| 9/9 [00:00<00:00, 15357.50it/s]

After passing the dask array as a positional argument, it’ll be 9 steps now -

with tqdm.dask.TqdmCallback(
        ascii=True, desc='da as positional arg'
    ):
        out_img.map_blocks(
            shift_up_left,
            da_img,
            n_blocks_up_left=2,
            dtype=da_img.dtype
        ).compute()
da as positional arg: 100%|#####| 9/9 [00:00<00:00, 18113.60it/s]

Thanks so much for the helpful discussion and insights!!

@Yu-AnChen

You’re very welcome! Even though it wasn’t so much a fix as a cautionary explanation, since after changing the relevant parameter to a positional argument you are faced with the problem of how to access content in chunks that are not in the current branch. As a result, each block gets shifted by the given vector leaving blank spaces at the edges of each block! Which I cannot imagine is correct.

… Could you point me to where this is happening/documented?

The docs for map_blocks() state:

**kwargs
Other keyword arguments to pass to function. Values must be constants (not dask.arrays)

(dask.array.map_blocks — Dask documentation)

And from the source code, you can trace the following chain of calls:
dask.array.core.map_blocks()dask.array.blockwise.blockwise()dask.delayed.unpack_collections()dask.delayed.finalize()

In particular, in blockwise():

1 Like

Thanks for the references! Indeed, the behavior of passing the movint_img dask array as a positional arg is not what I wanted. As you also mentioned -

As a result, each block gets shifted by the given vector leaving blank spaces at the edges of each block! Which I cannot imagine is correct.

After some further hacking around, my current solution is to use functools.partial to pass in the dask array moving_img and here’s a glance at the workaround and performance. It’s definitely degraded :frowning: - probably due to the complexity of the graph.

In [1]: %paste
import dask.array as da
import numpy as np
import zarr
import tqdm.dask

np_img = np.eye(100) * np.arange(100)
zarr_img = zarr.array(np_img, chunks=1)
da_img = da.from_array(np_img, chunks=1)
out_img = da.empty_like(da_img)


def shift_up_left(block, moving_img, n_blocks_up_left, block_info=None):
    Y, X = block_info[None]['array-location']
    assert n_blocks_up_left >= 0, '`n_blocks_up_left` must >= 0'
    n = int(n_blocks_up_left)
    try:
        lower_right = moving_img[Y[0] + n, X[0] + n]
    except IndexError:
        lower_right = 0
    return np.atleast_2d(lower_right)


def shift_up_left_np(block, moving_img, n_blocks_up_left, block_info=None):
    return np.asarray(
        shift_up_left(block, moving_img, n_blocks_up_left, block_info)
    )


import functools

da_out = out_img.map_blocks(
    functools.partial(shift_up_left, moving_img=da_img),
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

da_out_np = out_img.map_blocks(
    functools.partial(shift_up_left_np, moving_img=da_img),
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

np_out = out_img.map_blocks(
    functools.partial(shift_up_left, moving_img=np_img),
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

np_out_np = out_img.map_blocks(
    functools.partial(shift_up_left_np, moving_img=np_img),
    n_blocks_up_left=2,
    dtype=da_img.dtype
)

import itertools

for p1, p2 in itertools.combinations(
    [np.asarray(o) for o in (da_out, da_out_np, np_out, np_out_np)],
    2
):
    assert np.all(p1 == p2)
## -- End pasted text --

In [2]: %timeit np.asarray(da_out)
3 s ± 24.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [3]: %timeit np.asarray(da_out_np)
10.9 s ± 87.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [4]: %timeit np.asarray(np_out)
558 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: %timeit np.asarray(np_out_np)
559 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In my actual use case, the functools.partial approach is helpful for the peeking purpose that allows me to quickly get the result of map_blocks_result.blocks[Y, X]; I need to do more comparison to decide whether I need two modes: 1) persist moving_img as a numpy array and 2) pass dask array as the moving_img for checking/debugging purposes. If the performance degradation is not acceptable in the targeted usages, I’ll need a switch between the two modes. (And just a note, I’m applying affine transformation to each block, the above shift_up_left is just for demo/testing)

Again, thanks for the helpful discussion!

@Yu-AnChen

After some further hacking around, my current solution is to use functools.partial to pass in the dask array

Yeah … I’ve been cautioned against using dask arrays that way in an earlier post (follow this link). But if it works for you, then that’s fine by me.

All the best with your work!

1 Like