Dask + Pennylane

Hi,

I am trying to use dask distributed for parallelizing the computations in the framework of pennylane. But the issue i am struggling with is caused due to pickle problem, deserialize and object Arraybox. Following is the code,

import pennylane as qml
import matplotlib.pyplot as plt
from pennylane import numpy as np
import dask
import time
import random
from dask.distributed import Client, progress
client = Client(threads_per_worker=1, n_workers=4)
dev = qml.device('lightning.qubit', wires=1, shots=None)
@qml.qnode(dev, diff_method='adjoint')
def circuit(x):
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(wires=0))

image

def expval(x):
    temp = []
    for i in range(5):
        val = dask.delayed(circuit)(x[i])
        temp.append(val)
    expvals = np.array(dask.compute(*temp))
    expvals = np.sum(expvals)
    return expvals
%%time
x = np.random.random(5)
expval(x)

qml.grad(expval, argnum=0)(x)
Errors:

  • Could not serialize object of type ArrayBox.

  • Can’t pickle local object ‘VJPNode.initialize_root..’

  • cannot pickle ‘generator’ object

Can someone please help in this regard of distributed dask usage or show some direction for the alternatives.

Thanks a lot!!!

Hi @Abhi, welcome to this Forum!

You might want to go through the following documentation pages:
https://distributed.dask.org/en/stable/serialization.html

https://distributed.dask.org/en/stable/protocol.html

In order to use some distributed or multi-process libraries like Dask, you need to send object and data back and forth between processes, which means serializing them. In order to serialize functions, Dask rely on pickle. Your objects need to be serializable if you want to use Dask distributed. You can check the serialization problem with some code snippet in Problems with object pickle - #6 by scharlottej13.

Hi @guillaumeeb , thank you for the help. I am trying but it is hard to understand serialization and deserialization in the documentation especially for a function. Can you try with above code so that you can help in a better way.

Thanks

I’m currently not able to reproduce your example. Even though I installed pennylane package, the line:

dev = qml.device('lightning.qubit', wires=1, shots=None)

is throwing an exception:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 dev = qml.device('lightning.qubit', wires=1, shots=None)
      2 @qml.qnode(dev, diff_method='adjoint')
      3 def circuit(x):
      4     qml.RX(x, wires=0)

File ~/miniconda3/envs/coiled/lib/python3.9/site-packages/pennylane/__init__.py:306, in device(name, *args, **kwargs)
    300     raise DeviceError(
    301         f"The {name} plugin requires PennyLane versions {plugin_device_class.pennylane_requires}, "
    302         f"however PennyLane version {__version__} is installed."
    303     )
    305 # Construct the device
--> 306 dev = plugin_device_class(*args, **options)
    308 # Once the device is constructed, we set its custom expansion function if
    309 # any custom decompositions were specified.
    310 if custom_decomps is not None:

File ~/miniconda3/envs/coiled/lib/python3.9/site-packages/pennylane_lightning/lightning_qubit.py:116, in LightningQubit.__init__(self, wires, c_dtype, shots, batch_obs)
    114 else:
    115     raise TypeError(f"Unsupported complex Type: {c_dtype}")
--> 116 super().__init__(wires, r_dtype=r_dtype, c_dtype=c_dtype, shots=shots)
    117 self._batch_obs = batch_obs

TypeError: __init__() got an unexpected keyword argument 'r_dtype'

As I really don’t know what I’m doing, this is hard to debug. Probably some Python env problem.

Anyway, if one object cannot be pickled, you’ve either got to change it so it can, or provide some custom method, see pickle — Python object serialization — Python 3.11.2 documentation.

Hi @guillaumeeb , Thanks a lot for your help.

To reproduce the example, please run the following lines in your notebook.

!pip install pennylane --upgrade
!pip install pennylane-lightning --upgrade
import pennylane as qml
import matplotlib.pyplot as plt
from pennylane import numpy as np
import dask
import time
import random
from dask.distributed.protocol import serialize, deserialize

