Testing lazy evaluation of task graphs

Hi everyone!

This is a question about testing the lazy evaluation of dask task graphs!

We’re using Dask+Xarray for a library of processes that can be chained together to manipulate large amounts of Earth Observation data. Since we’re usually operating on data that is larger than the available memory in our Dask cluster, we assume the evaluation to remain lazy throughout the task graph, i.e. no intermittent .compute() calls.

We’ve noticed that for us it is quite easy to produce a certain type of bug, where this lazy evaluation is broken by accident (-> something along the way leading to a call of .compute()), too-much data is loaded prematurely and Dask workers die due to running out of memory.

To better guard against this type of bug in the future, I thought I could extend our testsuite with a function that checks whether a given process implementation inadvertently loads the data into cluster memory.

I’ve tried wrapping the function-under-test with the MemorySampler from distributed.diagnostics.memory_sampler and comparing memory usage on a LocalCluster before/after calling the function-under-test. However, I’ve noticed that cluster memory on LocalCluster seems to slowly increase, even when no tasks are being worked on (see below for code to replicate). It’s therefore not obvious to me how I would reliably detect any accidental loadings against this background noise.

Before going any further down this particular rabbit hole, I thought I’d ask here:

  1. Does this approach to testing even make sense? Or am I misunderstanding some fundamental Dask concepts here?
  2. If yes, does anyone have any pointers for how I’d reasonably go about this?

Thanks in advance for any support!! :slight_smile:
Lukas.


Code to replicate what I’ve tried:

from dask.distributed import Client
from distributed.diagnostics.memory_sampler import MemorySampler
import numpy as np
from time import sleep

with Client():
    ms = MemorySampler()
    with ms.sample("mem_usage", interval=0.5):
        sleep(0.5)
        print("hello")
        sleep(2)

    ms_pandas = ms.to_pandas()
    mem_usage_range = np.ptp(ms_pandas["mem_usage"])
    print(ms_pandas)
    print(mem_usage_range)
    assert mem_usage_range == 0  # This assertion errors

Python version: 3.9.5
Dask version: 2022.12.0

Hi @luk,

Sorry for the long delay until you finally got a (partial) answer. Welcome to this forum, even if it’s not the best start you probably hoped for.

Just to clarify, is this a bug, or just that somehow the code written implies/needs to compute() something somewhere, so a bug in your code rather than in Dask?

Thanks to the provided code, I’m able to easily replicate the behavior. However, as a Python process doing a lot of communications with the Scheduler and also providing metrics and doing other things (even the MemorySampler must use some memory), I’m not really surprised by this slow increase. This would probably be cleaned after some time by Garbage Collection.

I think it could make sense in your case yes. Watching Worker memory in tests to detect big bumps sound reasonable to me.

I think you should take into account this slow increase, and define a threshold for detecting memory problem. Running your test during 2 minutes leads to an increase of about 7MiB in my case. This is probably much lower than the bumps your trying to detect? If yes, you could go with a

assert mem_usage_range <= 1e8
# or even
assert mem_usage_range <= duration * 1e5

Maybe @jrbourbeau or @fjetter have more things to say about the Worker memory slowly increasing over time?

1 Like

Thank you so much for the thoughtful response, really appreciate it!

Yes, I’m referring entirely to bugs in our code and usage of dask, not in dask itself!

Thanks for this advice - I will give this approach a shot then and report back on how well it works out in practice! :wink:

1 Like