How to parallelize several loops on huge climate datasets using dask.delayed

Hello,

I’ve been struggling on that matter for more than a week now and have tried so many things that I finally abdicated and came here to ask for your help.

Context: I’ve got outputs from several WRF model configurations I want to compare to meteorological station data. Though I succeeded creating a script that gets me where I want, it runs slow : 16 configurations to be compared to 25 stations on 7 variables.
Note: Outputs consist in monthly files of hourly data over 90x90 grid cells, 36 months of it for each configurations. When fully loaded, a 3-years hourly dataset is around 6GB.

So, of course, I thought of Dask to help parallelize the script in order to, for example, process several configurations simultaneously. I read and watched many tutorials on how to use dask.delayed… Unfortunately, example cases are usually rather simple and not any close to what I’m willing to achieve here.

To shortly describe my current process (not using dask.delayed), have a look below:

for config in configs:
    lazy load config_ds
    preprocess config_ds to fit xarray/pandas standards

    for station in stations:
        load station_ds 
        select config_ds grid-point nearest to station_ds coordinates

        for variable in variables:
            select config_ds[variable(times)] according to station_ds[variable(times)] gives valid_times ; if valid_times = 0: continue
            process metrics comparing config_ds[variable(valid_times)] against station_ds[variable(valid_times)] 
See actual script
# Main Part

# Dictionary to map MF variable names to WRF variable names
var_mapping = {
    'RR1': 'TOTAL_RAIN',
    'FF': 'SPDUV10MEAN',
    'T': 'T2MEAN',
    'PSTAT': 'PSFC',
    'GLO': 'GHI',
    'U10': 'U10MEAN',
    'V10': 'V10MEAN'
}

def group_configurations(directories):
    """Group configuration directories based on their prefixes and last parts."""
    grouped_configurations = {}
    for directory in directories:
        parts = directory.split('_')
        prefix = "_".join(parts[:-1])  # Extract the prefix
        last_part = parts[-1]  # Extract the last part (A or D)

        if prefix not in grouped_configurations:
            grouped_configurations[prefix] = {'A': None, 'D': None}
        grouped_configurations[prefix][last_part] = directory
        
    return grouped_configurations

def load_wrf_dataset(config_dir, timestep, sample_length):
    """Lazy load and preprocess WRF datasets from the specified configuration directory."""
    # List and sort files based on the timestep
    files_in_dir = os.listdir(config_dir)
    wrf_files = sorted([os.path.join(config_dir, f) for f in files_in_dir
                        if timestep in f and f.endswith('.nc')])
    
    # Lazy load and concatenate WRF datasets along 'Time' dimension
    wrf_dataset = new_time_axis(xr.open_mfdataset(wrf_files[sample_length], combine='nested', concat_dim='Time',
                                                  chunks={'south_north': 1, 'west_east': 1, 'Time': None}))
    
    # Preprocess the dataset (assign coordinates, drop unwanted variables, etc.)
    wrf_dataset = wrf_dataset.assign_coords(latitude=("south_north", latitude_norm), longitude=("west_east", longitude_norm))
    wrf_dataset = wrf_dataset.drop_vars(["XLAT", "XLONG", "GHIC", "Q2MEAN", "Times", "idx", "sz"])
    wrf_dataset = wrf_dataset.rename({'Time': 'time'}).set_coords('time')
    wrf_dataset = wrf_dataset.swap_dims({'south_north': 'latitude', 'west_east': 'longitude'})
    wrf_dataset = wrf_dataset.chunk({"time": -1})  # Set appropriate chunking for time
    
    return wrf_dataset

def compute_metrics(wrf_data, mf_data, mf_var):
    """Compute various metrics comparing WRF data and MF data."""
    if mf_var == 'T':
        wrf_data -= 273.15  # Convert Kelvin to Celsius
    elif mf_var == 'GLO':
        wrf_data *= 1000  # Convert kWh/m² to Wh/m²
    elif mf_var == 'PSTAT':
        wrf_data /= 100
        mf_data /= 100
        
    # Compute metrics
    mf_avg = np.mean(mf_data)
    wrf_avg = np.mean(wrf_data)
    mbe = wrf_avg - mf_avg
    mae = np.mean(np.abs(wrf_data - mf_data))
    rmse = np.sqrt(np.mean((wrf_data - mf_data) ** 2))
    correlation = np.corrcoef(mf_data, wrf_data)[0, 1]
    mf_std = np.std(mf_data)
    wrf_std = np.std(wrf_data)
    diff_std = wrf_std - mf_std

    return {
        'mf_avg': mf_avg,
        'wrf_avg': wrf_avg,
        'mbe': mbe,
        'mae': mae,
        'rmse': rmse,
        'correlation': correlation,
        'mf_std': mf_std,
        'wrf_std': wrf_std,
        'diff_std': diff_std
    }

