Initialize worker side object without serialization

Hi :wave:

Related to this post:
Specify that a given task use huge amount of RAM to the Dask Ressource Manager.

I’m using ONNX in my workers to compute a deep learning forward, this is done via the instantiation of an ONNX Runtime InferenceSession which is a hardawre specific object (hence could not be serialize and share as a pickle).

Currently this object is instantiate in my forward_onnx function (which is run in each task, for each chunk), so a new instantiation of InferenceSession is done at each task run. I suppose this could lead to an extensive use of memory and time overhead.

According to the Dask Distributed librarie the usual way of defining a worker side object is to define a Worker Plugins.

To do so I defined this WorkerInferenceSessionPlugin class:

from dask.distributed import WorkerPlugin, Worker

class WorkerInferenceSessionPlugin(WorkerPlugin):    
    async def setup(self, worker: Worker):
        import onnxruntime as ort
        model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider'])
        worker.data["model_session"] = model_session

That I register on my client:

inference_session_plugin = WorkerInferenceSessionPlugin()
client = Client(cluster)
client.register_worker_plugin(inference_session_plugin)

And then use reuse it in my function instead of declaring a new one:

def foward_onnx(...):
    ...
    worker = get_worker()
    model_session = worker.data["model_session"]
    ...

However as my code is quite unstable and I’m facing memory issues I usally get many restarting of worker. I was expecting my model_session to be reinstatiate at the restart of my workers, however. I get ERROR - Failed to pickle 'model_session'.

1 Like