Using da.delayed for Zarr processing: memory overhead & how to do it better?

For completeness, here’s the full version of our current solution:

The essential part was to split up the whole process into single tasks per chunk and use the region parameter in to_zarr to just write that specific ROI to the zarr file for any given task.
Also, when calling to_zarr with compute=False (in order to generate all the dask tasks that we can then process), it’s not necessary anymore to use dask.delayed on the function (and also in other tests, when the output is a to_zarr, it’s not actually required to call dask.delayed on the function, because the to_zarr part already enforces the computation to be delayed).

Here is the code using region properly
def process_zarr_regions(input_zarr, inplace=False, overwrite=True):
    if inplace and not overwrite:
        raise Exception('If zarr is processed in place, overwrite needs to be True')

    out_zarr = f"out_{input_zarr}"
    data_old = da.from_zarr(input_zarr)

    # Prepare output zarr file
    if inplace:
        new_zarr = zarr.open(input_zarr)
    else:
        new_zarr = zarr.create(
            shape=data_old.shape,
            chunks=data_old.chunksize,
            dtype=data_old.dtype,
            store=da.core.get_mapper(out_zarr),
            overwrite=overwrite,
        )
    n_c, n_z, n_y, n_x = data_old.shape[:]

    print(f"Input file: {input_zarr}")
    print(f"Output file: {out_zarr}")
    print("Array shape:", data_old.shape)
    print("Array chunks:", data_old.chunks)
    print(f"Image size: ({SIZE},{SIZE})")

    tasks = []
    regions = []
    for i_c in range(n_c):
        for i_z in range(n_z):
            for i_y in range(0, n_y - 1, SIZE):
                for i_x in range(0, n_x - 1, SIZE):
                    regions.append((slice(i_c, i_c+1), slice(i_z, i_z+1), slice(i_y,i_y+SIZE), slice(i_x,i_x+SIZE)))
    
    for region in regions:
        data_new = shift_img(data_old[region])
        task = data_new.to_zarr(url=new_zarr, region=region, compute=False, overwrite=overwrite)
        tasks.append(task)

    # Compute tasks sequentially
    # TODO: Figure out how to run tasks in parallel where save => batching
    # (where they don't read/write from/to the same chunk)
    for task in tasks:
        task.compute()

Here are how the memory profiles compare for the largest example (the 2, 16, 16000, 16000 case, it scales as before for the smaller examples):

Summary:

  1. Our initial indexing approach was the worst (blue line) => long runtime & high memory usage
  2. The mapblocks approach is a mix: Much shorter runtime, intermediate memory usage
  3. The new region approach runs all tasks sequentially. Thus, runtime is comparable to the initial indexing, memory usage is very low (even lower than mapblocks)
  4. There is a bit of variability between runs and saving to a new zarr file may be slightly faster than overwriting the ROIs in the existing file. But the overhead is quite acceptable for our use-case.

Downsides compared to the mapblocks implementation:
We don’t have dask handling the question of assigning overlapping indices. For what we are planning to do, we don’t want to write the same pixel position multiple times. But we need to do that handling.
Also, using an indexing approach does have some overhead that scales with the number of ROIs. When not doing the compute, runtime scales up from 1s (for 256 ROIs in the 2, 2, 16000, 16000 case) to ~6s (for 2048 ROIs in the 2, 16, 16000, 16000 case), while mapblocks just seems to vary a bit, but likely would stay constant even for higher number of ROIs.


For our use-case, we will probably mostly have dozens to hundreds of ROIs, so that should be fine. But if one applies this logic to arbitrary number of ROIs, scaling isn’t great.


Room for improvement:
I’m currently running all the ROIs sequentially. One could come up with a way to batch them and I’d assume that this would decrease runtime & increase memory usage. Such batching is non-trivial though, because we can only batch tasks that don’t need to write to the same underlying chunk in the zarr array, so may be a bit tricky to calculate this. Thus, achieving the same runtimes as mapblocks will not be trivial (but if we find a good way to do this, we can tune the runtime vs. memory usage trade-off explicitly).


Things to be tests
We’ll now need to test implementing this approach for our real-world use-cases again. Typically, the data is a bit bigger, though we normally have fewer ROIs (~100). The actual run functions will be doing much more than our dummy function at the moment, so the overhead may become negligible, but we’ll need to test that. Curious to see how this performs, whether we can use the inplace version and how easy it will be to do some parallelization.


In conclusion: This is the solution we need for our use-case of arbitrary ROIs. Splitting the processing into distinct tasks and using region in to zarr enables this use case. There are some trade-offs with complexity, overhead & runtime (but those trade-offs should be worth it for the ROI flexibility we’re gaining in our use-case).


Thanks a lot Davis Bennett for the support, everyone else on the napari Zulip & @jni here on the forum as well!

PS: Learning for anyone attempting this as well: The region property only takes tuples of slice, not e.g. list, single integers or such.

2 Likes