def main(directories_to_iterate, timestep, sample_length, mf, variables, wrf_path):
    """Main function to orchestrate the loading and processing of WRF and MF datasets."""
    grouped_configurations = group_configurations(directories_to_iterate)

    metrics_data = []

    for i, (prefix, config_data) in enumerate(grouped_configurations.items()):
        for last_part, configuration in config_data.items():
            if configuration is None:
                continue

            print(f'RUNNING: {configuration}')
            config_dir = os.path.join(wrf_path, configuration)
            
            # Load the WRF dataset
            wrf_dataset = load_wrf_dataset(config_dir, timestep, sample_length)

            for station in mf.station.values:
                lat_mf = mf.sel(station=station).latitude.values
                lon_mf = mf.sel(station=station).longitude.values
                station_data = mf.sel(station=station)

                print(f'\tRUNNING: {station}')

                # Find the nearest grid point to the station
                wrf_nearest = wrf_dataset.sel(latitude=lat_mf, longitude=lon_mf, method='nearest')

                for mf_var in variables:
                    # Filter out times where the current variable is NaN in the MF dataset
                    valid_times = station_data[mf_var].dropna(dim='time').time
                    if len(valid_times) == 0:
                        continue

                    # Select only the valid times
                    wrf_time_filtered = wrf_nearest.sel(time=valid_times)

                    # Extract the corresponding WRF variable and MF data
                    wrf_data = wrf_time_filtered[var_mapping[mf_var]].values
                    mf_data = station_data[mf_var].sel(time=valid_times).values
                    
                    # Compute metrics
                    metrics = compute_metrics(wrf_data, mf_data, mf_var)

                    # Append metrics to the list
                    metrics_data.append({
                        'station': station,
                        'configuration': configuration,
                        'variable': mf_var,
                        **metrics,
                        'valid_times': len(valid_times)
                    })

    # Create DataFrame from the list of dictionaries
    metrics_df = pd.DataFrame(metrics_data)
    return metrics_df

# Assuming you have already defined the necessary variables
metrics_df = main(directories_to_iterate, timestep, sample_length, mf, variables, wrf_path)

Using dask.delayed

I had plenty of different issues when trying to adapt the script using dask.delayed, but the most common was (and still is) memory management issue. I must say I’m confused because the goal here is to only load time-series (grid-cell x variables) of the model datasets in order not to overload the memory. Nevertheless, when running some dask.delayed-adapted scripts, either it runs very slow (=inefficiently), or it eventually crashes after blowing the memory up.
Please see below my last-attempt-script, which runs under a Dask Local Cluster (8 workers / 1GiB each - I tested other configuration like : 4W / 2 GiB, 2W / 5 GiB, …) ):

Trying to use dask.delayed script
# Dictionary to map MF variable names to WRF variable names
var_mapping = {
    'RR1': 'TOTAL_RAIN',
    'FF': 'SPDUV10MEAN',
    'T': 'T2MEAN',
    'PSTAT': 'PSFC',
    'GLO': 'GHI',
    'U10': 'U10MEAN',
    'V10': 'V10MEAN'
}

def group_configurations(directories):
    """Group configuration directories based on their prefixes and last parts."""
    grouped_configurations = {}
    for directory in directories:
        parts = directory.split('_')
        prefix = "_".join(parts[:-1])
        last_part = parts[-1]

        if prefix not in grouped_configurations:
            grouped_configurations[prefix] = {'A': None, 'D': None}
        grouped_configurations[prefix][last_part] = directory
        
    return grouped_configurations

