Reading data from netcdf or zarr files loads all data into memory

Hello!

I am working with climate model output (mostly in netcdf format) and came across some issues when reading those data and calculating averages with a distributed scheduler. Looking at the Dask dashboard, I find that basically all the data is read into memory at some point. Doing the same with the default (threaded) scheduler seems to work fine. Creating data directly with dask.array instead of reading from files also seems to work fine (with a distributed scheduler).

Here is a screenshot of an example dashboard:

I investigated different setups (all use ~45 GiB of data, and a LocalCluster with 2 workers à 4 GiB).

  1. Netcdf data read with xarray: The workers’ memory keeps filling up and eventually they start spilling to disk (see screenshot above). At some point, memory usage goes down again. No warnings/errors raised.
MWE
from pathlib import Path
from dask.distributed import Client
import dask.array as da
import xarray as xr

if __name__ == '__main__':
    # Create Dummy data
    shape = (365, 1280, 2560)
    chunks = (10, -1, -1)
    out_dir = Path("tmp_dir")
    nc_paths = [
        out_dir / "1.nc",
        out_dir / "2.nc",
        out_dir / "3.nc",
        out_dir / "4.nc",
        out_dir / "5.nc",
    ]
    for path in nc_paths:
        arr = xr.DataArray(
            da.random.random(shape, chunks=chunks),
            name="x",
            dims=("time", "lat", "lon"),
        )
        arr.to_netcdf(path)

    client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")

    # Read dummy data
    dss = [xr.open_dataset(p, chunks={"time": 10, "lat": -1, "lon": -1}) for p in nc_paths]
    arrs = [ds.x.data for ds in dss]

    # Concatenate and calculate weighted average
    arr = da.concatenate(arrs, axis=0)
    weights = da.ones_like(arr, chunks=arr.chunks)
    avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2))  # da.average behaves identically
    print(avg.dask)
    print(avg.compute())
Dask graph
HighLevelGraph with 18 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f03ec0b1c70>
 0. ones_like-c7c9cf4f6871a4068d6bda6f4e5141ea
 1. sum-142584b6bc297f1d60f68230643654e5
 2. sum-aggregate-4939d78fcafe476a8465a60e20df2c65
 3. original-open_dataset-x-760bd344ae92ba60483bc3f282e1b70e
 4. open_dataset-x-760bd344ae92ba60483bc3f282e1b70e
 5. original-open_dataset-x-59206dd4469c13d5957b7e3c3e3da309
 6. open_dataset-x-59206dd4469c13d5957b7e3c3e3da309
 7. original-open_dataset-x-1bafe17e170d74a2bbae3a9518808761
 8. open_dataset-x-1bafe17e170d74a2bbae3a9518808761
 9. original-open_dataset-x-61de7e5d23d744435a3d0116d35e6333
 10. open_dataset-x-61de7e5d23d744435a3d0116d35e6333
 11. original-open_dataset-x-25d1b20a8f923f1a653cd0ef1ce5163b
 12. open_dataset-x-25d1b20a8f923f1a653cd0ef1ce5163b
 13. concatenate-8f55d06266aaeabe8c000c30ce2a4048
 14. mul-d8a6804c719b299cad52f97707b11a86
 15. sum-012c92287950bccb2aba104d10dea9fa
 16. sum-aggregate-8a1d67d086a9a1d7841895d269782fd7
 17. truediv-8338ef8062243673d277f81853de876d
  1. Zarr data read with dask.array.from_zarr: This case is very similar to the netcdf case. The workers’ memory keeps filling up and eventually they start spilling to disk (see screenshot above). At some point, memory usage goes down again. No warnings/errors raised.
MWE
from pathlib import Path
from dask.distributed import Client
import dask.array as da

if __name__ == '__main__':
    # Create Dummy data
    shape = (365, 1280, 2560)
    chunks = (10, -1, -1)
    out_dir = Path("tmp_dir")
    zarr_paths = [
        out_dir / "1.zarr",
        out_dir / "2.zarr",
        out_dir / "3.zarr",
        out_dir / "4.zarr",
        out_dir / "5.zarr",
    ]
    for path in zarr_paths:
        arr = da.random.random(shape, chunks=chunks)
        da.to_zarr(arr, path)

    client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")

    # Read dummy data
    arrs = [da.from_zarr(p) for p in zarr_paths]

    # Concatenate and calculate weighted average
    arr = da.concatenate(arrs, axis=0)
    weights = da.ones_like(arr, chunks=arr.chunks)
    avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2))  # da.average behaves identically
    print(avg.dask)
    print(avg.compute())
