Related to this post:
Specify that a given task use huge amount of RAM to the Dask Ressource Manager.
I’m using ONNX in my workers to compute a deep learning forward, this is done via the instantiation of an ONNX Runtime InferenceSession which is a hardawre specific object (hence could not be serialize and share as a pickle).
Currently this object is instantiate in my
forward_onnx function (which is run in each task, for each chunk), so a new instantiation of InferenceSession is done at each task run. I suppose this could lead to an extensive use of memory and time overhead.
According to the Dask Distributed librarie the usual way of defining a worker side object is to define a Worker Plugins.
To do so I defined this
from dask.distributed import WorkerPlugin, Worker class WorkerInferenceSessionPlugin(WorkerPlugin): async def setup(self, worker: Worker): import onnxruntime as ort model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider']) worker.data["model_session"] = model_session
That I register on my client:
inference_session_plugin = WorkerInferenceSessionPlugin() client = Client(cluster) client.register_worker_plugin(inference_session_plugin)
And then use reuse it in my function instead of declaring a new one:
def foward_onnx(...): ... worker = get_worker() model_session = worker.data["model_session"] ...
However as my code is quite unstable and I’m facing memory issues I usally get many restarting of worker. I was expecting my
model_session to be reinstatiate at the restart of my workers, however. I get
ERROR - Failed to pickle 'model_session'.