def load_wrf_time_series(config_dir, timestep, sample_length, lat, lon, latitude_norm, longitude_norm):
    """
    Load and preprocess time series data for all relevant variables at a specific grid cell.
    
    Parameters:
    - config_dir: Path to the WRF files.
    - timestep: Timestep indicator in the filenames.
    - sample_length: Slice object indicating the files to load.
    - lat, lon: Latitude and longitude of the location to load.
    - latitude_norm, longitude_norm: Normalized latitude and longitude coordinates for preprocessing.
    
    Returns:
    - wrf_data: Lazy-loaded data array for the specific time series and variable.
    """
    # List and sort the relevant files based on timestep
    files_in_dir = os.listdir(config_dir)
    wrf_files = sorted(
        [os.path.join(config_dir, f) for f in files_in_dir if timestep in f and f.endswith('.nc')]
    )[sample_length]

    # Lazy load and concatenate WRF datasets along 'Time' dimension
    wrf_dataset = new_time_axis(
        xr.open_mfdataset(wrf_files, combine='nested', concat_dim='Time',
                          chunks={'south_north': 1, 'west_east': 1})
    )

    # Preprocess the dataset (assign coordinates, drop unwanted variables, etc.)
    wrf_dataset = wrf_dataset.assign_coords(
        latitude=("south_north", latitude_norm), longitude=("west_east", longitude_norm)
    )
    wrf_dataset = wrf_dataset.drop_vars(["XLAT", "XLONG", "GHIC", "Q2MEAN", "Times", "idx", "sz"], errors='ignore')
    wrf_dataset = wrf_dataset.rename({'Time': 'time'}).set_coords('time')
    wrf_dataset = wrf_dataset.swap_dims({'south_north': 'latitude', 'west_east': 'longitude'})

    # Select the required variables and location, retaining only necessary data
    wrf_data = wrf_dataset.sel(
        latitude=lat, longitude=lon, method='nearest'
    ).load()  # Load to reduce memory usage

    return wrf_data

def compute_metrics(wrf_data, mf_data, mf_var):
    """Compute various metrics comparing WRF data and MF data."""
    if mf_var == 'T':
        wrf_data -= 273.15  # Convert Kelvin to Celsius
    elif mf_var == 'GLO':
        wrf_data *= 1000  # Convert kWh/m² to Wh/m²
    elif mf_var == 'PSTAT':
        wrf_data /= 100
        mf_data /= 100
        
    # Compute metrics
    mf_avg = np.mean(mf_data)
    wrf_avg = np.mean(wrf_data)
    mbe = wrf_avg - mf_avg
    mae = np.mean(np.abs(wrf_data - mf_data))
    rmse = np.sqrt(np.mean((wrf_data - mf_data) ** 2))
    correlation = np.corrcoef(mf_data, wrf_data)[0, 1]
    mf_std = np.std(mf_data)
    wrf_std = np.std(wrf_data)
    diff_std = wrf_std - mf_std

    return {
        'mf_avg': mf_avg,
        'wrf_avg': wrf_avg,
        'mbe': mbe,
        'mae': mae,
        'rmse': rmse,
        'correlation': correlation,
        'mf_std': mf_std,
        'wrf_std': wrf_std,
        'diff_std': diff_std
    }

def process_station_data(config_dir, timestep, sample_length, station_data, variables, station, configuration, latitude_norm, longitude_norm):
    """Process all variables for a single station under a specific configuration."""
    metrics = []

    # Extract MF coordinates
    lat_mf = station_data.latitude.values
    lon_mf = station_data.longitude.values
    
    # Load and preprocess WRF data for the specific grid cell
    wrf_data_series = load_wrf_time_series(
        config_dir, timestep, sample_length, lat_mf, lon_mf, latitude_norm, longitude_norm
    )

    for mf_var in variables:
        valid_times = station_data[mf_var].dropna(dim='time').time
        
        # Skip processing if no valid times
        if len(valid_times) == 0:
            continue

        # Filter WRF data by valid times
        wrf_time_filtered = wrf_data_series.sel(time=valid_times)

        # Extract MF data
        mf_data = station_data[mf_var].sel(time=valid_times).values

        # Compute metrics
        metrics_dict = compute_metrics(wrf_time_filtered.values, mf_data, mf_var)

        # Append metrics to the list
        metrics.append({
            'station': station,
            'configuration': configuration,
            'variable': mf_var,
            **metrics_dict,
            'valid_times': len(valid_times)
        })

    return metrics