qml.about()

from dask.distributed import Client, progress
client = Client(threads_per_worker=1, n_workers=4)
client
dev = qml.device('lightning.qubit', wires=1)
@qml.qnode(dev, diff_method='adjoint', max_diff=2)
def circuit(x, y):
    qml.RX(x, wires=0)
    qml.RY(y, wires=0)
    return qml.expval(qml.PauliZ(wires=0))
def dask_cost(x, y):
    temp = []
    for i in range(5):
        val = dask.delayed(circuit)(x[i], y)
        temp.append(val)
    expvals = dask.compute(*temp)
    expvals = np.sum(np.array(expvals))
    return expvals

def normal_cost(x, y):
    temp = []
    for i in range(5):
        val = circuit(x[i], y)
        temp.append(val)
    expvals = np.sum(np.array(temp))
    return expvals
x = np.array(np.random.random(5))
y = 2.0
dask_cost(x, y) 
normal_cost(x, y) # This output should match the above output
qml.grad(normal_cost, argnum=1)(x, y) # This line is without using dask.
qml.grad(dask_cost, argnum=1)(x, y) # This is where the error occured. 

Error :

  1. Could not serialize object of type ArrayBox.
  2. Failed to deserialize,
  3. Cannot pickle ‘generator’ object

You are right, we need to change the object to something the python can pickle. I am trying and thanks for your help. If you have tried out and got something please let me know.

I’m able to reproduce the issue. But after a few minutes investigating, I think I won’t get far until really getting into pennylane and autograd packages code source, which will be really time consuming, and I can’t do that.

But just to understand, you’re wrapping a Dask call inside another function qml.grad: what is this function doing? Is really each circuit call (in a real use case) really worth launching in parallel using Dask? It’s a little hard to follow without knowing the subject and libraries at all.

Hi @guillaumeeb . Thanks a lot for your effort. I too traced the error back to autograd source code for pickling problem. Even I am questioning my self how dask can be applied to speed up the process. My understanding for using dask is as follows,

  • Circuit function will be called for multiple inputs, which are independent of each other. Here i am succesful to parallelize using dask as you saw in code above.

  • All these circuit outputs are merged to give one scalar output.

  • A grad function (from autograd like you saw in error) is nothing but derivative of above output with respect to circuit function inputs. This grad function will basically help to give new updated inputs for circuit. A well known usage in Machine learning if you know it.

  • These three steps are repeated again and again for 1000s of times.

So from my understanding, can you say if my usage of dask will be useful or not ? I am only targetting of dask ability to parallelize the computations. If you say, its not worth it, then I will stop and close this post. Since I am very new user of dask, this is all I was able to understand untill now. Please let me know your suggestion.

Also I have noticed that Dask is used for machine learning(Dask-ML). And that is exactly what this simple code is doing. This is also another close point that I could able to link the usage of dask.

What I understand from you example and from your explanation is that the Circuit function has a quite short execution time, is that right? You’ve got to be careful with introducing Dask with really short computation, the cost of distributing work may be higher than the benefits of parallelization.

But this looks like it’s an iterative algorithm, you cannot compute several batch of circuits inputs in parallel, can you?

Well, it depends how much time take one single computation. If its above 1s, it’s probably worth it, if it less, maybe not.

Hi @guillaumeeb. The example shown is very simple one and easy to explain. But in real use cases, the circuit function executions are much time taking for multiple inputs. You are right, this is an iterative algorithm. Each iteration consisting of requirment for circuit executions with multiple inputs. So my hope was to speed up process by applying parallelism inside iterative algorithm. And I got stuck with the pickle problem.

If it’s above 1s or so, it’s worth continuing your effort!

Hey @Abhi! You might want to check out the PennyLane forum:

:smiley:

Thank you @isaacdevlugt. I got that :grinning: