How to concatenate xarray-datasets read from zarr with dask

Hello,

I have a large number (~4400) of zarr files (each about 3 MB) that contain outputs from several model runs. Each of them contains several variables that I organized in a hierarchical fashion with xarray-datatree. I am interested in the statistics of quantities that I can compute from the original outputs.
In order to do this, I need to stack the datasets along a new dimension. The concatenated dataset is quite large (about 17 GB in memory, based on what I see in the dask dashboard). It still just about fits into memory, but some computations might then exceed the memory limit, so I think chunking it with dask is a good option. The approach that I have used so far worked for a smaller dataset, but with my larger dataset I am running into trouble. These are my questions:

  • How can I concatenate the dataset without reading everything into memory at once?
  • How can I save the concatenated dataset for future use?

Additionally, I would also like to better understand why the approach that I tried is not working very well. I have the feeling that I lack some background dask knowledge to understand the problems, so any hints on this are welcome, too.

My current approach is as follows:

  1. Read in datasets individually on a dask client:
    futures = client.map(
        read_model_outputs_from_zarr,
        paths,
        chunks="auto",
    )
    
    The read_model_outputs_from_zarr is a custom function that reads in the zarr files (containing hierarchically organized data) with datatree and then flattens them into xarray datasets. The chunk argument is passed down to xarray’s open_dataset function.
  2. Concatenate the individual datasets:
    result = client.submit(
        xr.concat, futures, pd.Index(model_run_ids, name="model_run_id"), coords="all"
    )
    ds = client.gather(result)
    
  3. Write the concatenated dataset to disk, so I do not have to repeat step 1 and 2 the next time (instead, I can read in the data easily using xr.open_dataset):
    rechunked = result.chunk(chunks).unify_chunks()
    rechunked.to_zarr(os.path.join(base_path, "all_outputs.zarr"))
    
    I have to rechunk the dataset first because otherwise zarr complains that the chunks are not uniform.

These are the problems I am currently facing, and that I want to avoid:

  • When I concatenate the datasets, a single worker uses a lot of memory, and I need to set a very large memory limit for the workers in order to be able to do the concatenation. Overall, the memory usage for the concatenation is very high. It seems to me that the whole dataset is in memory at once, which is exactly what I want to avoid with dask.
  • Writing the concatenated file to disk does not finish even after several hours, and the file size did not increase during all the time. When I interrupted the writing process, dask seemed to be in some kind of optimization, but I do not understand what it is about. This is from the back trace:
    File ~/mambaforge/envs/pymc/lib/python3.10/site-packages/dask/array/core.py:1170, in store(***failed resolving arguments***)
       1168 # Optimize all sources together
       1169 sources_hlg = HighLevelGraph.merge(*[e.__dask_graph__() for e in sources])
    -> 1170 sources_layer = Array.__dask_optimize__(
       1171     sources_hlg, list(core.flatten([e.__dask_keys__() for e in sources]))
       1172 )
       1173 sources_name = "store-sources-" + tokenize(sources)
       1174 layers = {sources_name: sources_layer}
    
    File ~/mambaforge/envs/pymc/lib/python3.10/site-packages/dask/array/optimization.py:48, in optimize(dsk, keys, fuse_keys, fast_functions, inline_functions_fast_functions, rename_fused_keys, **kwargs)
         45 if not isinstance(dsk, HighLevelGraph):
         46     dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
    ---> 48 dsk = optimize_blockwise(dsk, keys=keys)
         49 dsk = fuse_roots(dsk, keys=keys)
         50 dsk = dsk.cull(set(keys))
    

Some more details:

  • The dimensions in the datasets do not all have the same shapes. This is because some of the simulations did not finish, so the length of the “time” dimension in the data sets varies. In the extreme case, there might even be a few data sets that do not contain a single time point, that is, the variables are missing. However, xr.concat can handle this by filling up the missing variables or missing time point with NaN.

I am sorry that this post is a bit long, but I hope it contains enough information to understand my problem and questions. I appreciate advice on any aspect of my problem.

Thanks!

Hi @astoeriko, welcome to Dask community!

Thanks for the detailed post, I’ll try to answer some questions.

You are using Future API, which is immediate rather than lazy. This can be useful in some situations, but not in your case. This command will trigger the loading of every files in a distributed way among Dask Workers. This means that all your data will be loaded into memory, though potentially on different Workers.

If you want to avoid this, you need to use the Delayed API, so that data loading can be done in a streaming way when you trigger the computation.

Here, you’re just telling one Worker to do the concatenation, which means that this worker will need to gather all the chunks from the other workers to perform this operation.

More over, by gathering the final result, you are trying to transfer the whole result from the Worker which did the concatenation to your main process, which will also be memory and time expensive.

You want to directly call the xr.concat function from your main process on the Delayed objects you will build before.

The end of your workflow should be correct, even if rechunking can be expensive. But maybe there some other way to handle the non unifor chunks size.

Let me know if this is enough for you to try implementing your workflow differently.

1 Like

Oh, I see—thanks for clarifying! I will give it a try and report back.

1 Like

Sorry that it took me some time to get back to this.

I tested using the delayed API: I can now read in and concatenate the files lazily by using the @dask.delayed decorator for the function that reads in the data. However, I have the impression that this still does not really solve my problem. As soon as I trigger a compute, the whole dataset still needs to be hold in memory during the concatenation step.
To solve this, I had to change the order in which I do things. Previously, I tried:

  1. Read many datasets lazily,
  2. concatenate datasets,
  3. do computations that reduce the size of the dataset.

During step 2, the whole big dataset has to be loaded (as soon as a compute is triggerred). Instead, I now reversed the order of step 2 and 3. That is, I first reduce the size of each individual dataset and do the concatenation at the very end, so the I never need to hold the full dataset in memory.

I think I kind of expected that dask would be able to detect that it can change the order of the operations automagically, but I realize that this is probably more than could be expected.
The only drawback of my current approach is that I need to know in advance which computations I want to do. If I want to compute an additional quantity from my original data, I have to repeat all three steps—and step 1 is currently by far the slowest.