Enforce Blocking on Task Submission when Workers are Saturated

Description

I am encountering a behavior where tasks submitted from within a worker (using worker_client) are dispatched to the scheduler immediately, even when the target workers are already fully occupied or saturated with resources.

I would like a mechanism to block or wait at the point of submission (e.g., inside map_func) until the workers have the actual capacity to accept new jobs.

Environment Details

  • Python: 3.11.13
  • Dask: 2025.12.0
  • Distributed: 2025.12.0
  • Cluster Configuration: LocalCluster with 1 thread per worker.

Current Behavior

In the reproduction script provided below, the following sequence occurs:

  1. Parent Task Initiation: A parent task (func_A) is launched on a worker.

  2. Resource Claim: The task dynamically claims capacity on the cluster using set_cluster_resources.

  3. Simulated Workload: The parent tasks simulate work by waiting: 2 tasks wait for 3 seconds, while the remaining 3 tasks wait for 15 seconds.

  4. The Issue (Unrestricted Submission): Immediately after this wait period, func_A submits multiple child tasks (func_B) via client.map. This submission happens instantly, disregarding the fact that the target workers are already saturated. There is currently no

    mechanism to pause submission until resources become available. Ideally, each func_B task should be submitted only when a worker thread becomes free. Consequently, the dashed lines in the plot below (representing queueing time) for func_A and func_B jobs should be absent.

Visualization

The following plot (generated by the script) illustrates the issue. Note how tasks are queued immediately even though execution cannot start.
For example, a func_B job submitted from Worker TCP-Task UID 42155-b0a7a8 to 33861-90fcbc waits in the queue for over 10 seconds (dashed line), despite resources being available on the other two workers. This indicates that the scheduler is assigning tasks to saturated workers prematurely rather than waiting for or utilizing available slots elsewhere.

Expected Behavior

The submission logic should be able to respect the worker’s saturation. The client should wait to submit the next batch of jobs until the cluster resources (specifically the dynamic resources set on the workers) are freed up.

Reproduction Script

The following script reproduces the scenario. It simulates a workload, tracks task execution via a custom WorkerPlugin, and generates a visualization showing the task overlap/queueing.

import csv
import glob
import os
import tempfile
import time
from contextlib import contextmanager

import dask
import dask.distributed
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd


def set_cluster_resources(client, resource_name, capacity):
    """
    Dynamically injects resources into a running cluster.
    """
    client.run(_update_worker_capacity, resource_name, capacity)


async def _update_worker_capacity(resource_name, capacity, dask_worker=None):
    """
    Runs INSIDE the worker.
    Uses the official set_resources method to trigger a scheduler update.
    """
    current_resources = dask_worker.state.total_resources.copy()
    current_resources[resource_name] = capacity
    await dask_worker.set_resources(**current_resources)
    return


def map_func(func, args):
    name = f"{func.__name__}"
    address, uid = retrieve_submitter()
    with (
        client_manager() as client,
        dask.annotate(submitter_url=address, submitter_uid=uid),
    ):
        set_cluster_resources(client, name, 1)
        results = client.gather(client.map(func, *args, resources={name: 1}))

    return results


@contextmanager
def client_manager():
    try:
        with dask.distributed.worker_client() as client:
            # NOTE: Delay execution briefly. This prevents a race condition that can cause
            #      failures if two worker_client instances are created in quick succession.
            time.sleep(0.5)
            yield client
    except ValueError:
        yield dask.distributed.get_client()


def func_B(sleep, parent_name, name):
    """The sub-task. It does 'work' (sleeps) and returns None."""
    time.sleep(sleep)
    print(f"{name} (from {parent_name})" + f" completed after {sleep:.2f}s")
    return


def func_A(name, sleep_time, n_subtasks):
    """
    The parent task.
    1. Consumes a resource slot.
    2. Submits n_subtasks of type B.
    """
    worker = dask.distributed.get_worker()
    print(f"[{name}] STARTING on {worker.address}")
    time.sleep(sleep_time)
    args = [
        [2 for i in range(n_subtasks)],
        [name for i in range(n_subtasks)],
        [f"B{i+1}" for i in range(n_subtasks)],
    ]
    results = map_func(func_B, args)
    return


