Memory accumulation using client.map - how can I avoid this?

Hi, I am a new dask user, so I wanted to post to this forum to see if anyone with more experience may be able to tell me if I am missing something here. This is my first time posting in a forum like this, so any suggestions are appreciated.

Dask version: 2022.2.0
netcdf4 version: 1.5.6
numpy version: 1.20.3

Task: Take a satellite image for a large domain and split it up into small images that I like to call patches. These patches are saved in netcdfs and additional data is added to them. This is run on a lot of images- I am using dask to run things in parallel to keep the computational time down.

Issue: When I run my existing code I have set client = Client(n_workers=8, threads_per_worker=1, memory_limit='5GB') The memory on the workers themselves is reasonable for my task (< 1GB), but the main program memory (local memory?) climbs. It does not climb significantly, but for a large amount of iterations, this ends up making a difference. It climbs from about 1 GB to over 3+ GB. This ends up causing the workers to fail because they combine the local memory with their memory to reach the 5 GB limit. It seems like the workers or futures are saving something to the local memory, and I can not seem to stop it from accumulating.

What I have tried: Agressively delete variables. Implement a client.restart() and then delete variables. Neither of these work.

I’ve tried to create a minimal reproducible example (see below). This is my first time doing something like this, so I am hoping it makes sense- it seems a little ridiculous having it save a bunch of random number matrices, but hopefully it makes sense? I can share my larger code if needed, but it does involve accessing data I can not share (on a remote server). The memory usage does not get as high as my full program since I creating pseudo arrays with random numbers and I have cut out a lot of processing. However, you can see the memory increase even in this example.

Does anyone have any suggestions on what I can do to avoid the local memory increasing?

import sys, os
from datetime import datetime, timedelta
import numpy as np
from netCDF4 import Dataset
from dask.distributed import Client
import random

os.environ['HDF5_USE_FILE_LOCKING']="FALSE"

def get_list_of_dates(year, month, day, hour, minutes):
    dts = []
    for h in hour:
        for m in minutes:
            dt = datetime(year, month, day, h, m)
            dts.append(dt)
    return dts

def get_pseudo_patch_info(num_patches): #gets information needed for patches in one datetime.
    patch_info = []
    all_regions = ['northwest', 'west', 'west_north_central', 'southwest', 'east_north_central', 'south', 'southeast',
                   'central', 'northeast', 'oconus_west', 'oconus_east', 'oconus_north', 'oconus_south', 'oconus_other']
    for i in range(num_patches):
        region = random.sample(all_regions, 1)[0]
        center_lon = random.randint(-180, 180)
        center_lat = random.randint(-90, 90)
        percent_10C = random.randint(0, 100)
        percent_no_cov = random.randint(0, 100)
        info = [center_lon, center_lat, region, percent_10C, percent_no_cov]
        patch_info.append(info)
    return patch_info

def get_patch(patch_info, dt, rootoutdir): #use patch information to save the patch
    center_lon = patch_info[0]
    center_lat = patch_info[1]
    region = patch_info[2]
    percent_10C = patch_info[3]
    percent_no_cov = patch_info[4]

    NCoutdir = dt.strftime(os.path.join(rootoutdir, ('data_netcdfs/%Y/%m/%d/' + region + '/')))
    if not os.path.exists(NCoutdir):
        os.makedirs(NCoutdir)
    sat_array = np.random.randint(200, size=(200, 200))

    NCfilepath = os.path.join(NCoutdir, dt.strftime('%Y_%m_%d_%H_%M_%j_' + str(center_lon) + str(center_lat) + str(percent_10C) + str(percent_no_cov) + '.nc'))

    root = Dataset(NCfilepath, 'w', format='NETCDF4')
    root.description = 'Data Patch for Convective Initiation Model (CONUS)'
    root.center_latitude = center_lat
    root.center_longitude = center_lon
    root.center_region = region
    root.percent_60max_neg10C_greater30dBZ = percent_10C
    root.percent_60max_neg10C_noCoverage = percent_no_cov
    sat_group = root.createGroup('satellite')
    sat_group.createDimension('y', sat_array.shape[0])
    sat_group.createDimension('x', sat_array.shape[1])
    sat_data = sat_group.createVariable('Dummy', 'float32', ('y', 'x'))
    sat_data[:,:] = sat_array
    root.close()
    return NCfilepath

