Manage garbage collection of Workers

Problem Description

Hello people, I want to parallelise - what I thought is - a simple task. I have a lot of geographical datasets with global coverage (each with a size between 3 and 4 GB) which I want to crop in smaller parts (tiles), resample to a slightly different resolution and save to disk as netcdf. With this task to solve I came across dask’s Client() (which creates a local cluster in my case) and the delayed() function. I understood it’s workings in a way that it “calculates” processing graphs of what ever is inside delayed() and if appended to a list task_list would give me a list of tasks which are in the end processed in parallel when I pull the trigger with task_list.compute(). I set-up my code iteratively so that 100 subsets are created per global dataset which means 100 delayed tasks per iteration. But when the computation starts I noticed through the dask dashboard (what an amazing tool) that a lot of unmanaged memory gets accumulated. After doing some research of this topic I stumbled upon a coiled blog post which mentions garbage collection: “If calling gc.collect() on your workers makes your unmanaged memory drop, you should investigate your data for circular references and/or tweak the gc settings on the workers through a Worker Plugin”. To test this, I implemented gc.collect() to the function which cuts and saves the subset manually (data_crop()). The amount of unmanaged memory dropped significantly (see animations of runs with and without manual garbage collection), but manually garbage collection caused another problem. After round a bout 5 datasets I get the warning: “2024-06-25 11:40:01,186 - distributed.utils_perf - WARNING - full garbage collections took 12% CPU time recently (threshold: 10%)” which gets worse after every dataset (so after processing 100 tasks) and after processing 36 of them I end up with ~ 33 % CPU time linked to garbage collection. However, the coiled blog post from above also mentions Worker PlugIns to fix this in case of a significant drop in unmanaged memory when calling garbage collection manually and refers to a dask support page about Worker Plugins. Sadly, there is no example for a Worker Plugin related to garbage collection and while the How-to is probably easy for anyone into parallel programming I don’t even know where to start. I’m still not 100 % sure if a Worker Plugin is neccessary in my case or if it is a conceptual problem in my code. Thanks in advance for any help and hints on how to deal with this issue.

Code

from numpy import arange, nan
import xarray as xr
import rioxarray as rxr
import rasterio as rio
from rasterio.enums import Resampling
from pathlib import Path
from dask.distributed import Client
from dask.diagnostics import ProgressBar
from dask import delayed, compute
from re import findall
import gc


# this function is later wrapped inside `delayed()`
def data_crop(
    data,
    ds_meta: dict,  # is only used for the output filenaming convention
    extent: tuple,  # (minx, miny, maxx, maxy)
    target_res: float,
    resampling,
    save_dir=None
    ):
    minx, miny, maxx, maxy = extent
    exp_extent = (minx - 2 * target_res, miny - 2 * target_res, maxx + 2 * target_res, maxy + 2 * target_res)
    trg_grid = create_trgraster(exp_extent, res=target_res, trg_name=generate_prefix(extent))
    data = data.rio.clip_box(*exp_extent)
    data = data.rio.reproject_match(trg_grid, resampling=resampling)
    data = data.rio.clip_box(*extent)
    gc.collect()  # gc only frees memory if it is executed in this function

    if save_dir != None:
        if Path(save_dir).exists():
            data.to_netcdf(Path(save_dir, set_filename(ds_meta, extent, target_res)))
        else:
            raise FileNotFoundError(f"directory {save_dir} does not exist!")
    else:
        return data


def tile_processor(
    data,
    aoi: tuple,
    tile_size: float,
    ds_meta: dict,
    target_res: float,
    resampling,
    mode: int,
    save_dir=None
    ):
    t = generate_subtiles(aoi, tile_size=tile_size)  # list of all tile extents in format (minx, miny, maxx, maxy) as tuple
    tasks = []

    for tile in t:  # iterate through all tile extents
        task = delayed(data_crop)(data, ds_meta, tile, target_res, resampling, save_dir)
        tasks.append(task)

    if mode == 0:
        with ProgressBar():
            compute(*tasks)
    elif mode == 1:
        return tasks


