Description
I am encountering a behavior where tasks submitted from within a worker (using worker_client) are dispatched to the scheduler immediately, even when the target workers are already fully occupied or saturated with resources.
I would like a mechanism to block or wait at the point of submission (e.g., inside map_func) until the workers have the actual capacity to accept new jobs.
Environment Details
- Python: 3.11.13
- Dask: 2025.12.0
- Distributed: 2025.12.0
- Cluster Configuration: LocalCluster with 1 thread per worker.
Current Behavior
In the reproduction script provided below, the following sequence occurs:
-
Parent Task Initiation: A parent task (
func_A) is launched on a worker. -
Resource Claim: The task dynamically claims capacity on the cluster using
set_cluster_resources. -
Simulated Workload: The parent tasks simulate work by waiting: 2 tasks wait for 3 seconds, while the remaining 3 tasks wait for 15 seconds.
-
The Issue (Unrestricted Submission): Immediately after this wait period,
func_Asubmits multiple child tasks (func_B) viaclient.map. This submission happens instantly, disregarding the fact that the target workers are already saturated. There is currently nomechanism to pause submission until resources become available. Ideally, each
func_Btask should be submitted only when a worker thread becomes free. Consequently, the dashed lines in the plot below (representing queueing time) forfunc_Aandfunc_Bjobs should be absent.
Visualization
The following plot (generated by the script) illustrates the issue. Note how tasks are queued immediately even though execution cannot start.
For example, a func_B job submitted from Worker TCP-Task UID 42155-b0a7a8 to 33861-90fcbc waits in the queue for over 10 seconds (dashed line), despite resources being available on the other two workers. This indicates that the scheduler is assigning tasks to saturated workers prematurely rather than waiting for or utilizing available slots elsewhere.
Expected Behavior
The submission logic should be able to respect the worker’s saturation. The client should wait to submit the next batch of jobs until the cluster resources (specifically the dynamic resources set on the workers) are freed up.
Reproduction Script
The following script reproduces the scenario. It simulates a workload, tracks task execution via a custom WorkerPlugin, and generates a visualization showing the task overlap/queueing.
import csv
import glob
import os
import tempfile
import time
from contextlib import contextmanager
import dask
import dask.distributed
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
def set_cluster_resources(client, resource_name, capacity):
"""
Dynamically injects resources into a running cluster.
"""
client.run(_update_worker_capacity, resource_name, capacity)
async def _update_worker_capacity(resource_name, capacity, dask_worker=None):
"""
Runs INSIDE the worker.
Uses the official set_resources method to trigger a scheduler update.
"""
current_resources = dask_worker.state.total_resources.copy()
current_resources[resource_name] = capacity
await dask_worker.set_resources(**current_resources)
return
def map_func(func, args):
name = f"{func.__name__}"
address, uid = retrieve_submitter()
with (
client_manager() as client,
dask.annotate(submitter_url=address, submitter_uid=uid),
):
set_cluster_resources(client, name, 1)
results = client.gather(client.map(func, *args, resources={name: 1}))
return results
@contextmanager
def client_manager():
try:
with dask.distributed.worker_client() as client:
# NOTE: Delay execution briefly. This prevents a race condition that can cause
# failures if two worker_client instances are created in quick succession.
time.sleep(0.5)
yield client
except ValueError:
yield dask.distributed.get_client()
def func_B(sleep, parent_name, name):
"""The sub-task. It does 'work' (sleeps) and returns None."""
time.sleep(sleep)
print(f"{name} (from {parent_name})" + f" completed after {sleep:.2f}s")
return
def func_A(name, sleep_time, n_subtasks):
"""
The parent task.
1. Consumes a resource slot.
2. Submits n_subtasks of type B.
"""
worker = dask.distributed.get_worker()
print(f"[{name}] STARTING on {worker.address}")
time.sleep(sleep_time)
args = [
[2 for i in range(n_subtasks)],
[name for i in range(n_subtasks)],
[f"B{i+1}" for i in range(n_subtasks)],
]
results = map_func(func_B, args)
return
def run_dask_test(n_workers, n_A_tasks, n_B_tasks, fast_tasks, path):
dask.config.set(
{
"distributed.scheduler.worker-saturation": 1.0,
}
)
cluster = dask.distributed.LocalCluster(n_workers=n_workers, threads_per_worker=1)
client = dask.distributed.Client(cluster)
plugin = TaskQueueLogger(path=path)
client.register_plugin(plugin)
args = [
[f"A{i+1}" for i in range(n_A_tasks)],
[3 if i < fast_tasks else 15 for i in range(n_A_tasks)],
[n_B_tasks for i in range(n_A_tasks)],
]
results = map_func(func_A, args)
client.close()
cluster.close()
print("\nTest Finished.")
def retrieve_submitter():
from dask.distributed import get_worker
try:
worker = get_worker()
return worker.address, worker.get_current_task()
except ValueError:
return "tcp://127.0.0.1:client", "client"
class TaskQueueLogger(dask.distributed.WorkerPlugin):
def __init__(self, path):
self.tasks = {}
self.path = path
def setup(self, worker):
self.worker = worker
safe_addr = worker.address.replace("://", "_").replace(":", "_")
self.filepath = f"{self.path}/worker_log_{safe_addr}.csv"
if not os.path.exists(self.filepath):
with open(self.filepath, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(
[
"Task_Name",
"Task_UID",
"Submitter_TCP",
"Submitter_UID",
"Receiver_TCP",
"Queued",
"Start",
"Finish",
]
)
def transition(self, key, start, finish, **kwargs):
timestamp = time.time()
if key not in self.tasks:
# RETRIEVE ANNOTATIONS
# We look up the task object in the worker's state to find annotations
task_state = self.worker.state.tasks.get(key)
annotations = getattr(task_state, "annotations", {}) or {}
# extract our custom 'submitter' annotation (default to 'Unknown')
sub_url = annotations.get("submitter_url", "Unknown-URL")
sub_uid = annotations.get("submitter_uid", "Unknown-UID")
self.tasks[key] = {
"sub_url": sub_url,
"sub_uid": sub_uid.split("-")[-1],
"queued": None,
"start": None,
"finish": None,
}
if finish == "waiting":
self.tasks[key]["queued"] = timestamp
elif finish == "executing":
self.tasks[key]["start"] = timestamp
elif finish == "memory":
self.tasks[key]["finish"] = timestamp
self._write_log(key)
del self.tasks[key]
elif finish == "released" or finish == "error":
if key in self.tasks:
del self.tasks[key]
def _write_log(self, key):
record = self.tasks[key]
if record["start"] and record["finish"]:
t_queued = self._format_time(record.get("queued"))
t_start = self._format_time(record.get("start"))
t_finish = self._format_time(record.get("finish"))
with open(self.filepath, "a", newline="") as f:
writer = csv.writer(f)
task_name, id = "-".join(key.split("-")[:-1]), key.split("-")[-1]
writer.writerow(
[
task_name,
id,
record["sub_url"],
record["sub_uid"],
self.worker.address,
t_queued,
t_start,
t_finish,
]
)
def _format_time(self, ts):
time_string = "N/A"
if ts is not None:
time_string = f"{time.strftime('%H:%M:%S', time.localtime(ts))}.{int((ts % 1) * 1000):03d}"
return time_string
def concat_csv(path):
files = glob.glob(f"{path}/*")
complete_df = []
for f in files:
df = pd.read_csv(f)
complete_df.append(df)
complete_df = (
pd.concat(complete_df).sort_values(by=["Receiver_TCP"]).reset_index(drop=True)
)
complete_df["TCP_UID"] = (
complete_df["Receiver_TCP"].str.split(":").str[-1]
+ "-"
+ complete_df["Task_UID"].str[:6]
)
time_fmt = "%H:%M:%S.%f"
for col in ["Queued", "Start", "Finish"]:
complete_df[col] = pd.to_datetime(complete_df[col], format=time_fmt)
# Normalize Time (Start at 0 for readability)
min_time = complete_df["Queued"].min()
complete_df["Q_sec"] = (complete_df["Queued"] - min_time).dt.total_seconds()
complete_df["S_sec"] = (complete_df["Start"] - min_time).dt.total_seconds()
complete_df["F_sec"] = (complete_df["Finish"] - min_time).dt.total_seconds()
return complete_df
def plot(df, path, file_name, tasks: list | None = None, border=0.4):
if tasks is not None:
tasks = [tasks] if isinstance(tasks, str) else list(tasks)
df = df[df["Task_Name"].isin(list(tasks))]
task_row_map = {row["Task_UID"]: i for i, row in df.iterrows()}
fig, ax = plt.subplots(figsize=(14, 7))
cmap = plt.get_cmap("tab10")
colors = {i: cmap(counter) for counter, i in enumerate(df["Task_Name"].unique())}
cmap = plt.get_cmap("tab10_r")
colors_bg = {
i: cmap(counter) for counter, i in enumerate(df["Receiver_TCP"].unique())
}
for i in df["Receiver_TCP"].unique():
d = df[df["Receiver_TCP"] == i]
ymin, ymax = d.index.min() - border, d.index.max() + border
ax.axhspan(ymin, ymax, color=colors_bg[i], alpha=0.3, zorder=0)
for i, row in df.iterrows():
ax.plot(
[row["Q_sec"], row["S_sec"]],
[i, i],
color="gray",
linestyle=":",
linewidth=1.5,
)
duration = row["F_sec"] - row["S_sec"]
c = colors[row["Task_Name"]]
ax.barh(
i,
duration,
left=row["S_sec"],
height=0.6,
color=c,
edgecolor="black",
alpha=0.8,
)
for i, row in df.iterrows():
parent_uid = row["Submitter_UID"]
if parent_uid in task_row_map:
parent_idx = task_row_map[parent_uid]
rad = 0.2 if parent_idx > i else -0.2
arrow = mpatches.FancyArrowPatch(
(row["Q_sec"], parent_idx),
(row["Q_sec"], i),
connectionstyle=f"arc3,rad={rad}",
color="black",
arrowstyle="-|>",
mutation_scale=12,
linewidth=1,
alpha=0.6,
)
ax.add_patch(arrow)
arrow = mpatches.FancyArrowPatch(
(row["F_sec"], i),
(row["F_sec"], parent_idx),
connectionstyle=f"arc3,rad={rad}",
color="red",
arrowstyle="-|>",
mutation_scale=12,
linewidth=1,
alpha=0.7,
)
ax.add_patch(arrow)
ax.set_yticks(range(len(df)))
ax.set_yticklabels(df["TCP_UID"])
ax.set_ylabel("Worker TCP - Task UID")
ax.set_xlabel("Time (seconds)")
ax.grid(True, axis="x", linestyle="--", alpha=0.3)
legend_patches = [mpatches.Patch(color=c, label=la) for la, c in colors.items()]
ax.legend(handles=legend_patches, loc="upper right")
plt.tight_layout()
plt.savefig(file_name)
if __name__ == "__main__":
with tempfile.TemporaryDirectory() as tmpdirname:
run_dask_test(
n_workers=3,
n_A_tasks=5,
n_B_tasks=2,
fast_tasks=2,
path=tmpdirname,
)
df = concat_csv(tmpdirname)
plot(df, tmpdirname, file_name=f"/tmp/figures/output.png")