def add_to_nc(NCfilepath): #add additional data to the netcdf for the patch (i.e. add radar to satellite patch)
    root = Dataset(NCfilepath, 'r+')
    radar_array = np.random.randint(200, size=(200, 200))
    radargroup = root.createGroup('radar')
    radargroup.createDimension('y', radar_array.shape[0])
    radargroup.createDimension('x', radar_array.shape[1])
    rad = radargroup.createVariable('Dummy2', 'float32', ('y', 'x'))
    rad[:, :] = radar_array
    root.close()
    return


if __name__ == "__main__":

    client = Client(n_workers=8, threads_per_worker=1, memory_limit='5GB')  # Call Dask Client
    year = 2019
    month = 6
    day = 22
    hour = np.arange(0, 23, 1)
    minutes = np.arange(1, 56, 5)
    dts = get_list_of_dates(year, month, day, hour, minutes)
    rootoutdir = os.environ['PWD']+'/test/'
    print(rootoutdir)
    if not os.path.exists(rootoutdir):
        os.makedirs(rootoutdir)

    patch_info_dicts = {} #I create a dictionary where the key is the date and the item is a list of all the patches I want for that datetime. 
    for single_dt in dts:
        num_patches = random.randint(1, 210)
        dict_key = single_dt.strftime('%Y%m%d%H%M%S')
        info = get_pseudo_patch_info(num_patches)
        patch_info_dicts[dict_key] = info
        del single_dt, info, dict_key

    counter = 0
    for dt in dts: #access the patches through the datetime key and collect the patches and all info needed.
        print(counter, dt)
        dict_key = dt.strftime('%Y%m%d%H%M%S')
        patch_info = patch_info_dicts[dict_key]
        patch_res = client.map(get_patch, patch_info, dt=dt, rootoutdir=rootoutdir)
        radar_res = client.map(add_to_nc, patch_res)
        final_res = client.gather(radar_res)
        counter = counter + 1
        del patch_info, patch_res, radar_res, final_res

The starting memory of the main program is 239324 (can only upload one image as a new user) and the final is below highlighted (others in the image portray memory of the 8 workers and/or pycharm, etc.):

Thanks for the question and the reproducer @sbradshaw! From your description, it sounds like there could be a memory leak, where perhaps netCDF4 is not adequately releasing memory that’s no longer in use. There are a few solutions you could try to improve this:

  1. netCDF4 seems to have some caching options, perhaps there is an argument to turn off caching
  2. xarray has some chunking functionality that may help, especially in the get_patch function.
  3. Manually trimming memory using malloc_trim, which is particularly helpful for libraries that don’t release memory very well

Please feel free to follow up with any other questions!

Thanks @scharlottej13. I’m working with @sbradshaw and have been able to narrow down her script to:

import sys, os
import psutil
import numpy as np
from dask.distributed import Client, LocalCluster
import random
from time import sleep
import ctypes


def trim_memory():
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)


def current_memory():
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / (1024 * 1024)  # MB
    return mem


def get_patch(index, dt):
    return None


if __name__ == "__main__":
    cluster = LocalCluster(n_workers=8, threads_per_worker=1, memory_limit='5GB')
    client = Client(cluster)  # Call Dask Client
    #client = Client(n_workers=8, threads_per_worker=1, memory_limit='5GB')  # Call Dask Client
    print("Dask Dashboard: ", client.dashboard_link)
    dts = list(range(23 * 11))

    initial_mem = curr_mem = current_memory()
    print(f"Initial memory: {curr_mem}")
    print("Sleeping...")
    sleep(5)
    curr_mem = current_memory()
    print(f"Memory after sleep: {curr_mem}")
    counter = 0
    num_patches = 210
    for dt in dts:
        prev_mem = curr_mem
        curr_mem = current_memory()
        print(f"| Count: {counter:>03d} | Current memory usage: {current_memory():>0.05f} | Memory delta: {curr_mem - prev_mem:>0.05f} |")
        patch_res = client.map(get_patch, [dt + x for x in range(num_patches)], dt=dt)
        final_res = client.gather(patch_res)
        counter += 1
        del patch_res, final_res
        sleep(3.0)
        if counter % 5 == 0:
            print(trim_memory())
        if counter % 10 == 0:
            client.run(trim_memory)
    print(f"Initial memory: {initial_mem:0.05f} -> Final memory: {curr_mem:0.05f} => Change: {curr_mem - initial_mem:0.05f}")

With this we see output like the following on an OSX and Ubuntu LocalCluster:

