Hi
Related to this post:
Specify that a given task use huge amount of RAM to the Dask Ressource Manager.
I’m using ONNX in my workers to compute a deep learning forward, this is done via the instantiation of an ONNX Runtime InferenceSession which is a hardawre specific object (hence could not be serialize and share as a pickle).
Currently this object is instantiate in my forward_onnx
function (which is run in each task, for each chunk), so a new instantiation of InferenceSession is done at each task run. I suppose this could lead to an extensive use of memory and time overhead.
According to the Dask Distributed librarie the usual way of defining a worker side object is to define a Worker Plugins.
To do so I defined this WorkerInferenceSessionPlugin
class:
from dask.distributed import WorkerPlugin, Worker
class WorkerInferenceSessionPlugin(WorkerPlugin):
async def setup(self, worker: Worker):
import onnxruntime as ort
model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider'])
worker.data["model_session"] = model_session
That I register on my client:
inference_session_plugin = WorkerInferenceSessionPlugin()
client = Client(cluster)
client.register_worker_plugin(inference_session_plugin)
And then use reuse it in my function instead of declaring a new one:
def foward_onnx(...):
...
worker = get_worker()
model_session = worker.data["model_session"]
...
However as my code is quite unstable and I’m facing memory issues I usally get many restarting of worker. I was expecting my model_session
to be reinstatiate at the restart of my workers, however. I get ERROR - Failed to pickle 'model_session'
.
1 Like
the inner workings of the WorkerPlugin are not clear in my opinion, especially as to what runs where. but since you’re getting a pickling error, it seems to indicate that the setup is not really being called on the worker side. I have tried a similar setup to yours, using a simpler data structure. however, the key that is registered into worker.data in the “setup” method of the WorkerPlugin, returns a key error when I try to use it from the worker method. did you get around to fixing this and if so, what approach did you end up using?
Hi !
I haven’t managed to configure the starting of my inferencer session (the ort.InferenceSession) using the WorkerPlugin API. I’ve had to resign myself to open a new session in each task, which is infinitely more expensive
1 Like
sorry to hear that. The WorkerPlugin should be perfect for your use case, but I am not sure it is working as intended. thanks for getting back to me, anyway.
1 Like
Did you mean client.register_worker_callbacks
?
Something like that:
def setup(dask_worker: Worker):
import onnxruntime as ort
model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider'])
dask_worker.data["model_session"] = model_session
client.register_worker_callbacks(setup)
I don’t know if it work since my post is already 1 year old we did find an alternative solution at the time, although it resulted in reduced usage of Dask .
But I remember that register_worker_callbacks
was deprecated or at least going to be deprecated at the time, and that WorkerPlugin as the recommended way to do.
Maybe @jurgencuschieri can test this on their setup? It could benefit others who come across this issue.
1 Like
hi @guillaumeeb and @julien-blanchon , the register_worker_callbacks method seems to work, and from what I can see it is being computed remotely on the worker (as I want). not sure why there is an open issue to mark this as deprecated in favour of register_worker_plugin.
1 Like
Nice ! I wish I’d had the answer sooner . Thanks @guillaumeeb for the hint. Maybe @jurgencuschieri you can post a message on Github to ask why ?
Anyway I can finally mark this a resolved