Hello!
I am working with climate model output (mostly in netcdf format) and came across some issues when reading those data and calculating averages with a distributed scheduler. Looking at the Dask dashboard, I find that basically all the data is read into memory at some point. Doing the same with the default (threaded) scheduler seems to work fine. Creating data directly with dask.array
instead of reading from files also seems to work fine (with a distributed scheduler).
Here is a screenshot of an example dashboard:
I investigated different setups (all use ~45 GiB of data, and a LocalCluster
with 2 workers à 4 GiB).
- Netcdf data read with
xarray
: The workers’ memory keeps filling up and eventually they start spilling to disk (see screenshot above). At some point, memory usage goes down again. No warnings/errors raised.
MWE
from pathlib import Path
from dask.distributed import Client
import dask.array as da
import xarray as xr
if __name__ == '__main__':
# Create Dummy data
shape = (365, 1280, 2560)
chunks = (10, -1, -1)
out_dir = Path("tmp_dir")
nc_paths = [
out_dir / "1.nc",
out_dir / "2.nc",
out_dir / "3.nc",
out_dir / "4.nc",
out_dir / "5.nc",
]
for path in nc_paths:
arr = xr.DataArray(
da.random.random(shape, chunks=chunks),
name="x",
dims=("time", "lat", "lon"),
)
arr.to_netcdf(path)
client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")
# Read dummy data
dss = [xr.open_dataset(p, chunks={"time": 10, "lat": -1, "lon": -1}) for p in nc_paths]
arrs = [ds.x.data for ds in dss]
# Concatenate and calculate weighted average
arr = da.concatenate(arrs, axis=0)
weights = da.ones_like(arr, chunks=arr.chunks)
avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2)) # da.average behaves identically
print(avg.dask)
print(avg.compute())
Dask graph
HighLevelGraph with 18 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f03ec0b1c70>
0. ones_like-c7c9cf4f6871a4068d6bda6f4e5141ea
1. sum-142584b6bc297f1d60f68230643654e5
2. sum-aggregate-4939d78fcafe476a8465a60e20df2c65
3. original-open_dataset-x-760bd344ae92ba60483bc3f282e1b70e
4. open_dataset-x-760bd344ae92ba60483bc3f282e1b70e
5. original-open_dataset-x-59206dd4469c13d5957b7e3c3e3da309
6. open_dataset-x-59206dd4469c13d5957b7e3c3e3da309
7. original-open_dataset-x-1bafe17e170d74a2bbae3a9518808761
8. open_dataset-x-1bafe17e170d74a2bbae3a9518808761
9. original-open_dataset-x-61de7e5d23d744435a3d0116d35e6333
10. open_dataset-x-61de7e5d23d744435a3d0116d35e6333
11. original-open_dataset-x-25d1b20a8f923f1a653cd0ef1ce5163b
12. open_dataset-x-25d1b20a8f923f1a653cd0ef1ce5163b
13. concatenate-8f55d06266aaeabe8c000c30ce2a4048
14. mul-d8a6804c719b299cad52f97707b11a86
15. sum-012c92287950bccb2aba104d10dea9fa
16. sum-aggregate-8a1d67d086a9a1d7841895d269782fd7
17. truediv-8338ef8062243673d277f81853de876d
- Zarr data read with
dask.array.from_zarr
: This case is very similar to the netcdf case. The workers’ memory keeps filling up and eventually they start spilling to disk (see screenshot above). At some point, memory usage goes down again. No warnings/errors raised.
MWE
from pathlib import Path
from dask.distributed import Client
import dask.array as da
if __name__ == '__main__':
# Create Dummy data
shape = (365, 1280, 2560)
chunks = (10, -1, -1)
out_dir = Path("tmp_dir")
zarr_paths = [
out_dir / "1.zarr",
out_dir / "2.zarr",
out_dir / "3.zarr",
out_dir / "4.zarr",
out_dir / "5.zarr",
]
for path in zarr_paths:
arr = da.random.random(shape, chunks=chunks)
da.to_zarr(arr, path)
client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")
# Read dummy data
arrs = [da.from_zarr(p) for p in zarr_paths]
# Concatenate and calculate weighted average
arr = da.concatenate(arrs, axis=0)
weights = da.ones_like(arr, chunks=arr.chunks)
avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2)) # da.average behaves identically
print(avg.dask)
print(avg.compute())
Dask graph
HighLevelGraph with 18 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7fae1a6c8980>
0. ones_like-b2eeb017d903fe0162dbf9fabb034160
1. sum-bb0b49a151ee1ec3649751f9df4dc22d
2. sum-aggregate-799a254cbdfcab0da6691f0d1f8abc3a
3. original-from-zarr-c97365cf561afe64ebf973c34a876e62
4. from-zarr-c97365cf561afe64ebf973c34a876e62
5. original-from-zarr-6d55f3c94c52e1e2179024850b15b00a
6. from-zarr-6d55f3c94c52e1e2179024850b15b00a
7. original-from-zarr-4e162aaecc691c159a5d43200b078036
8. from-zarr-4e162aaecc691c159a5d43200b078036
9. original-from-zarr-92b1e21aaf6078648da539f59f16adb2
10. from-zarr-92b1e21aaf6078648da539f59f16adb2
11. original-from-zarr-76c9fa38a30197be5753e24414e5aee7
12. from-zarr-76c9fa38a30197be5753e24414e5aee7
13. concatenate-c92880ca8542262e5ae5848e7d65c648
14. mul-eed970fe66ae9b8888ccb4b1c1830a65
15. sum-3e71a2b5934567d7f2ad5eb642933a5f
16. sum-aggregate-90c144806d5ebbe9efba7664c1927918
17. truediv-4636ae24fac9a860e7a1882e7dd30811
- Creating the data with
da.random.random(shape, chunks=chunks)
without saving/loading them: This works perfectly fine, memory usage per worker is basically always < 2GiB, dashboard looks clean.
MWE
from dask.distributed import Client
import dask.array as da
if __name__ == '__main__':
# Create Dummy data
shape = (365, 1280, 2560)
chunks = (10, -1, -1)
client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")
# Create data
arrs = [da.random.random(shape, chunks=chunks) for _ in zarr_paths]
# Concatenate and calculate weighted average
arr = da.concatenate(arrs, axis=0)
weights = da.ones_like(arr, chunks=arr.chunks)
avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2)) # da.average behaves identically
print(avg.dask)
print(avg.compute())
Dask graph
HighLevelGraph with 13 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f84bb681280>
0. ones_like-b2eeb017d903fe0162dbf9fabb034160
1. sum-bb0b49a151ee1ec3649751f9df4dc22d
2. sum-aggregate-a4961dcbcb818a25426264c800426259
3. random_sample-b773ea283c4c668852b9a00421c3055d
4. random_sample-ef776b5e484b45a0e538b513523db6cb
5. random_sample-87da63b5fee3774ddb748ed498c9ec89
6. random_sample-e37dcf6cb5d6d272ca5ab3f1d1c2df00
7. random_sample-32b9374ec13f6972c0f19c85eb6fd784
8. concatenate-f892adc0c62d53897e0e5293ecdbb26e
9. mul-bcbc2daf6a3f04985c3af0a5ad0d7d5d
10. sum-ddf13f4205911e5bb7cbfe019e1c7f4d
11. sum-aggregate-e8835cb74a96aed802fd656f976893b5
12. truediv-373a8a1ea714c9051d5ea535ae246621
This behavior is more or less independent from the scheduler parameters and the actual calculation that is performed. I also tested this with a dask_jobqueue.SLURMCluster
and found the same.
I would appreciate any help on this! Thank you so much!