Dask Dashboard:  http://127.0.0.1:8787/status
Initial memory: 109.40234375
Sleeping...
Memory after sleep: 110.6640625
| Count: 000 | Current memory usage: 110.66406 | Memory delta: 0.00000 |
| Count: 001 | Current memory usage: 116.46875 | Memory delta: 5.80469 |
| Count: 002 | Current memory usage: 117.16406 | Memory delta: 0.69531 |
| Count: 003 | Current memory usage: 117.89062 | Memory delta: 0.72656 |
| Count: 004 | Current memory usage: 118.39844 | Memory delta: 0.50781 |
1
| Count: 005 | Current memory usage: 117.90625 | Memory delta: -0.49219 |
| Count: 006 | Current memory usage: 119.18750 | Memory delta: 1.28125 |
| Count: 007 | Current memory usage: 119.44531 | Memory delta: 0.25781 |
| Count: 008 | Current memory usage: 120.45703 | Memory delta: 1.01172 |
| Count: 009 | Current memory usage: 120.96094 | Memory delta: 0.50391 |
1
| Count: 010 | Current memory usage: 120.70703 | Memory delta: -0.25391 |
| Count: 011 | Current memory usage: 121.96875 | Memory delta: 1.26172 |
...

The memory being used here is what psutil reports as RSS. As the output shows, as client.map/gather are called the memory increases over time. This does not use any library outside the standard library except dask and psutil. The rate of increase in memory seems to be dependent on the number of items passed to client.map that get converted to futures/jobs. Any ideas?

1 Like

@djhoese and @sbradshaw thanks for providing a more minimal example! @ian and I were able to reproduce the issue, though I’m not quite sure why this is happening. I altered your snippet slightly and even removing the del lines did not affect the memory on the client nor the number of references on the client (when we would expect both to increase):

import sys, os, gc
import psutil
import numpy as np
from dask.distributed import Client, LocalCluster
import random
import ctypes


def current_memory():
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / (1024 * 1024)  # MB
    return mem


def get_patch(index, dt):
    return None


if __name__ == "__main__":
    cluster = LocalCluster(n_workers=8, threads_per_worker=1, memory_limit='5GB')
    client = Client(cluster)
    dts = list(range(23 * 11))

    initial_mem = curr_mem = current_memory()
    print(f"Initial memory: {curr_mem}")
    curr_mem = current_memory()
    counter = 0
    num_patches = 210
    for dt in dts:
        prev_mem = curr_mem
        curr_mem = current_memory()
        print(f"| Count: {counter:>03d} | Current memory usage: {current_memory():>0.05f} | Memory delta: {curr_mem - prev_mem:>0.05f} |")
        patch_res = client.map(get_patch, [dt + x for x in range(num_patches)], dt=dt)
        final_res = client.gather(patch_res)
        print(f'number of client references {len(client.futures)}'
        counter += 1
        if counter % 10 == 0:
            gc.collect()

A similar issue has been recently reported in dask/distributed, so it seems you may not be the only one facing this problem! I’m not sure how helpful this is, but in the short term one solution might be to try to reduce the amount of time the client is running.

1 Like

As I mentioned on the issue @scharlottej13 linked, could you try disabling all the log parameters mentioned in Set scheduler log sizes automatically based on available memory · Issue #5570 · dask/distributed · GitHub and see if any of them make a difference?

2 Likes

@gjoseph92 We have tried disabling the log parameters you mentioned without any luck. Do you have any other recommendations for a work around? I have also tried forcing a client.restart() after a client.gather() call followed by a gc.collect() since my code does a client.map() over many iterations. This did not work to reduce the memory accumulated as it only seemed to shut down the workers and restart them (when the workers were not the ones holding the memory accumulation).

This is not correct - you set up to have 8 workers taking up to 5GB each, for a total of 40GB, plus whatever the client + scheduler take (in the case of LocalCluster that you’re using, client and scheduler run in the same process).
If your laptop is mounting less than 40 GB + whatever RAM is taken by your MacOS desktop + dask client + dask scheduler, the OS will start swapping out which is going to hurt performance badly. Note that whenever each worker will reach 5*0.6=3GB of used RAM, it will start spilling to disk so you’ll hopefully never reach the full 5GB. This is visible from the dashboard in the form of a gray area in the memory chart.
This is explained at Worker Memory Management — Dask.distributed 2023.11.0+22.gdc06ce4 documentation.

Could you please run dask-scheduler and dask-worker as separate processes from the command line and figure out if the memory increase is in the client, scheduler, or specific to LocalCluster?

1 Like

@crusaderky We will look into the workers being killed when they only report ~1GB of usage in the system monitor.

I’ve added a comment to Possible memory leak when using LocalCluster · Issue #5960 · dask/distributed · GitHub where I ran the above simplified script with a separate scheduler and 4 workers. The memory usage is clearly in the scheduler.