Optimising Dask computations (memory implications and communication overhead)

I am working on a project which was initially optimised for the Python multiprocessing library, and I am trying to use Dask to distribute the work over multiple nodes. Particularly I am trying to use the SSHCluster.

In order to optimise as much as possible, I have changed the worker methods to be more fine-grained, i.e. working on the smallest level possible, requiring smaller inputs and returning smaller outputs. I am trying to utilise the least amount of memory, while taking the least amount of time to complete a single task.

The kind of data structures I have are as follows: a large dict, with inner integer values, arrays and dicts. As well as simpler dicts, arrays and sets. These are dynamic in nature, i.e. they should be changed by the worker methods and returned back to the client, and then be used by subsequent calls to the same (and other) worker methods.

I also have a dict of dicts, which is static, and some other objects that also feature static data. I am saving these properties as JSON files, using the client.upload_file method to upload the files to the workers, and using the client.register_worker_callbacks method to register these files on the worker side to be able to use them as “global shared” memory from the worker methods (this uses up some memory, especially because the data is duplicated in the memory space of each worker, however, it works quite well, because the data is loaded once upon creating the workers, and then is shared by any subsequent worker method (task) computation).

However, when it comes to the dynamic memory, this data (approx. 800 mb in size), needs to be passed to the workers in the most efficient way possible, before starting the computation of the worker method.

I came up with 3 potential ways to achieve this communication between the client and the workers:

  1. Split the data structures into “partial data structures” based on what each worker method requires. For e.g. if the worker method will be tackling person A, B, C, only include the corresponding data for A, B, C. Then subsequently only return the data for A, B, C.
  1. Scatter these data structures to the workers using: client.scatter(dynamic_obj, broadcast=true). Pass the scattered data structures as futures to the workers (along with other small params). Then on the worker side build the “partial data structures” for local usage only, and return the results much like 1.
  1. Use Dask data structures such as Dask.Bag or Dask.Array.

The first one works; I am just not sure whether my memory consumption is optimal. For e.g. the client is using around 6.3gb of memory, the scheduler around 2.8gb, and the workers around 2.3gb each. I am using the client.submit or the client.map methods, and then evaluating the resulting futures with the as_completed method. I am also releasing each future as soon as I evaluate its results. While I think the client is warranted to use 6.3gb of memory, I am not sure why the scheduler is using that much memory, when I should be releasing the results so quickly. The workers seem to have a baseline memory which is around 1.9gb, so 2.3gb seems to be acceptable as “working memory” while the tasks are ongoing.

The second option doesn’t work. When trying to call:

client.scatter(dynamic_obj, broadcast=True)

on what I’ve described as “a large dict, with inner integer values, arrays and dicts”, I get:

distributed.comm.core.CommClosedError: in <TCP (closed) ConnectionPool.scatter local=tcp:// remote=tcop://>: Stream is closed

after around 32 minutes. Is this possibly because of the “nested dict” type of data structure? The collection is not even that large, and most inner dicts/arrays are empty at this point.

I am not sure about the 3rd option, especially because from the articles I read and videos I watched, it seems that these are more useful for distributed computations on large data sets, of which I don’t have plenty. The data structures I am using (and referring to above) are what I have found to be the most convenient data representations of the results that need to be returned by the worker methods. However, I was wondering whether, when using Dask data structures, I could gain an automatic speed up with regards to the data communication overhead. It would be interesting if I can be pointed into the right direction in this regard.

This sounds a bit suboptimal, generating a lot of data transfer between Client/Workers. It would be much better if you could keep the data on worker side as Future objects. But I understand you’ve got to kind of synchronise the results into a single data structure?

This is the way Dask collections work. Partitionning the data and having each worker process one of the partition. But the idea with Dask collections is that the partial result stays on the Workers memory, automatically moving around if another Worker needs it.

I don’t think solution 2 is a good one.

For solution 3, it completly depends on your data structure.

This sounds really high, but maybe because of the static data you use?

If you could somehow fit your data structure into a Dask collection, you would benefit from higher level method and automatic partitionning of your dataset, which could simplify your workflow a lot.

In order to help more, we would probably need an example of your dynamic data structure.

hi @guillaumeeb. thanks for always taking the time to answer me.