Dask graph
HighLevelGraph with 18 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7fae1a6c8980>
 0. ones_like-b2eeb017d903fe0162dbf9fabb034160
 1. sum-bb0b49a151ee1ec3649751f9df4dc22d
 2. sum-aggregate-799a254cbdfcab0da6691f0d1f8abc3a
 3. original-from-zarr-c97365cf561afe64ebf973c34a876e62
 4. from-zarr-c97365cf561afe64ebf973c34a876e62
 5. original-from-zarr-6d55f3c94c52e1e2179024850b15b00a
 6. from-zarr-6d55f3c94c52e1e2179024850b15b00a
 7. original-from-zarr-4e162aaecc691c159a5d43200b078036
 8. from-zarr-4e162aaecc691c159a5d43200b078036
 9. original-from-zarr-92b1e21aaf6078648da539f59f16adb2
 10. from-zarr-92b1e21aaf6078648da539f59f16adb2
 11. original-from-zarr-76c9fa38a30197be5753e24414e5aee7
 12. from-zarr-76c9fa38a30197be5753e24414e5aee7
 13. concatenate-c92880ca8542262e5ae5848e7d65c648
 14. mul-eed970fe66ae9b8888ccb4b1c1830a65
 15. sum-3e71a2b5934567d7f2ad5eb642933a5f
 16. sum-aggregate-90c144806d5ebbe9efba7664c1927918
 17. truediv-4636ae24fac9a860e7a1882e7dd30811
  1. Creating the data with da.random.random(shape, chunks=chunks) without saving/loading them: This works perfectly fine, memory usage per worker is basically always < 2GiB, dashboard looks clean.
MWE
from dask.distributed import Client
import dask.array as da

if __name__ == '__main__':
    # Create Dummy data
    shape = (365, 1280, 2560)
    chunks = (10, -1, -1)

    client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")

    # Create data
    arrs = [da.random.random(shape, chunks=chunks) for _ in zarr_paths]

    # Concatenate and calculate weighted average
    arr = da.concatenate(arrs, axis=0)
    weights = da.ones_like(arr, chunks=arr.chunks)
    avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2))  # da.average behaves identically
    print(avg.dask)
    print(avg.compute())
Dask graph
HighLevelGraph with 13 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f84bb681280>
  0. ones_like-b2eeb017d903fe0162dbf9fabb034160
  1. sum-bb0b49a151ee1ec3649751f9df4dc22d
  2. sum-aggregate-a4961dcbcb818a25426264c800426259
  3. random_sample-b773ea283c4c668852b9a00421c3055d
  4. random_sample-ef776b5e484b45a0e538b513523db6cb
  5. random_sample-87da63b5fee3774ddb748ed498c9ec89
  6. random_sample-e37dcf6cb5d6d272ca5ab3f1d1c2df00
  7. random_sample-32b9374ec13f6972c0f19c85eb6fd784
  8. concatenate-f892adc0c62d53897e0e5293ecdbb26e
  9. mul-bcbc2daf6a3f04985c3af0a5ad0d7d5d
  10. sum-ddf13f4205911e5bb7cbfe019e1c7f4d
  11. sum-aggregate-e8835cb74a96aed802fd656f976893b5
  12. truediv-373a8a1ea714c9051d5ea535ae246621

This behavior is more or less independent from the scheduler parameters and the actual calculation that is performed. I also tested this with a dask_jobqueue.SLURMCluster and found the same.

I would appreciate any help on this! Thank you so much!

I @schlunma, welcome to Dask Discourse forum!

I can confirm that I reproduce the issue you highlight by executing your code. I’m not sure why, but something in your resulting graph triggers the whole data loading into memory.

However, just changing a bit the code, using:

sum_res = da.sum(arr * weights, axis=(1, 2)).compute()
weight_res = da.sum(weights, axis=(1, 2)).compute()
avg_res = sum_res / weight_res

It works well and stream the computation.

Also, you say in a commented line

da.average behaves identically

If I replace your code with da.average(arr, axis=(1,2)), I see no memory problem either.

Hi @guillaumeeb, thanks for your quick response!

Sorry, I should have given more details. I also get the problematic behavior with a weighted average, i.e.,

avg = da.average(darr, axis=(1, 2), weights=weights)

I can also confirm that the code snippet you sent works well. Unfortunately, I cannot really implement this directly since my code uses other libraries as an intermediate layer (mainly Iris) rather than using Dask directly. In addition, since the basic multiplication operation is affected, the problem is not limited to weighted averages. Do you have any idea why this adapted code snippet works well?

Would you have a suggestion on how to proceed? To me, this looks like a bug in Dask (distributed). Would it make sense to open an issue there?

Thanks for your help!

Just because I simplified the tasks graph by computing it in two different steps, which is OK for this simple case.

I did some more tests today, and I confirm this behavior. I tried to make the test smaller to debug, and saw that if I read just one Zarr file, then the problem disapears. I’m not sure if it comes from the concatenate step? I really don’t know what the problem could be, as the Dask graph looks good and trying to optimize it doesn’t change anything.

So I’m at a loss, and I would recommend to open an issue, this computation should be entirely streamable.

Thanks @guillaumeeb for your further tests!

I just opened an issue here: Reading data from zarr or netcdf files loads all data into memory with distributed scheduler · Issue #8969 · dask/distributed · GitHub