Initialize worker side object without serialization

Hi :wave:

Related to this post:
Specify that a given task use huge amount of RAM to the Dask Ressource Manager.

I’m using ONNX in my workers to compute a deep learning forward, this is done via the instantiation of an ONNX Runtime InferenceSession which is a hardawre specific object (hence could not be serialize and share as a pickle).

Currently this object is instantiate in my forward_onnx function (which is run in each task, for each chunk), so a new instantiation of InferenceSession is done at each task run. I suppose this could lead to an extensive use of memory and time overhead.

According to the Dask Distributed librarie the usual way of defining a worker side object is to define a Worker Plugins.

To do so I defined this WorkerInferenceSessionPlugin class:

from dask.distributed import WorkerPlugin, Worker

class WorkerInferenceSessionPlugin(WorkerPlugin):    
    async def setup(self, worker: Worker):
        import onnxruntime as ort
        model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider'])
        worker.data["model_session"] = model_session

That I register on my client:

inference_session_plugin = WorkerInferenceSessionPlugin()
client = Client(cluster)
client.register_worker_plugin(inference_session_plugin)

And then use reuse it in my function instead of declaring a new one:

def foward_onnx(...):
    ...
    worker = get_worker()
    model_session = worker.data["model_session"]
    ...

However as my code is quite unstable and I’m facing memory issues I usally get many restarting of worker. I was expecting my model_session to be reinstatiate at the restart of my workers, however. I get ERROR - Failed to pickle 'model_session'.

1 Like

the inner workings of the WorkerPlugin are not clear in my opinion, especially as to what runs where. but since you’re getting a pickling error, it seems to indicate that the setup is not really being called on the worker side. I have tried a similar setup to yours, using a simpler data structure. however, the key that is registered into worker.data in the “setup” method of the WorkerPlugin, returns a key error when I try to use it from the worker method. did you get around to fixing this and if so, what approach did you end up using?

Hi !

I haven’t managed to configure the starting of my inferencer session (the ort.InferenceSession) using the WorkerPlugin API. I’ve had to resign myself to open a new session in each task, which is infinitely more expensive :frowning:

1 Like

sorry to hear that. The WorkerPlugin should be perfect for your use case, but I am not sure it is working as intended. thanks for getting back to me, anyway.

1 Like

Did you try using API — Dask.distributed 2023.9.2+2.g1650ceb documentation?

1 Like

Did you mean client.register_worker_callbacks?

Something like that:

def setup(dask_worker: Worker):
   import onnxruntime as ort
   model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider'])
   dask_worker.data["model_session"] = model_session

client.register_worker_callbacks(setup)

I don’t know if it work since my post is already 1 year old we did find an alternative solution at the time, although it resulted in reduced usage of Dask :pensive:.

But I remember that register_worker_callbacks was deprecated or at least going to be deprecated at the time, and that WorkerPlugin as the recommended way to do.

Maybe @jurgencuschieri can test this on their setup? It could benefit others who come across this issue.

1 Like

hi @guillaumeeb and @julien-blanchon , the register_worker_callbacks method seems to work, and from what I can see it is being computed remotely on the worker (as I want). not sure why there is an open issue to mark this as deprecated in favour of register_worker_plugin.

1 Like

Nice ! I wish I’d had the answer sooner :laughing:. Thanks @guillaumeeb for the hint. Maybe @jurgencuschieri you can post a message on Github to ask why ?
Anyway I can finally mark this a resolved :blush: