Available options for collective communication in Dask

When I look at the collective communication options I was only able to find

Rabit (used in XGBoost) [1] and RAFT’s NCCL wrapper[2]. I was curious if there are any other options which
1.Are available in a non-XGBoost managed dask cluster.
2. More tolerant to node failures than NCCL

I found Hoplite [3] which seems to be designed for resilient collective communication for frameworks like Dask. However this doesn’t seem to be readily available even in Ray, for which it was designed.

I was curious if there are implementations which examine this problem.

[1] : https://github.com/dmlc/xgboost/blob/e93a2748230656de9770aa880157dc71c35af2da/python-package/xgboost/dask.py#L219-L233
[2] : RAFT Dask API — raft 23.06.00 documentation
[3] : Zhuang, Siyuan, Zhuohan Li, Danyang Zhuo, Stephanie Wang, Eric Liang, Robert Nishihara, Philipp Moritz, and Ion Stoica. “Hoplite: Efficient and Fault-Tolerant Collective Communication for Task-Based Distributed Systems.” arXiv, September 28, 2021. [2002.05814] Hoplite: Efficient and Fault-Tolerant Collective Communication for Task-Based Distributed Systems.


NOT A CONTRIBUTION

Hi @vij,

If I understand correctly, you are trying to find a library which would allow you to issue direct communications between Worker processes? Could you give an example of the workflow you’re trying to implement?

I am trying to implement an algorithm which has frequent gather + broadcast / all-gather calls. I was able to perform fast gather via dask.bag.fold but I haven’t been able to find any mechanisms to perform broadcast(e.g. tree-broadcast).

This broadcast remains a huge bottleneck and I am trying to figure out mechanisms to speed it up.


NOT A CONTRIBUTION

Hi vij,

Dask’s general paradigm is not MPI (which is what xgboostg.dask does).
Instead of starting services and then using the constructs you described to make them communicate with each other, we use task graphs. For a broadcast, you’ll have a single dependent task with many dependencies.

The current algorithm is fairly naive, as all consumers will fetch the data from the same one producer instead of collaboratively propagating the data. This works well enough for small data (e.g. to propagate a scalar or a handful of MiBs worth of domain table) but it could be problematic if you need to broadcast many gigabytes of data to thousands of workers.

This could be solved algorithmically within the dask scheduler and workers. I would strongly discourage an MPI-oriented add-on library.

1 Like

Thanks for clarifying the design goal.

I am trying out a few variants of task graphs for doing the distribution, including the existing scatter. I will report back if I find any competitive mechanisms.


NOT A CONTRIBUTION

Please don’t use scatter to propagate data in the middle of a computation. That’s designed just to deploy it before you start, and only if you can’t load it from network storage directly from the workers.

Please don’t use scatter to propagate data in the middle of a computation

Defnitely.
I am avoiding any mixing of scatter with the asynchronous calls. It is being used between distinct phases i.e., separated by synchronouscompute calls. For now it is solving immediate need of broadcasting and is performing better than remote storage based solutions.

This could be solved algorithmically within the dask scheduler and workers.

I am interesting in implementing such a solution, would you have any pointers to existing methods which might be used for guidance ?
Going through the reduction methods in dask array and bag modules I got some sense of tree-reduce implementation.
Going through the client.scatter implementation I have some sense of how distribution is being coordinated.
Any other more relevant pointers would be helpful.


NOT A CONTRIBUTION

@vij do you have some sample code you’d like to run or some pseudo code? I suspect that dask is already doing what you want to. There are a couple of config options that impact the data replication but generally you shouldn’t have to worry about this.

For example

data = client.submit(load_data)
intermediate = client.map(process_data, range(...), data=data)
client.submit(reduce, intermediate)
...

will replicate the result of load_data efficiently to all workers that are processing a process_data task and is reducing this again to one worker, etc.

Are you experiencing poor performance or what is pointing you to a MPI style pattern?

In my case I did the following

model, = client.scatter([model],broadcast=False)
<persisted-dask-bag>.map(run_inference, model=model)

When I ran a snippet like this on a large cluster (several hundred nodes) and counted the number of times a particular node acted as the source for model Future* it turns out that only 10% of the nodes acted as sources. Further only 1 node acted as the source 78% of the time. The other nodes each accounted <1% of the transfers each.

More importantly I also observed that the same model was fetched several times on a worker :confused:, which makes me believe this is being discarded after each chunk is processed.

Are you experiencing poor performance or what is pointing you to a MPI style pattern?

There is a lot of delay (~1min) before the slowest worker starts computation using the model even for small assets (few hundred MB). The initial scatter call itself is quite fast (few seconds), which IIUC is the time to upload it to 1 worker. So I am assuming the large delay is due to sub-optimal transfer of model across the pool. I was hoping to more effectively transfer this asset.

The 1 minute delay isn’t much if this is single time operation, however for an iterative algorithms with a few hundred scatter type operations this quickly adds up. Right now these operations take >25% of the execution time in my case.

[*] This accounting was performed by tracking distributed.worker.get_data_from_worker calls, as implemented by a collaborator.


NOT A CONTRIBUTION