def main(directories_to_iterate, timestep, sample_length, mf, variables, wrf_path, latitude_norm, longitude_norm):
    """Main function to orchestrate the loading and processing of WRF and MF datasets."""
    grouped_configurations = group_configurations(directories_to_iterate)
    metrics_data = []

    # Define a list to hold the delayed tasks
    tasks = []

    for prefix, config_data in grouped_configurations.items():
        for last_part, configuration in config_data.items():
            if configuration is None:
                continue

            print(f'RUNNING: {configuration}')
            config_dir = os.path.join(wrf_path, configuration)

            # Create a delayed task for processing all stations under this configuration
            tasks.append(
                dask.delayed(process_stations)(
                    config_dir, timestep, sample_length, mf, variables, configuration, latitude_norm, longitude_norm
                )
            )

    # Compute all tasks in parallel
    results = dask.compute(*tasks)

    # Flatten the results into metrics_data list
    for result in results:
        metrics_data.extend(result)

    # Create DataFrame from the list of dictionaries
    metrics_df = pd.DataFrame(metrics_data)
    return metrics_df

def process_stations(config_dir, timestep, sample_length, mf, variables, configuration, latitude_norm, longitude_norm):
    """Process all stations for a given configuration."""
    metrics = []
    
    # Create delayed tasks for processing each station
    for station in mf.station.values:
        station_data = mf.sel(station=station)
        print(f'\tRUNNING: {station}')

        # Process data for the station
        station_metrics = process_station_data(config_dir, timestep, sample_length, station_data, variables, station, configuration, latitude_norm, longitude_norm)
        metrics.extend(station_metrics)

    return metrics

metrics_df = main(directories_to_iterate, timestep, sample_length, mf, variables, wrf_path, latitude_norm, longitude_norm)
'south_north': 1, 'west_east': 1

Some more insight into my problem : when I run the first (sequential) script, the memory usage never goes higher than like 300 Mb ; when using Dask.delayed adapted script, I usually go up to 10-11 Gb … That is one of my main mystery here…

I must say that I really want to understand how to properly use the full potential of Dask, not only for me but also or some young colleagues of mine so that I can teach them some tips for their own work. That is why I struggled for a more week without doing anything else. But I got (loads of) other work to do and I can’t spend much more time on that ; so I hope someone will care to help here.

Thanks a lot :pray:

I would recommend using xarray with a dask cluster to do work like this.

Here’s an example notebook which extracts a time series at a specified lon,lat location from a massive model output.

i would use dask.delayed in that case.
each loop should have a diff delayed.

Thanks for the link/notebook. I read it ; but it uses remote data and clusters while I’m using local data and cluster.

I’m willing to understand how to best use the performance of my own PC using Xarray & Dask. I’d be glad if you had time to look at the second (hidden) script that is an dask.delayed adpatation of the first one (sequential). Both works on local cluster according to the specs I mentionned (edited) in the main post.

Thanks for your time.

I use the same pattern for local data and local cluster. That’s the great thing about using xarray, dask and fsspec!

Thank you. I already tried this (please see the hidden dask.delayed script in the main post). My problem is this script using dask.delayed either runs slow or fail due to memory management issue.

I beg your pardon if I’ve missed something, but I don’t see any call to the dask.delayed function in your notebook. I’m trying to figure out:

How can I properly use the dask.delayed utility on a local cluster?

As mentioned in the main post (and in the attached scripts, though they’re currently hidden), even though I’m using lazy loading/computation and chunking, I’m still running into memory overload issues. My challenge is this: I don’t understand why the sequential script only uses around 300 MB of memory, while the parallelized version uses up to 400 times that amount.

If you have a moment to review my scripts, I’d really appreciate any insights you might have.

Hi @Arty, I’ve got several remarks and questions.

What is slow in this case? How much time for a couple (configuration, station)?

So 3-years is the complete dataset for a model configuration? From what I see in your code, you never load all this data.

Your initial code is based on Xarray, with a lazy loading mechanism (which already uses Dask Array under the hood) through the use of open_mfdataset and chunks kwarg. I think this is what @rsignell was suggesting, you already do this. But since you have several configurations and stations, you iterate through Models, and then stations to select inside the model Xarray Dataset. I’m not an Xarray expert, but I think your data is actually loaded at this point:

Which means already completely filtered data.

Your distributed code introduces Delayed, and the use of a LocalCluster. First, it means that Xarray with Dask Array backend will also make use of this Cluster, and this translates to tasks from tasks, which might be hard to manage. But more than that, I see two main changes in your Delayed code:

  • You create the Xarray Dataset in process_station_data, instead of outside the first for loop on Config. So this creation is duplicated for each station.
  • You use .load() (to reduce memory usage ??) on data selected only by spatial coordinates and not time nor variables. This means all the data for those coordinates will be loaded, which might be more than you really need.

I think all that explain why you go from 300Mb to 10Gb. Since you’ve got 8 workers, it would already mean 8x more memory, but you also load more data in each Delayed task.