I agree with this statement. But I am not sure I have an alternative. Think about it this way, I have worker_method1 and worker_method2. If it were a simple, non-distributed setup, worker_method1 returns a dict, which has to be passed as input to worker_method2. The work can be distributed over multiple workers and each task returns a small subset of the “dict”. However, each worker will return keys that overlap with those coming from other workers, and worker_method2 requires the full key value pairs to function properly. I thought that returning the results to the client and syncing them into the “global” collection is likely the only way. If you have other ideas that would function closer to your concept of keeping the “dict” on the worker side as “Future” objects, it would be very helpful if you could suggest how that might work?

I was considering using a Dask.Array or Dask.Bag for the inputs of worker_method1, because structurally it is likely possible to use them. But logically, how would we implement a worker method, which updates these Dask collections (remotely), and makes the results instantly available to other workers, without ever returning the data back to the client? Or are the Dask collections, by nature, passed around, virtually as Futures (rather than pickled back and forth)?

Could you elaborate on why? Is the “nested dict” potentially to blame for the error I am getting when trying to scatter?

There is plenty static data. However, I don’t think it should be anywhere close to 2gb. I have noticed that with a worker on a smaller machine (slower CPU, 4 cores instead of 10, 2gb of memory instead of 20gb), it used max 1.2gb. This seems to indicate that the rapid task allocation could also be a problem - case in point the smaller machine does not have the capacity to support as many tasks, hence, uses less memory. I am submitting all tasks in quick succession (all 81685 of them). Maybe batching the tasks in the groups, rather than all together, could alleviate the memory usage.

scatter() tries to be smart - when you pass to it dicts or lists, it traverses them looking for dask collections. In your case, this is detrimental. Try wrapping your top-level dict in a UserDict to make it opaque before you pass it to scatter.
broadcast=True is unnecessary - dask will take care of propagating your data as needed.

As @guillaumeeb was mentioning, if not all of your static data is required by all of your tasks, it would be beneficial to break it down into smaller pieces attached to separate futures.

Hi @crusaderky, thanks for your reply.

your idea of using UserDict instead of dict seems to get me a step further.

When using client.scatter with the broadcast=True param, it takes ages to see any tasks in the Task Stream, presumably because it is moving the data to all the workers prior to starting the work. This is still weird though, because I am only passing around 200 mb of data on 3 local workers. It takes so long that I always end up giving up.

When using broadcast=False, the work starts significantly fast. However, in the task stream, I constantly see red marks, pertaining to communication across the workers. And each of these takes very long (6-16 seconds).

When using client.submit and queuing all the tasks at once, the memory usage is excessive, such that running with 4 workers is impossible (as it quickly runs out of memory). I tried an approach which assigns batches of tasks rather than all at the same time, however, this results in a lot of idle workers. When using client.map with all tasks at once, the memory usage is a huge problem, because the program runs out of memory before I get to see any tasks in the Task stream.

Indeed, it seems that passing “partitions” of the data structures back and forth for each task is the only way that is giving me acceptable results. This approach doesn’t pass around futures as parameters to the workers, it simply keeps the data being passed back and forth from each worker as small as possible.

I will now try to understand more on how I can possibly use Dask collections to optimise further.

When you say this, do you mean that Dask collections can be scattered?

Also can multiple Dask collections be used with the same worker method (in a single task) ?

Did some more digging into. Dask collections are not usable for my use case case; fast key based access is a primary aspect of my data structures and none of the Dask collections achieve this is as efficiently as a simple dict. Dask collections are great for applying distributed computations over a large data set, but this is by no means what I am doing.

Indeed, splitting the computation across the available workers seems like the only way. client.submit seems to be faster and when compared to the delayed function, it seems to use considerably less memory.

Initially I had tried an approach whereby the worker method requires less data as input, does less work & hence computes much faster, and returns smaller outputs. This seemed to be the best approach memory wise (both scheduler and worker wise) but with 4 workers, was running much slower than with a single process (no Dask). Eventually, I found out that splitting the computation (and all dependent data) onto all the available workers works best. However, simply assigning work to each node, we would not be making good use of the Dask’s scheduler task scheduling capabilities such as load balancing. If having different resources in the nodes, some nodes might take much longer to finish its task. We can perhaps “estimate” a proportional assignment of work prior to assigning the tasks, but this would still not be utilising the scheduler. When splitting the work into smaller tasks, the bottleneck shifts to the client, because instead of having to process the results for “n” tasks, where “n” is the number of workers, we would have to process 100k+ results from a single thread in the client, which will take ages. Also tried a batching approach, but processing the results in between batches, creates a lot of idle time on the workers.

Do you guys have any ideas to potentially combat these issues?

Hi @jurgencuschieri,

It’s a bit hard to help here, we would need some minimum reproducible and representative example. Do you think you could build one?