Testing lazy evaluation of task graphs

Hi everyone!

This is a question about testing the lazy evaluation of dask task graphs!

We’re using Dask+Xarray for a library of processes that can be chained together to manipulate large amounts of Earth Observation data. Since we’re usually operating on data that is larger than the available memory in our Dask cluster, we assume the evaluation to remain lazy throughout the task graph, i.e. no intermittent .compute() calls.

We’ve noticed that for us it is quite easy to produce a certain type of bug, where this lazy evaluation is broken by accident (-> something along the way leading to a call of .compute()), too-much data is loaded prematurely and Dask workers die due to running out of memory.

To better guard against this type of bug in the future, I thought I could extend our testsuite with a function that checks whether a given process implementation inadvertently loads the data into cluster memory.

I’ve tried wrapping the function-under-test with the MemorySampler from distributed.diagnostics.memory_sampler and comparing memory usage on a LocalCluster before/after calling the function-under-test. However, I’ve noticed that cluster memory on LocalCluster seems to slowly increase, even when no tasks are being worked on (see below for code to replicate). It’s therefore not obvious to me how I would reliably detect any accidental loadings against this background noise.

Before going any further down this particular rabbit hole, I thought I’d ask here:

  1. Does this approach to testing even make sense? Or am I misunderstanding some fundamental Dask concepts here?
  2. If yes, does anyone have any pointers for how I’d reasonably go about this?

Thanks in advance for any support!! :slight_smile:
Lukas.


Code to replicate what I’ve tried:

from dask.distributed import Client
from distributed.diagnostics.memory_sampler import MemorySampler
import numpy as np
from time import sleep

with Client():
    ms = MemorySampler()
    with ms.sample("mem_usage", interval=0.5):
        sleep(0.5)
        print("hello")
        sleep(2)

    ms_pandas = ms.to_pandas()
    mem_usage_range = np.ptp(ms_pandas["mem_usage"])
    print(ms_pandas)
    print(mem_usage_range)
    assert mem_usage_range == 0  # This assertion errors

Python version: 3.9.5
Dask version: 2022.12.0