Dask distributed callbacks

Hi there!

First of all, thanks for your incredible work. I have been using Dask (array) for a while for image processing, and the speed improvements is just great, and the power of working with big images… yeah, just great.

I have a question. In my workflow, one of the steps is to segmentate an image using a Keras model. The problem is that I need to get the time elapsed on that step, and the log must be done in standard output (as the process is executed in a docker container, without any chances of returning a web page, or connecting to dask dashboard while it is running). I had though in using callbacks for filtering and measuring times of the appropriate tasks, but it seems that it does not work for distributed computations.

Is there any way of emulating the callbacks in distributed processing? Or getting the time report for every task in standard output instead of using a web page?

Thank you!!!

Luis

Hi there!

Eventually, I was able to do it using scheduler callbackshttps://distributed.dask.org/en/stable/plugins.html and checking transitions between the states “waiting”, “processing” and “memory”. Output to stdout is done in the restart method (in my case, I guess this is not something usable by everyone).

Anyway, other ideas are welcome :slight_smile:

Thanks!

Luis

@lhcorralo Welcome!

Eventually, I was able to do it using scheduler callbacks

This does seem like the best way :smile:

If you’re comfortable, it’ll be super helpful for future users if you can share your plugin here!

Just for reference, here is something similar that @ian wrote: task-group-statistics-plugin.py · GitHub

Sure!

The code is like this

from dask.distributed.diagnostics.plugin import SchedulerPlugin

class TaskCounter(SchedulerPlugin):
    def __init__(self, task_name):
        self.started = 0
        self.finished = 0
        self.task_name = task_name
        self.current_tasks = {}
        self.finished_tasks = {}

    def transition(self, key: str, start, finish, *args, **kwargs):
        # Uncomment this in your workflow for getting all the tasks
        # logging.info(f"Key: {key}")
        if key.startswith(self.task_name):
            if start == 'waiting' and finish == 'processing':
                self.current_tasks[key] = time.time()
                self.started += 1
            if start == 'processing' and finish == 'memory':
                self.finished += 1
                start_time = self.current_tasks[key]
                self.finished_tasks[key] = time.time() - start_time

    def _show_stats(self):
        logging.info(f"Analyzing tasks starting with {self.task_name}")
        logging.info(f"Captured a total of {self.finished} tasks")
        total_time = 0.0
        for task_time in self.finished_tasks.values():
            total_time += task_time
        logging.info(f"Tasks took a total of {total_time} seconds")

    def restart(self, scheduler):
        self._show_stats()
        self.started = 0
        self.finished = 0
        self.current_tasks = {}
        self.finished_tasks = {}

    def close(self):
        # If the cluster were not restarted, maybe this where a good place for self._show_stats()
        pass

And it is registered with

tc = TaskCounter("('Segmentation")
client.register_scheduler_plugin(tc, "taskCounter")

Note that tc is not updated, that is, tc.started is 0 after calling compute(), so I guess a copy is made somewhere (and because of that I have to do the log in the restart method, which is working 'cause I retart the client). Also note that the parameter key in transition is a string, despite looking like a tuple. Because of that, TaskCounter filters a task starting with pharentesis and quotes.

I hope this helps to anybody :slight_smile:

Luis

1 Like

BTW, the code you shared is much better, cleaner… so if anybody is reading this, forget my code; it is a bunch of workarounds. Go to pavithraes’ one!

1 Like

Hi!

I have a problem using @ian’s code (task-group-statistics-plugin), maybe you can help

My plugin now looks as following

class TaskCounter(SchedulerPlugin):
    def __init__(self, task_names: List):
        self.task_names = task_names
        self.groups = {}
        self.scheduler = None

    async def get_task_stats(self, comm):
        df = pd.DataFrame.from_dict(self.groups, orient="index")
        return df

    def start(self, scheduler):
        """Called on scheduler start as well as on registration time"""
        self.scheduler = scheduler
        scheduler.handlers["get_task_stats"] = self.get_task_stats

    def transition(self, key: str, start, finish, *args, **kwargs):
        # Uncomment this in your workflow for getting all the tasks
        # logging.info(f"Key: {key}")

        prefix_name = key_split(key)
        group_name = key_split_group(key)

        if prefix_name in self.task_names:
            if start == 'processing' and finish == 'memory':
                # Add the stats
                if group_name not in self.groups:
                    self.groups[group_name] = {}

                group = self.scheduler.task_groups[group_name]

                self.groups[group_name]["prefix"] = prefix_name
                self.groups[group_name]["duration"] = group.duration
                self.groups[group_name]["start"] = str(datetime.fromtimestamp(group.start))
                self.groups[group_name]["stop"] = str(datetime.fromtimestamp(group.stop))
                self.groups[group_name]["nbytes"] = group.nbytes_total

    def restart(self, scheduler):
        self.groups = {}

Mostly like Ian’s one. The problem is that, weh I want to recover the stats (after compute() has finished, but before client.close() is called), I run

stats_data = client.sync(client.scheduler.get_task_stats)

And I get the following error stack

tornado.application - ERROR - Exception in callback functools.partial(<bound method IOLoop._discard_future_result of <tornado.platform.asyncio.AsyncIOLoop object at 0x7f7a19154a30>>, <Task finished name='Task-47' coro=<Server.handle_comm() done, defined at /opt/venv/lib/python3.8/site-packages/distributed/core.py:433> exception=ValueError('The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().')>)
Traceback (most recent call last):
  File "/opt/venv/lib/python3.8/site-packages/tornado/ioloop.py", line 741, in _run_callback
    ret = callback()
  File "/opt/venv/lib/python3.8/site-packages/tornado/ioloop.py", line 765, in _discard_future_result
    future.result()
  File "/opt/venv/lib/python3.8/site-packages/distributed/core.py", line 537, in handle_comm
    if reply and result != Status.dont_reply:
  File "/opt/venv/lib/python3.8/site-packages/pandas/core/generic.py", line 1535, in __nonzero__
    raise ValueError(
ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/project/dl_oneo_gstp_crops/crops_yearly.py", line 188, in <module>
    main()
  File "/opt/project/dl_oneo_gstp_crops/crops_yearly.py", line 154, in main
    stats_data = client.sync(client.scheduler.get_task_stats)
  File "/opt/venv/lib/python3.8/site-packages/distributed/utils.py", line 309, in sync
    return sync(
  File "/opt/venv/lib/python3.8/site-packages/distributed/utils.py", line 376, in sync
    raise exc.with_traceback(tb)
  File "/opt/venv/lib/python3.8/site-packages/distributed/utils.py", line 349, in f
    result = yield future
  File "/opt/venv/lib/python3.8/site-packages/tornado/gen.py", line 762, in run
    value = future.result()
  File "/opt/venv/lib/python3.8/site-packages/distributed/core.py", line 905, in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
  File "/opt/venv/lib/python3.8/site-packages/distributed/core.py", line 674, in send_recv
    response = await comm.read(deserializers=deserializers)
  File "/opt/venv/lib/python3.8/site-packages/distributed/comm/inproc.py", line 199, in read
    raise CommClosedError()
distributed.comm.core.CommClosedError

I made a workaround overloading the remove_client and the restart methods for showing the stats in that moment (as in my first code), but that solution is much less “polite”, as I cannot poll for the stats exactly when I want :frowning:

Thank you!