def bulk_processor(
    files: list,  # list of file paths
    aoi: tuple,  # global extent of the tiling grid
    tile_size: float,  # size of each tile in degrees
    target_res: float,
    ds_meta: dict,
    resampling: Resampling,
    parallel: bool,
    save_dir=None
):
    bulk_tasks = []

    for i, j in enumerate(files):
        with xr.open_dataset(j, engine="rasterio") as ds:
            ds_meta.update(ds_date = findall("\\d{8}", j)[0])

            if parallel:
                tasks = tile_processor(ds, aoi, tile_size, ds_meta, target_res, resampling, 1, save_dir)
                bulk_tasks.extend(tasks)
            elif not parallel:
                tile_processor(ds, aoi, tile_size, ds_meta, target_res, resampling, 0, save_dir)

    if parallel:
        with ProgressBar():
            compute(*bulk_tasks)


if __name__ == "__main__":

    file_dir = ...
    tile_size = .5
    trg_res = 0.003125
    write_dir = ...
    # xmin, ymin, xmax, ymax
    aoi = (-55, -14, -50, -9)

    # list files
    files = filter_date(file_dir, "2020")  # custom function which gives me only datasets from 2020

    data_meta = ... 

    client = Client(n_workers=10, threads_per_worker=1)
    client.amm.start()
    bulk_processor(files, aoi, tile_size, trg_res, data_meta, Resampling.bilinear, False, write_dir)
    client.close()

Behaviour with and without manual garbage collection

System

  • OS: Debian 12
  • Python 3.9.19
  • dask 2024.6.0

I could partly solve my issue with a custom worker plugin.

class MemoryHandler(WorkerPlugin):
    def __init__(self, memory_threshold_mb):
        self.memory_threshold_mb = memory_threshold_mb

    def setup(self, worker):
        self.process = psutil.Process(os.getpid())

    def transition(self, key, start, finish, *args, **kwargs):
        if finish == "memory":
            memory_info = self.process.memory_info()
            unmanaged_memory_mb = memory_info.rss / (1024 * 1024)

            if unmanaged_memory_mb >= self.memory_threshold_mb:
                gc.collect()

If a process exceeds the pre-defined max memory threshold gc.collect() is executed. This works well for most of my pipelines processing time. However, it can happen that a worker’s memory stays just below the threshold which causes the plugin to run basically after every new task. As a result this worker gets really slow causing the garbage collection CPU time warning. Does anyone else have a better idea?

Hi @dkin, welcome to Dask Discourse forum!

Did the memory get release automatically at some point? If you do nothing and let Python handle it, do you encounter memory errors?

If so, you probably have a problem in your code, some still referenced data in memory that you do not need anymore. You should try to find that or the library causing the problem.

I cannot access the images you provided.

Updated Links

Did the memory get release automatically at some point? If you do nothing and let Python handle it, do you encounter memory errors?

At some point, yes, but usually too late and from what I’m experiencing, it’s random. I also thought about xarray being the problem, as despite calling close() or using a context manager, the dataset remains in memory. If this is the case, I don’t understand why manually triggering gc.collect() helps. However, I can’t find a suitable solution for that problem, and the current state is more of a workaround than a real fix, so any help is still highly appreciated.

One thing I notice while taking a closer look at your code is that you mix Xarray usage on client side and Delayed tasks on this Xarray loaded Data.

How many file do you have, and how many tiles per file?

Files a couple of hundred and tiles 100 per file (at the moment but likely 1000 per file in the future).

I think it would be better to try to either parallelize only by tile, or only by file. Process files sequentially and parallelize on the tile using Xarray with Dask backend should have the lower footprint.

That’s what I already have. I iterate sequentially through the files and “create” each tile as a delayed process which I compute once every tile is a delayed object.

What I mean is that you open an Xarray dataset on each file on client side, create subtiles from this Dataset, and submit them as Delayed call. This looks like a complex mix.

Do you mean submitting all to-be-processed tiles combined? So, let’s say 100 tiles per file with 10 files to be processed would give a combined task list of 1000 delayed tasks, which are then computed.

I was under the im^pression this is what you were doing already. So no, I was saying not creating one Delayed per tile, but only one Delayed per file.