Restrict task graph to one worker in a distributed cluster

I am using an HPC cluster to do some data processing. The workload is very numpy/dask.array heavy and parallelizes beautifully over the threaded scheduler on a single node (which has 40 cores). It looks something like this:

x = read_lots_of_array_data_from_disk(*args)
x = process_data(x)
x.compute()

For a single “block”, this can usually involve reading ~50-100 GB of data and the final task graph to compute x usually has ~10-20k tasks on the graph (lots of FFTs, concatenates, reshapes, etc.). There are something like 60 “read-data” tasks per block of computation (this is relevant later).

Excited to parallelize this so I can use multiple nodes/machines simultaneously to compute many blocks in parallel, I used dask-jobqueue to create 1 worker per node, each node with as many threads as I can get per machine. Then I try to dask.compute(list_of_many_x) - this is where we run into many problems! Suddenly, my workflow is slowed down by tons of transfer tasks.

My hope was that each node would compute a single block by itself because transfer costs are really high even with infiniband. But unfortunately, the read tasks get distributed across multiple nodes (because there are more read tasks than threads on a single node) and then a lot of time gets wasted consolidating the array chunks to a single node for the processing bits.

To solve this issue - is there a way to do dask.compute(x) in a way that restricts the entire task graph to a single worker while using a single Cluster/client object.

PS: I don’t want to resort to tactics like “Make N cluster objects, scale each one to 1 worker, then submit jobs one at a time to each cluster objects” (that’s just running MPI with extra steps).