Is it possible to map coroutines in a dask bag?

I was wondering if it is possible to map async functions onto a dask bag, something like:

import dask.bag as bag

async def sqr(n):
    return n**2

bag.from_sequence(range(10)).map(square).compute()

The above script will raise the following:

/tmp/foo/env/lib/python3.8/site-packages/dask/local.py:226: RuntimeWarning: coroutine 'sqr' was never awaited
  failed = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/tmp/foo/env/lib/python3.8/site-packages/dask/local.py:226: RuntimeWarning: coroutine 'sqr' was never awaited
  failed = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/tmp/foo/env/lib/python3.8/site-packages/dask/base.py", line 288, in compute
    (result,) = compute(self, traverse=False, **kwargs)
  File "/tmp/foo/env/lib/python3.8/site-packages/dask/base.py", line 571, in compute
    results = schedule(dsk, keys, **kwargs)
  File "/tmp/foo/env/lib/python3.8/site-packages/dask/multiprocessing.py", line 219, in get
    result = get_async(
  File "/tmp/foo/env/lib/python3.8/site-packages/dask/local.py", line 507, in get_async
    raise_exception(exc, tb)
  File "/tmp/foo/env/lib/python3.8/site-packages/dask/local.py", line 315, in reraise
    raise exc
  File "/tmp/foo/env/lib/python3.8/site-packages/dask/local.py", line 222, in execute_task
    result = dumps((result, id))
  File "/tmp/foo/env/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/tmp/foo/env/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'coroutine' object

Thanks for the question, @ian.liu88!

That’s not possible in Dask Bag as of today. But it does makes sense for the collections to be able to handle async fucntions, would you like to open a feature request on Dask’s issue tracker for this?

Support for async tasks was added for client.submit/map recently, you can maybe use that:

from dask.distributed import Client
client = await Client(asynchronous=True)

def sqr(n):
    return n**2

futures = client.map(sqr, range(10))
result = await client.gather(futures)
result

cc @scharlottej13

1 Like