Climate Dataset : xarray & dask repeated fail

Hello,

I’m kind of new to Dask despite I used xarray for few months now. Currently, I’m struggling processing a 10.5 GB dataset of 1/2 hourly data I need to sum on a hourly base :

Opening the dataset...
Total number of time steps: 346224
Size of each processing chunk: 346 time steps
<xarray.Dataset>
Dimensions:           (time: 346224, lat: 90, lon: 90)
Coordinates:
  * time              (time) datetime64[ns] 2002-01-01T00:15:00 ... 2021-09-3...
  * lat               (lat) float32 -18.53 -18.51 -18.49 ... -16.74 -16.72
  * lon               (lon) float32 -150.4 -150.4 -150.4 ... -148.5 -148.5
Data variables:
    precipitationCal  (time, lat, lon) float32 dask.array<chunksize=(346224, 10, 10), meta=np.ndarray>

In short, I need to run something like:

hourly_data = data.resample(time="1H").sum()

I tried so many scripts the last three days it would be to long to list ; but in short:

  • chunck over time only (1000 ; 100) ; then time/lat/lon ; then lat/lon only (30 ; 10)
  • automatic LocalCluster / defined Cluster : 2, 4, 9 workers / memory 4, 2, 1 GiB
  • loop over time indexes via '.isel(start_idx, end_idx) : record result as NetCDF ; roll over whole dataset

My PC isn’t that bad : 16G RAM, 20 cores and runs under Ubuntu 22.04.

Nevertheless, I can’t get the RAM freed after each iteration of the loop despite using .close() - hence the RAM tends to fill up, and it eventually crashes.

I’m really out of option ; I can’t see why I cannot get xarray & Dask to manage this dataset and process it as intended.

I would very gladly use a little help. Thank you :blush:


Would like more information ? The current script I try to run:

from dask.distributed import Client, LocalCluster

# Create a Dask cluster with specific settings
cluster = LocalCluster(
    n_workers=9,           # Number of workers
    memory_limit='1GiB'     # Memory limit per worker
)

# Connect to the Dask cluster
client = Client(cluster)

# Print the Dask dashboard link
print(f"Dask dashboard: {client.dashboard_link}")

import os
import xarray as xr
import numpy as np

# Define chunk size (adjust based on available memory and dataset size)
chunk_size = {
    'lat': 10,
    'lon': 10
}

# Define IMERG file path and prefix
imerg_path = '/media/pmauger/LaCie/01_DATA/IMERG'
file_name = "IMERG_2002_2021_INTERPOLATED_RAW.nc"

# Output directory for processed chunks
output_dir = '/media/pmauger/LaCie/01_DATA/IMERG/chunks'

# Open the dataset
print("Opening the dataset...")
imerg = xr.open_dataset(os.path.join(imerg_path, file_name), chunks=chunk_size)

# Test
#imerg = imerg.isel(time=slice(0, 480))

# Calculate the total number of time steps
num_time_steps = len(imerg.time)
print(f"Total number of time steps: {num_time_steps}")

# Calculate the size of each processing chunk (approximately one-tenth of total time steps)
chunk_size_proc = num_time_steps // 1000
print(f"Size of each processing chunk: {chunk_size_proc} time steps")

# Loop through the dataset in chunks
for i in range(0, num_time_steps, chunk_size_proc):
    # Define the start and end indices for the current chunk
    start_idx = i
    end_idx = min(i + chunk_size_proc, num_time_steps)
    print(f"Processing chunk from index {start_idx} to {end_idx - 1}...")

    # Generate output file name based on chunk index
    output_file_path = os.path.join(output_dir, f'chunk_{start_idx}_{end_idx}.nc')
    print(f"Saving processed chunk to: {output_file_path}")
    
    # Resample to hourly data (sum 2 * 1/2 hourly)
    imerg_chunk = imerg.isel(time=slice(start_idx, end_idx)).resample(time="1H").sum()

    # Save the processed chunk to a NetCDF file
    imerg_chunk.to_netcdf(output_file_path)

# List all processed chunk files
processed_files = sorted([os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.startswith('chunk_')])

print(f"Processed {len(processed_files)} chunks.")

# Concatenate processed chunks along the time dimension using Dask
final_dataset = xr.open_mfdataset(processed_files, combine='nested', concat_dim='time')

# Save the final concatenated dataset to a NetCDF file
output_final_path = os.path.join(imerg_path, 'final_dataset.nc')
print(f"Saving final dataset to: {output_final_path}")
final_dataset.to_netcdf(output_final_path)

print(f'Final dataset saved: {output_final_path}')

It runs perfectly if I un-mute the ‘slicing’ of original imerg dataset to 0:480 time-steps. Nevertheless I can see the memory increasing at each iteration ; so I cannot run it over almost a thousand loops, which would be ridiculous.

Hi @Arty, welcome to Dask community!

