Accessing TaskState metadata for running tasks

Hi, we’re using Dask distributed to execute some custom, long-running computations on a cluster. The workers spawn a separate process to perform the computation and we want to record the PID of that process, along with the hostname. This is needed for us to perform some troubleshooting if the task is stuck. We see that TaskState has a metadata field (Add TaskState metadata by jrbourbeau · Pull Request #4191 · dask/distributed · GitHub) and would like to leverage that. The issue is that TaskState metadata is only synced to the scheduler after the task finishes processing. Is there a way to sync the task state to the Dask scheduler (and then we can access via Client.run_on_scheduler) while the task is running? Thanks.

1 Like

@dcheng Welcome to Discourse!

Could you please share some more details about use-case, and if possible, a minimal example? How/Why are you accessing the TaskState? It’ll allow us to help you better. :smile:

I’d also suggest looking into Scheduler/Worker plugins, which might be better suited for logging PID: Plugins — Dask.distributed 2022.8.1+6.gc15a10e8 documentation

Thanks for your reply! We’re using Dask distributed to run tasks on a cluster in AWS. Each task will spawn a separate process, and we want to keep track of the process ids on each worker host for debugging purposes. The following code demonstrate what we’re trying to do.

Code
import time
import subprocess
import distributed
from distributed import Client, LocalCluster


def do_some_work(duration: int):
    print('[Worker] doing some work')
    process = subprocess.Popen(["sleep", str(duration)])
    _update_task_state(process.pid)
    while True:
        try:
            process.wait(timeout=10)
            break
        except subprocess.TimeoutExpired:
            pass
    print('[Worker] subprocess is complete')


def _update_task_state(pid: str):
    from distributed.worker import thread_state
    key = thread_state.key
    dask_worker = distributed.get_worker()
    task_state = dask_worker.tasks.get(key)
    task_state.metadata.update({ "pid": pid })
    print(f'[Worker] Saved PID to task state metadata: {task_state.metadata}')


def _get_task_state_metadata(dask_scheduler, key: str):
    ts = dask_scheduler.tasks.get(key)
    if ts:
        return ts.metadata

def main():
    cluster = LocalCluster()
    client = Client(cluster)

    try:
        print(client)
        future = client.submit(do_some_work, 30)
        while future.status not in ['error', 'finished']:
            task_state_metadata = client.run_on_scheduler(_get_task_state_metadata, key=future.key)
            print(f'Waiting for key {future.key} to finish, metadata={task_state_metadata}')
            time.sleep(5)
        print(f'final status: {future}')
    finally:
        client.close()


if __name__ == '__main__':
    main()

On the workers we set the process id of the spawned process into task state, and expect to be able to retrieve the task state from the scheduler host by calling client.run_on_scheduler(). However the retrieved task state doesn’t contain the PID, as shown in the following output.

Output of code
Waiting for key do_some_work-331de1ee9134f5b8ba065406f7a6a655 to finish, metadata={}
[Worker] doing some work
[Worker] Saved PID to task state metadata: {'pid': 32545}
Waiting for key do_some_work-331de1ee9134f5b8ba065406f7a6a655 to finish, metadata={}
Waiting for key do_some_work-331de1ee9134f5b8ba065406f7a6a655 to finish, metadata={}
Waiting for key do_some_work-331de1ee9134f5b8ba065406f7a6a655 to finish, metadata={}
Waiting for key do_some_work-331de1ee9134f5b8ba065406f7a6a655 to finish, metadata={}
Waiting for key do_some_work-331de1ee9134f5b8ba065406f7a6a655 to finish, metadata={}
[Worker] subprocess is complete
final status: <Future: finished, type: NoneType, key: do_some_work-331de1ee9134f5b8ba065406f7a6a655>

Adding to this, what we’re also looking for is a callback/notification when task starts getting executed. Based on the documentation it doesn’t look like the scheduler (Scheduling State — Dask.distributed 2022.8.1 documentation) has that info. It only knows when a task is assigned to a worker but the task might be in the queue and could be stolen by another worker. We could use Worker plugins, and then build a mechanism outside of Dask to notify our application when Dask starts executing a task.