Bringing a decorator-based approach to dask distributed

Overview

In short, I love Dask distributed, but I also love the decorators from Dask delayed. I have an approach that kind of works, but I’m receiving a PicklingError in one edge case that I’d love insight on.

Setup

Without any decorators, let’s consider the following operations. Obviously, in reality, they would be high-compute tasks.

from dask.distributed import Client

client = Client()

def add(a, b):
    return a + b

def make_more(val):
    return [val] * 3

def add_distributed(vals, c):
    return [add(val, c) for val in vals]

out1 = client.submit(add, 1, 2) # 3
out2 = client.submit(make_more, out1) # [3, 3, 3]
out3 = client.submit(add_distributed, out2, 3) # [6, 6, 6]
out3.result()

I understand that it is possible to be using .map() here instead of .submit() with some refactoring, but in the spirit of decorators, I want to keep the underlying code largely in tact.

Decorator Approach

Defining the Decorator

Let’s define a decorator @remote_execute that is simply the following:

import functools
from dask.distributed import Client

client = Client()

def remote_execute(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        future = client.submit(func, *args, **kwargs)
        return future

    return wrapper

Presumably, in practice I should be passing in the client somehow so I don’t end up instantiating it twice, but that seems solvable.

Trying it Out

Now, we rewrite the above routines:

from dask.distributed import Client

client = Client()

@remote_execute
def add(a, b):
    return a + b

@remote_execute
def make_more(val):
    return [val] * 3

@remote_execute
def add_distributed(vals, c):
    return [add(val, c) for val in vals]

out1 = add(1, 2) # 3
out2 = make_more(out1) # [3, 3, 3]
out3 = add_distributed(out2, 3) # [6, 6, 6]
out3.result()

The Problem

Unfortunately, in running the above example, I get the following traceback:

2023-12-06 20:17:27,910 - distributed.protocol.pickle - ERROR - Failed to serialize <ToPickle: HighLevelGraph with 1 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x1d3b9db68d0>
 0. 2008808932416
>.
Traceback (most recent call last):
  File "c:\Users\asros\miniconda\envs\quacc\Lib\site-packages\distributed\protocol\pickle.py", line 63, in dumps
    result = pickle.dumps(x, **dump_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_pickle.PicklingError: Can't pickle <function add_distributed at 0x000001D3B8BB2DE0>: it's not the same object as __main__.add_distributed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\Users\asros\miniconda\envs\quacc\Lib\site-packages\distributed\protocol\pickle.py", line 68, in dumps
    pickler.dump(x)
_pickle.PicklingError: Can't pickle <function add_distributed at 0x000001D3B8BB2DE0>: it's not the same object as __main__.add_distributed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\Users\asros\miniconda\envs\quacc\Lib\site-packages\distributed\protocol\pickle.py", line 81, in dumps
    result = cloudpickle.dumps(x, **dump_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\asros\miniconda\envs\quacc\Lib\site-packages\cloudpickle\cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "c:\Users\asros\miniconda\envs\quacc\Lib\site-packages\cloudpickle\cloudpickle.py", line 1245, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'TaskStepMethWrapper' object

Any ideas on how to address this? Running a two-step workflow consisting of add then make_more works perfectly, but it’s that third task that really throws a wrench in things.

I’m open to other ideas too if there are better ways to integrate decorators into this approach!

Hi @arosen93,

Before going deeper, I just want to make sure I understand this sentence. Why are you not using Dask Delayed within a distributed cluster?

1 Like

@guillaumeeb: Thanks for your reply! I suppose the honest answer is that I remain quite unsure about the key differences between Dask Delayed and Distributed. I had, perhaps naively, assumed that if I wanted to run calculations in a distributed environment, I would need to do things the “Dask distributed way” as outlined in the docs.

Would you be willing to provide a minimal example of how one might use Dask Delayed within a distributed cluster for a toy calculation? Are there major limitations in using Dask Delayed in this way that I should be aware of?

Thank you greatly!

Agh, for some reason there was a bug when trying to reply to your comment on mobile, and it edited my original comment, which I can no longer change. @guillaumeeb: do you have admin privileges to revert the change back?

This may not be clear in a first approach, but every Dask API or collections, including Delayed, can be used on a Dask Distributed environment. Maybe the Scheduling documentation can help to understand that.

On the Dask Example website, you can see that Delayed is used with a LocalCluster, which is a distributed cluster on a single machine.

There are no limitations on using Delayed on a distributed cluster.

Does that clarify things?

@guillaumeeb: Thanks for the reply! Yes, I think it does!

Just to clarify, the following is completely valid and encouraged in Dask:

from dask import delayed
from dask.distributed import Client

client = Client()

@delayed
def add(a, b):
    return a + b


@delayed
def mult(a, b):
    return a * b


def workflow(a, b, c):  
    return mult(add(a, b), c)


delayed = workflow(1, 2, 3)
result = client.compute(delayed).result()  # 9

Yes it is!

To clarify even more, once a Client object is created in your Python process, every compute call on any Dask connexion will use it by default. So these two lines of code are also valid:

result = dekayed.compute()
result = dask.compute(delayed)

@guillaumeeb — Ah, thank you so much! I was confused about the difference between all these different .compute() methods.

I’m going to summarize my understanding here, both to make sure I got it right and to inform anyone else who might read this. Please let me know if I’m mistaken!

  1. Doing client.compute(delayed) will dispatch the function to the Client. A reference will be returned instantaneously. It can be resolved with .result(), which is blocking since it waits for the calculation to finish.

  2. You can do delayed.compute() as a short-hand if you’re already connected to the Client. This will both dispatch and resolve the Future. So, you don’t call .result() explicitly. However, this command will be blocking.

  3. You can do dask.compute(delayed), which is very similar to Option 2 in that it both dispatches and resolves. However, it returns a slightly different output in the above example: (9,) instead of 9.

Yes, I think this is correct!

With client.compute() and dask.compute(), you can put an iterable of Delayed object in argument, in order to have all of them computed at the same time. It explains also the Tuple object that dask.compute() returns.

1 Like