Is there a way to use threads / processes exclusively for a code block

Is there a way to specifically process some dask delayed jobs with threads vs processes?

e.g.

@dask.delayed
def plot():
    ...  # matplotlib job that needs processes because matplotlib is not thread safe

@dask.delayed
def image_manip():
    ...  # imageio job that only needs threads because it's I/O bound

Would this work?

with dask.annotate(scheduler="threads"):
    plot()

with dask.annotate(scheduler="processes):
    image_manip()

Here’s a minimal example, but it doesn’t seem to do what I want:

import dask
import logging

# set up logging to console
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
# set a format which is simpler for console use
formatter = logging.Formatter('%(process)s %(threadName)s %(name)s %(message)s')
console.setFormatter(formatter)
# add the handler to the root logger
logging.getLogger('').addHandler(console)
logger = logging.getLogger(__name__)


@dask.delayed
def test(x):

    logger.warning(x)
    return x


with dask.annotate(scheduler="threads"):
    a = test("a")
    b = test("b")

with dask.annotate(scheduler="processes"):
    c = test("c")

jobs = [a, b, c]

dask.compute(jobs)

I expect c to show ProcessPoolExecutor, but get all threadpools using the same process

9287 ThreadPoolExecutor-0_0 __main__ a
9287 ThreadPoolExecutor-0_1 __main__ c
9287 ThreadPoolExecutor-0_2 __main__ b
(['a', 'b', 'c'],)

I think this kind of works, but it requires computing ahead of time separately.

import dask
import logging

# set up logging to console
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
# set a format which is simpler for console use
formatter = logging.Formatter('%(process)s %(threadName)s %(name)s %(message)s')
console.setFormatter(formatter)
# add the handler to the root logger
logging.getLogger('').addHandler(console)
logger = logging.getLogger(__name__)


@dask.delayed
def test(x):

    logger.warning(x)
    return x


with dask.config.set(scheduler='threads'):
    a = test("a")
    b = test("b")
    dask.compute(a, b)

with dask.config.set(scheduler='processes'):
    c = test("c")
    d = test("d")
    dask.compute(c, d)
1389 ThreadPoolExecutor-0_2 __main__ a
1389 ThreadPoolExecutor-0_1 __main__ b
d
c