Dask queue : producer consumer

Hello, :smile:

I’m a data scientist/data engineer working on a specific workflow where I need to process a huge amount of documents.
The basic idea is to open a huge amount of documents. Ech document is a collection of an arbitrary number of pages.
We then apply a function on each page. The time to process each page is arbitrary and we have a lot more pages than documents ( 20M pages for 400 000 documents)

I tried using Dask for this pipeline and tried different ways of writing this pipeline and I settled on the task in tasks design pattern :

from time import sleep
from distributed import LocalCluster, Client
from distributed import Queue

def ocr_image(page):
    timeDelay = random.randrange(1,10)
    sleep(timeDelay)  # simulate actual ocr work
    return "this is ocr"

def load_pages(doc):
    # simulate open file
    n = random.randint(1,5)
    n = 30
    with worker_client() as client: 
        for page in range(n):
            future_ocr = client.submit(ocr_image,page,pure=False)
    return futures

def main(): 
    # Load and submit tasks
    loaders= [ client.submit(load_pages,doc,pure=False) for doc in filenames]
    res_loaders = client.gather(loaders)
    res_ocr = client.gather(list(chain.from_iterable(res_loaders)))
    return res_ocr

The issue with this approach is having to schedule a LOT of small tasks, so i thought about batching but the issue here is the arbitrary number of pages in a document. ( 1 pages to 40000 !)

The ‘saner’ approach would be to have a distributed producer/consumer architecture with a queue of pages that we can consume.
Luckily I found the distribute.Queue class but it has some issues ( I know it’s experimental right now ).

def batch_ocr_image():
    #  You can't have batch size and timeout 
    #  pages = [ q.get(timeout='1s') for _ in range(batch_size)]
    pages = q.get(batch_size)
    for _ in range(batch_size) :
        timeDelay = random.randrange(1,10)
        sleep(timeDelay) # simulate actual ocr work
    return ["this is ocr"]*batch_size

def ocr_image():
    page = q.get(timeout='1s')
    timeDelay = random.randrange(1,10)
    return "this is ocr"

def load_pages(doc):
    # simulate open file
    n = random.randint(1,5)
    n = 10
    for page in range(n):
    return n

def main(): 
    ## Load pages in queue
    loaders= [ client.submit(load_pages,doc,pure=False) for doc in filenames]

    # Sync 1 : Gather loaders
    # approach 1 : wait for all loaders to finish res_loaders = client.gather(loaders)
    # approach 2 : wait for the first and then submit
    loaders = wait(loaders,return_when='FIRST_COMPLETED')

    ## Batching
    # Batching is very hard : q.qsize() will fail here
    consumers = [client.submit(batch_ocr_image,pure=False,retries=4) 
                 for _ in range(q.qsize()//batch_size)]
    # Sync 2 : to consume queue 
    res_consumer = client.gather(consumers)
    return loaders, res_consumer

I might miss something about how to correctly implement the producer/consumer using distributed. I have just submitted a feature request for some missing methods on the queue
Thanks a lot for your help an guidance !

@AmineDiro Welcome to discourse and thanks for this question!

I see @crusaderky answered your question on GitHub, just noting here for reference.

Hello @pavithraes,

Yes but the I have had some issues with resolving the delayed objects from the aggregate_ocr_results . The .compute() never computes the actual ocr and return a List[Delayed] ? Is there something I’m missing ?

CHUNK_SIZE = 1000  # pages processed by a single task

def parse_document(path: str) -> list[Image]:
    # Load document from disk and return one raw image per page

def ocr_page(page: Image) -> OCRExitStatus:
    # Run a single page through OCR, dump output to disk, and return metadata / useful info

def ocr_pages(pages: list[Image]) -> list[OCRExitStatus]:
    return [ocr_page(page) for page in pages]

def aggregate_ocr_results(*chunk_results: list[OCRExitStatus]) -> list[OCRExitStatus]:
    return [r for chunk in chunk_results for r in chunk]

def ocr_document(doc_path: str):
    raw_pages = parse_document(path)
    client = distributed.get_client()
    chunks = client.scatter(
            raw_pages[i: i + CHUNK_SIZE]
            for i in range(0, len(raw_pages), CHUNK_SIZE)
    return aggregate_ocr_results(ocr_pages(chunk) for chunk in chunks)

client = distributed.Client()
all_results = aggregate_ocr_results(ocr_document(path) for path in paths)
all_results.compute()  # returns list[OCRExitStatus]

@AmineDiro I see @crusaderky mentioned that this code wouldn’t work – Have you tried the second solution proposed there?

Also, on a general note, you can compute lists of delayed with dask.compute(*list_of_delayeds).

1 Like

@AmineDiro Looks like the discussion is continuing on GitHub, so I’ll go ahead and mark this Discourse thread as resolved to avoid duplication. Please feel free to add comments if required though!