I have a dask pipeline which has a few hundred stages of operation. I cache periodically some of the dask collections generated by the pipeline. I would like the pipeline to resume from this persisted cache when the corresponding tasks are re-executed e.g. after a node failure.
One mechanism to achieve this is to perform a cache check in the UDF. However this mechanism leads to execution of all the chain of UDFs leading to the desired task where the cache if available is fetched.
Despite its inefficiency this at least works when I have a dask pipeline which is composed of simple chains. Whenever there is a more complicated flow, e.g. there is reduce operation or a non-deterministic operation, this UDF based cache check can lead to correctness/consistency issues.
I was curious if there are more standard mechanisms to add resume capabilities in dask.
What kind of mechanism are you using, are you caching collections on disk?
Could you give a quick code snippet to illustrate what you are doing?
I’m not sure if I’m following correctly. You are checking for result availability from inside a UDF, so inside a function called by map_blocks for example if working on a Dask Array? Couldn’t you check the cache externaly in your pipeline, so before calling map_blocks?
It is just an iterative algorithm with the same set of transforms applied several times. Hence the size of transform chain.
I just save the distributed collection using the to_parquet methods available.
In the below task graph, where the squares represent the tasks, I am performing a “caching” stage on the dask.Bag at a point. Now if there is a node failure (e.g. square in red) dask restarts from the first task in that chain only for this partition. I would like to implement behavior where it just retraces steps from the last cached output.
When UDF specific result availability is being checked we do retrace all the steps from the beginning but just avoid the expensive compute.
e.g.
def _my_udf(some_input, cache):
if cache.has_result(some_input):
return cache.get_result(some_input)
# expensive compute related to the input
IIUC you are suggesting something like
def _my_iterative_dask_pipeline(input:dask.Bag, num_iterations:int, cache):
prev_iteration_output = input
for i in range(num_iterations):
if cache.has_results(prev_iteration_output):
iteration_output = cache.get_results(prev_iteration_output)
continue
iteration_output = _transformation(prev_iteration_output)
# Store the outputs
cache.store(iteration_output)
prev_iteration_output = iteration_output
Tasks residing on failed nodes are individually rescheduled by the distributed scheduler on new nodes and the scheduler executes all the tasks in the chain on the new node, and I would like it to start from the last stored cache location.
For sure there would be a cache miss cache.has_results(prev_iteration_output) as the top level iteration is not being re-executed after node failure but only the tasks related to an individual partition.
I don’t think there is any standard mechanism to to this kind of cache/snapshot of intermediate results in Dask. It is often considered that re-computation is cheaper than disk resources in Big Data scenarios. However I understand there are some cases like yours where this could be useful.
You could try to do something with persist() and replicate(), but this means duplicating the data into worker memory, so if results are large, this is probably not possible.