Hello all, I am trying to optimize the process of scaling up/down effectively without Dask resubmitting tasks. For example, suppose I scale up to 20 nodes to handle a workflow that needs to process 100,000 files. My workers become well saturated and it successfully processes 99,000 files, however 100 of those take significantly longer to process (hrs maybe), so many of my workers just go idle. Now, say I want to scale down this cluster, what Ive noticed is that Dask will re-submit tasks from those killed nodes. How can I setup a scenario where I can scale down and Dask can nicely shuffle its remaining tasks to smaller cluster? I don’t mind re-submitting unfinished tasks but I do not want to re-submit finished ones.
cluster.scale(jobs=20)
futures = client.map(process_file, file_list, batch_size=n_workers, pure=False)
# after 30 minutes 99,000 tasks are done so I do this
cluster.scale(jobs=2)
# 100 tasks remaining get re-submitted to scaled-down cluster . 2 nodes.
# 18 resources can be given back
It looks like you are using dask-jobqueue, is that right? This library currently doesn’t handle well task locality upon scaling down, that might be your problem. I think you should gather all finished Futures, and delete their reference in order to prevent them from being executed again!
Hi @guillaumeeb , that is correct! Using Dask-jobqueue at the moment. I found a way to handle this . If it helps, posting on here to help others implement this pattern.
seq = as_completed(futures) # contains more than 100_000 tasks
current_n_nodes = 20 # start n_nodes
workers = 30 # Workers per node.
results = []
while seq.count() > 0:
current_total_workers = current_n_nodes*workers:
if seq.count() <= current_total_workers:
new_n_nodes = math.ceil(seq.count() / workers)
if new_n_nodes < current_n_nodes:
print(f"[LOG] Adjusting cluster. Scaling to '{new_n_nodes}' jobs")
cluster.scale(jobs=new_n_nodes)
current_n_nodes = new_n_nodes
f = next(seq)
if f.status == "finished":
results.append(f.result())
f.release()
else:
if f.status != "cancelled":
f.retry()
seq.add(f)