From what I understand, you shouldn’t have to do any loops, but just rely on Xarray and Dask for streaming your data and resampling it. In your case, you should use chunks that take the whole time dimension, and so chunk only on lat and lon, you need time series.

What happens with a simple Xarray code like you proposed at first and this kind of chunking scheme?

When I don’t use loop, the script eventually crashes due to memory space : the time-resampling operation seem to break the time-chunking ; hence I tried to chunk only on lat/lon, but it didn’t get any better : eventually the script also crashes due to memory failure (I clearly see in the system monitor that I reach 100% load).

When I try with smaller datasets (like 10th of original size, time-sliced), it works fine though.

My comprehension of Dask was indeed that it should handle such computation.

I finally managed to get something to work : I split the dataset in 48 (equal parts) and load each part in the loop, compute resampling/summing for each part that I record ; and then I combine all those NetCDF chunks into one.
By the way : it’s way quicker this way ; but I get failure when I compute larger dataset parts (like 1/24th).

So, I’m still wondering what happens and what am I doing wrong.

I did some more tests and still cannot understand where the problem could come from :

When computing over 1 year worth of data, runtime is 24 secs ; when computed over the whole dataset (20 years), it takes… 6404 seconds (almost 2 hours !) … With exactly the same local cluster and chunks configurations.
I noticed I got good performances for 1, 2, 3 years in a row (77 seconds for the latter), but when I increase the dataset (time)size, it starts to take a disproportionate time to compute.

I succeeded running the simple command .coarsen(time=2).sum() ** over the whole dataset without my computer crashing though ***. But I think there still is something weird and that it’s not normal such “easy” computation takes so long even though the dataset can be considered “big” for a laptop.

I am currently running another test over 15 years **** : already 20minutes in and I can barely ear the fans of the laptop : cores are running slow, mainly under 10%, as I can see from system monitor. Memory is 50% free so clearly not overloaded.

I’d really like to better understand what’s going on to avoid further issues as I’m facing right now with extensive runtimes. Feel free to propose any tests I could run which would help to understand what is happening.

Thank you

** which yields the same result as .resample(time=‘1H’).sum() except I’ve to shift the time-stamps afterward
*** chunks : 100, 30, 30 / 9 workers, 1.2GiB each
**** chunks 1, 5, 5 / 9 workers, 1.2GiB each

This does not account for merging the data to one file (and time-stamps shifting), however the script below manages to get the resampling done in less than 5 minutes for 15 years worth of data :

# Iterate over years from 2002 to 2016
for year in range(2002, 2017):
    print(f'RUNNING: {year}')
    
    # Select year
    imerg_year = imerg.sel(time=imerg['time.year'] == year)

    # Resample the data to hourly with .coarsen
    coarse = imerg_year.coarsen(time=2).sum()
    
    # Specify the path to save the new NetCDF file in the same directory as the original files
    output_filename = 'IMERG_{year}_INTERPOLATED_d02_HOURLY.nc'
    output_file_path = os.path.join(imerg_path, output_filename)
    
    # Save the new dataset to the NetCDF file
    coarse.to_netcdf(output_file_path)
    
    # Print confirmation message
    print(f'Dataset for year {year} has been saved to: {output_file_path}')

Okay, I’m currently guessing this comes from the initial chunking of your data, which is probably by date in netcdf files. If you want to chunk with a single time chunk, the first operation is probably to gather all the files and rechunk the array, which sounds simple, but might be really expensive in memory. This is why some packages such as rechunker have been created.

You might also get a better answer on Pangeo forum.

Thanks Guillaume. I’ll make some more tests.

Just to let you know, I installed flox (which I should have done earlier, my bad…) and it really improves the resampling process - as advertised. Nevertheless, I still have some issues when the dataset’s size increases.

For example, I ran a few tests with (I didn’t go further than 11 years because I’ve been disconnected from Jupyter Lab several times at this point) :

client = Client(n_workers=6, memory_limit='2GiB')
chunks={'time': 2*8760, 'lat': 5, 'lon': 5}

With another client configuration:

client = Client(n_workers=4, memory_limit='3GiB')

We can clearly observe a break within the runtime from 8 to 9 years and above from both client configurations ; but similar runtimes. This raises a question though: I don’t get how the definition of workers number / memory affect the runtime ? I would have expected the more the workers, the smaller the runtime ; but it’s not the case. And I can also see from the system monitor that all CPUs are running high during both tests.

Out of curiosity, I ran 2 other tests:

client = Client(n_workers=1, memory_limit='12GiB')

And:

client = Client(n_workers=2, memory_limit='6GiB')

Even though it would more robust if I had run these tests several time to average them, the difference made by cluster configuration does not seem too important (except for when using only 1 worker).

I’ll try using rechunker then.

I would also appreciate if you had advice about ‘interp’ function: are there package suited to improve Dask performance for this type of computation ?

For all those questions, it would be nice to have advices from experts: cc @dcherian @rabernat!