Tracking progress of logical groupings of tasks at execution time

I want to group tasks together in a dask execution graph and then track the progress of those groups when the graph is executing.

For instance, given the following operations:

# a dask-backed xr.DataArray
data_array: xr.DataArray = ...

# Group 1: "filter"
data_array_filtered = data_array.sel(time=slice("2024-01-01", "2024-12-31"))
# This will create multiple tasks. I want to be able to associate
# all of them with the "filter" group at runtime.

# Group 2: "adjust"
data_array_adjusted = (data_array_filtered ** 2) / 2 + 1
# This will create multiple tasks. I want to be able to associate
# all of them with the "adjust" group at runtime.

# Group 3: "aggregate"
data_array_mean = data_array_adjusted.mean()
# This will create multiple tasks. I want to be able to associate
# all of them with the "aggregate" group at runtime.

data_array_mean.compute()

During execution, I want to be able to generate events of the form:

group "filter" started
group "filter" completed
group "adjust" started
group "adjust" completed
group "aggregate" started
group "aggregate" completed

The solution should be general so that it works on arbitrarily complex graphs.

Things I’ve tried that don’t quite work:

dask.annotate to attach labels to tasks and then catch state transitions in a SchedulerPlugin

This only works if I wrap the .compute() call in a with annotate()which assigns the same annotation to all tasks, which is not what I want.

Something like the following would be ideal, but it doesn’t work:

with annotate(label="filter"):
    data_array_filtered = data_array.sel(time=slice("2024-01-01", "2024-12-31"))

with annotate(label="adjust"):
    data_array_adjusted = (data_array_filtered ** 2) / 2 + 1

with annotate(label="aggregate"):
    data_array_mean = data_array_adjusted.mean()

Create a mapping between task keys and groups by checking data_array.data.dask after each group’s operation to detect newly added task keys. Catch them in a SchedulerPlugin as above.

I can do

data_array_filtered = data_array.sel(time=slice("2024-01-01", "2024-12-31"))
data_array_filtered_keys = extract_keys(data_array_filtered.data.dask)

data_array_adjusted = (data_array_filtered ** 2) / 2 + 1
data_array_adjusted_keys = extract_keys(data_array_adjusted.data.dask)

# the set difference between data_array_adjusted_keys and
# data_array_filtered_keys contains only the keys that are part of
# the "adjust" group

This almost works but some tasks get replaced during graph optimization so the task keys at runtime are different than the mapping constructed earlier. E.g.

  • tasks get fused and their names concatenated e.g. mean_agg-aggregate-mean_chunk-...
  • tasks with keys like …-finalize-hlgfinalizecompute-… get added.

Manually insert new layers in the graph with custom tasks that only log events

I am able to get this to kind of work if I use client.get() directly and request the keys of the new tasks, but not if I use data_array.compute().

Manually rename all keys in the graph to add a prefix that is preserved when keys get concatenated during fusion

This is what prompted this question I posted earlier. I can’t find a way to safely, recursively rename all keys in a graph.

Hi @AdeelH,

A few thoughts on this:

  • First, having a real MVCE would really help, without Xarray and only Dask arrays if possible.
  • Looking at your example, you’ll probably never get what you want considering Dask internal behavior. Even if you manage to track all tasks belonging to one operations, due to Dask way of scheduling in depth first, you’ll probably get something like:
group "filter" started
group "adjust" started
group "aggregate" started
group "adjust" completed
group "filter" completed
group "aggregate" completed
  • As you’ve noticed, you’ll need to disable any graph optimization to be able to do what you want, which is probably not good for performance.
  • I’m not sure why the annotations are not working, are you using distributed Scheduler?