How to cache an object instance per worker

I am trying to generate a bunch of images through selenium webdriver across four workers/processes.

Right now, I am restarting the webdriver instance per task, but that is highly inefficient (10000 tasks).

Instead, I’m thinking that upon initiating per worker, each will start their own webdriver instance, and later tasks, reuse the premade webdriver instance. Wondering if that’s possible?

I printed the IDs to confirm that these are the same workers.

Worker-b35541a2-7546-4909-bed9-477e28daa9f0
Worker-731ecc29-5eb8-43ba-9fba-710dc6ec8ca6
Worker-b35541a2-7546-4909-bed9-477e28daa9f0
Worker-20c44df1-6fb9-4f86-a296-a93baeb825f8
Worker-0f7788a7-4fdb-4abb-bb03-e1d629068fde
Worker-20c44df1-6fb9-4f86-a296-a93baeb825f8
Worker-731ecc29-5eb8-43ba-9fba-710dc6ec8ca6
Worker-0f7788a7-4fdb-4abb-bb03-e1d629068fde
Worker-b35541a2-7546-4909-bed9-477e28daa9f0
Worker-b35541a2-7546-4909-bed9-477e28daa9f0
Worker-0f7788a7-4fdb-4abb-bb03-e1d629068fde
Worker-731ecc29-5eb8-43ba-9fba-710dc6ec8ca6
Worker-20c44df1-6fb9-4f86-a296-a93baeb825f8
Worker-20c44df1-6fb9-4f86-a296-a93baeb825f8
Worker-731ecc29-5eb8-43ba-9fba-710dc6ec8ca6
Worker-0f7788a7-4fdb-4abb-bb03-e1d629068fde

Here’s some pseudo code that I envision:

import dask
from distributed import Client
from selenium.webdriver.chrome.webdriver import WebDriver

def task(i):
    if "webdriver" in dask.worker.cache:
        webdriver = dask.worker.cache["webdriver"]
    else:
        webdriver = WebDriver()
        dask.worker.cache["webdriver"] = webdriver
    
    # use webdriver to do things...

client = Client()
client.map(task, range(1000))

Hmmm I might be understanding how workers operate incorrectly…

Would have expected cache to be mutated and have up to assert len(cache) == num_workers

import dask.distributed


def test(i):
    worker_id = dask.distributed.get_worker().id
    if worker_id not in cache:
        cache[worker_id] = i
    else:
        i = cache[worker_id]
    print(cache)

cache = {}
client = dask.distributed.Client()
client.gather(client.map(test, range(100)))

Reading this, and seems promising
Get data local to each worker. · Issue #5331 · dask/dask (github.com)

import dask.distributed


def test(i):
    worker = dask.distributed.get_worker()
    with worker._lock:
        if not hasattr(worker, "_i"):
            worker._i = i
    return worker._i

cache = {}
client = dask.distributed.Client()
client.gather(client.map(test, range(10)))

Outputs:
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1]

1 Like

Okay got it to work with a lock :slight_smile:

Support processes in rendering HoloViews by ahuang11 · Pull Request #47 · ahuang11/streamjoy (github.com)

Hi @ahuang11, welcome to this community!

Nothing to add, but I’m glad you found a solution, and it’s always nice to have people that self solve their issues!