def run_dask_test(n_workers, n_A_tasks, n_B_tasks, fast_tasks, path):
    dask.config.set(
        {
            "distributed.scheduler.worker-saturation": 1.0,
        }
    )
    cluster = dask.distributed.LocalCluster(n_workers=n_workers, threads_per_worker=1)
    client = dask.distributed.Client(cluster)
    plugin = TaskQueueLogger(path=path)
    client.register_plugin(plugin)

    args = [
        [f"A{i+1}" for i in range(n_A_tasks)],
        [3 if i < fast_tasks else 15 for i in range(n_A_tasks)],
        [n_B_tasks for i in range(n_A_tasks)],
    ]

    results = map_func(func_A, args)
    client.close()
    cluster.close()
    print("\nTest Finished.")


def retrieve_submitter():
    from dask.distributed import get_worker

    try:
        worker = get_worker()
        return worker.address, worker.get_current_task()
    except ValueError:
        return "tcp://127.0.0.1:client", "client"


class TaskQueueLogger(dask.distributed.WorkerPlugin):
    def __init__(self, path):
        self.tasks = {}
        self.path = path

    def setup(self, worker):
        self.worker = worker
        safe_addr = worker.address.replace("://", "_").replace(":", "_")
        self.filepath = f"{self.path}/worker_log_{safe_addr}.csv"

        if not os.path.exists(self.filepath):
            with open(self.filepath, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(
                    [
                        "Task_Name",
                        "Task_UID",
                        "Submitter_TCP",
                        "Submitter_UID",
                        "Receiver_TCP",
                        "Queued",
                        "Start",
                        "Finish",
                    ]
                )

    def transition(self, key, start, finish, **kwargs):
        timestamp = time.time()

        if key not in self.tasks:
            # RETRIEVE ANNOTATIONS
            # We look up the task object in the worker's state to find annotations
            task_state = self.worker.state.tasks.get(key)
            annotations = getattr(task_state, "annotations", {}) or {}

            # extract our custom 'submitter' annotation (default to 'Unknown')
            sub_url = annotations.get("submitter_url", "Unknown-URL")
            sub_uid = annotations.get("submitter_uid", "Unknown-UID")

            self.tasks[key] = {
                "sub_url": sub_url,
                "sub_uid": sub_uid.split("-")[-1],
                "queued": None,
                "start": None,
                "finish": None,
            }

        if finish == "waiting":
            self.tasks[key]["queued"] = timestamp

        elif finish == "executing":
            self.tasks[key]["start"] = timestamp

        elif finish == "memory":
            self.tasks[key]["finish"] = timestamp
            self._write_log(key)
            del self.tasks[key]

        elif finish == "released" or finish == "error":
            if key in self.tasks:
                del self.tasks[key]

    def _write_log(self, key):
        record = self.tasks[key]
        if record["start"] and record["finish"]:
            t_queued = self._format_time(record.get("queued"))
            t_start = self._format_time(record.get("start"))
            t_finish = self._format_time(record.get("finish"))

            with open(self.filepath, "a", newline="") as f:
                writer = csv.writer(f)
                task_name, id = "-".join(key.split("-")[:-1]), key.split("-")[-1]
                writer.writerow(
                    [
                        task_name,
                        id,
                        record["sub_url"],
                        record["sub_uid"],
                        self.worker.address,
                        t_queued,
                        t_start,
                        t_finish,
                    ]
                )

    def _format_time(self, ts):
        time_string = "N/A"
        if ts is not None:
            time_string = f"{time.strftime('%H:%M:%S', time.localtime(ts))}.{int((ts % 1) * 1000):03d}"
        return time_string


def concat_csv(path):
    files = glob.glob(f"{path}/*")
    complete_df = []
    for f in files:
        df = pd.read_csv(f)
        complete_df.append(df)
    complete_df = (
        pd.concat(complete_df).sort_values(by=["Receiver_TCP"]).reset_index(drop=True)
    )
    complete_df["TCP_UID"] = (
        complete_df["Receiver_TCP"].str.split(":").str[-1]
        + "-"
        + complete_df["Task_UID"].str[:6]
    )
    time_fmt = "%H:%M:%S.%f"
    for col in ["Queued", "Start", "Finish"]:
        complete_df[col] = pd.to_datetime(complete_df[col], format=time_fmt)

    # Normalize Time (Start at 0 for readability)
    min_time = complete_df["Queued"].min()
    complete_df["Q_sec"] = (complete_df["Queued"] - min_time).dt.total_seconds()
    complete_df["S_sec"] = (complete_df["Start"] - min_time).dt.total_seconds()
    complete_df["F_sec"] = (complete_df["Finish"] - min_time).dt.total_seconds()
    return complete_df


def plot(df, path, file_name, tasks: list | None = None, border=0.4):
    if tasks is not None:
        tasks = [tasks] if isinstance(tasks, str) else list(tasks)
        df = df[df["Task_Name"].isin(list(tasks))]
    task_row_map = {row["Task_UID"]: i for i, row in df.iterrows()}
    fig, ax = plt.subplots(figsize=(14, 7))
    cmap = plt.get_cmap("tab10")
    colors = {i: cmap(counter) for counter, i in enumerate(df["Task_Name"].unique())}

    cmap = plt.get_cmap("tab10_r")
    colors_bg = {
        i: cmap(counter) for counter, i in enumerate(df["Receiver_TCP"].unique())
    }

    for i in df["Receiver_TCP"].unique():
        d = df[df["Receiver_TCP"] == i]
        ymin, ymax = d.index.min() - border, d.index.max() + border
        ax.axhspan(ymin, ymax, color=colors_bg[i], alpha=0.3, zorder=0)

    for i, row in df.iterrows():
        ax.plot(
            [row["Q_sec"], row["S_sec"]],
            [i, i],
            color="gray",
            linestyle=":",
            linewidth=1.5,
        )
        duration = row["F_sec"] - row["S_sec"]
        c = colors[row["Task_Name"]]
        ax.barh(
            i,
            duration,
            left=row["S_sec"],
            height=0.6,
            color=c,
            edgecolor="black",
            alpha=0.8,
        )
    for i, row in df.iterrows():
        parent_uid = row["Submitter_UID"]
        if parent_uid in task_row_map:
            parent_idx = task_row_map[parent_uid]

            rad = 0.2 if parent_idx > i else -0.2
            arrow = mpatches.FancyArrowPatch(
                (row["Q_sec"], parent_idx),
                (row["Q_sec"], i),
                connectionstyle=f"arc3,rad={rad}",
                color="black",
                arrowstyle="-|>",
                mutation_scale=12,
                linewidth=1,
                alpha=0.6,
            )
            ax.add_patch(arrow)

            arrow = mpatches.FancyArrowPatch(
                (row["F_sec"], i),
                (row["F_sec"], parent_idx),
                connectionstyle=f"arc3,rad={rad}",
                color="red",
                arrowstyle="-|>",
                mutation_scale=12,
                linewidth=1,
                alpha=0.7,
            )
            ax.add_patch(arrow)

    ax.set_yticks(range(len(df)))
    ax.set_yticklabels(df["TCP_UID"])
    ax.set_ylabel("Worker TCP - Task UID")
    ax.set_xlabel("Time (seconds)")
    ax.grid(True, axis="x", linestyle="--", alpha=0.3)

    legend_patches = [mpatches.Patch(color=c, label=la) for la, c in colors.items()]
    ax.legend(handles=legend_patches, loc="upper right")
    plt.tight_layout()
    plt.savefig(file_name)


if __name__ == "__main__":
    with tempfile.TemporaryDirectory() as tmpdirname:
        run_dask_test(
            n_workers=3,
            n_A_tasks=5,
            n_B_tasks=2,
            fast_tasks=2,
            path=tmpdirname,
        )

        df = concat_csv(tmpdirname)
        plot(df, tmpdirname, file_name=f"/tmp/figures/output.png")

Hi @Ale_dev, welcome ti Dask community!

First I’d like to say that you workflow is a bit complex and hard to understand, especially the set_cluster_resources part. I fear this is what is messing things around, I’m not sure what you are trying to do with it.

Anyway, launching tasks from tasks is always complicated, but if using worker_client context manager properly (at least without your set_client_resources method) there shouldn’t be this sort of problem.

It’s normal that tasks are submitted on the fly to the scheduler, what you want is more avoiding them to be dispatched on already full workers. You might try to play with resources saturations for that. Maybe a value of 1.0 would help, not sure if you can lessen it more than that.

Hi @guillaumeeb ,

If I remove set_cluster_resources, the scheduler sees the worker as free and starts overlapping func_A tasks, as shown in the plot below:

You can see this happening on worker 37485. While the first func_A job is waiting (after the 3-second sleep and func_B submission), the scheduler launches a second func_A job on the same worker.

I’m using set_cluster_resources to prevent this overlap because func_A may hold onto a significant amount of memory. If multiple instances of func_A stack up on the same worker, RAM usage grows indefinitely. I need the resource constraint to force the scheduler to wait until func_A is truly finished before starting the next one.

Regarding the worker-saturation parameter: I am currently setting it to 1.0 at the start of run_dask_test.
If I remove that configuration (reverting to the default), I get the following plot:

I don’t see any difference in behavior compared to the original image.