What is the best and most flexible way to optimize a task graph that consists of Arrays and Delayed tasks (or any other dask object’s tasks)?
I’m running into an old issue that I’ve now realized is affecting me more than I previously thought. I maintain a library that lets users do a lot of dask Array operations. Some of the Array operations including writing to disk can only be done currently as a Delayed function. The problem is that when a user goes to compute via result.compute() or dask.compute(result1, result2) any final result that ends as a Delayed object is optimized (obj.__dask_optimize__) as a Delayed object and the underlying Array tasks are not fully optimized. If the final result is an Array object then there is no problem. I’ve run into this problem with da.store before (see da.store loses dependency information · Issue #8380 · dask/dask · GitHub and Remove source and target optimization in array.store by djhoese · Pull Request #9732 · dask/dask · GitHub) where I discovered the Array optimization is done “manually” in the da.store function. In my simple tests, if I force Array graph optimization on my Delayed results I see a 35+% improvement in execution time.
So, what possible solutions or workarounds are there?
Update dask to always perform all types of graph optimizations for all object types (Delayed, Array, DataFrame, etc).
Somehow detect that a Delayed object consists of Array or Array-like inputs. Note optimization should only be done at compute time to ensure shared tasks between multiple results are optimized together.
Document (in my library and in dask preferably) and provide utility functions in my library for overriding the Delayed optimization function before computing results (ex. dask.config.set(delayed_optimize=array_and_delayed_optimize).
Other ideas?
Option 1 and 2 are the only options I see that could allow completely fixing #8380 and the related PR since da.store should really not be optimizing the graph “manually”.
For the record, in the short term I ended up going with the lazy version of option 3. That is, I added a dask.config.set(delayed_optimize=...) to my library (satpy) in its utility function that merges all final results and computes them. This function is used in our 2 main common use cases so it is a no-op if no Delayed result and an improvement if there is a Delayed result. However, this still does not resolve all of the issues related to this as a generic problem and all of the optimization and task sharing/reuse issues of da.store.
Thanks for raising this point. To better understand your point, it would really help to have a MVCE reproducing your situation.
That said, I agree that at least documentation should be clearer. Currently, I basically see from the corresponding page:
In most cases, users won’t need to interact with these functions directly, as specialized subsets of these transforms are done automatically in the Dask collections (dask.array , dask.bag , and dask.dataframe )
But it is hard to say which ones…
users working with custom graphs or computations may find that applying these methods results in substantial speedups
which seems to be your case, but this doesn’t really help.
Could you also point to the code? Do you think some utility method should be added to Dask and documented there?
Here’s the smallest MVCE I could think of. The basic idea is that one of the optimizations done for arrays that isn’t performed for Delayed objects is combining slicing operations. The + 1 in the delayed function is just to have some operation done inside, it isn’t realistic and I understand that that operation could be done in Array-friendly ways.
import dask
import dask.array as da
@dask.delayed(pure=True)
def delayed_func(arr1):
return arr1 + 1
start = da.zeros((2, 2), chunks=1)
subarr1 = start[:1]
subarr2 = subarr1[:, :1]
delay_result = delayed_func(subarr2)
assert len(delay_result.dask.keys()) == 9 # zeros * 4 -> getitem * 3 -> finalize -> delayed_func
if True:
# current dask
assert len(dask.optimize(delay_result)[0].dask.keys()) == 5 # zeros * 1 -> getitem * 2 -> finalize -> delayed_func
else:
# if Arrays were detected in Delayed graphs:
assert len(dask.optimize(delay_result)[0].dask.keys()) == 2 # finalize-getitem-zeros -> delayed_func
# use array optimize instead of delayed
with dask.config.set(delayed_optimize=da.optimize):
assert len(dask.optimize(delay_result)[0].dask.keys()) == 2 # finalize-getitem-zeros -> delayed_func
Here’s my workaround:
On the point of “custom graphs”, I’m not sure I would consider what I’m doing to be a custom graph. I guess if you want to call it a “custom computation” because my calculations require running a Delayed function then sure I guess it is custom in that it isn’t staying entirely within one dask collection type.
Another solution that I’ve tried and didn’t write in my original post is that I could take the final Delayed result and convert it back to an Array. Then a compute/optimize would perform Array graph optimization.
Do you think some utility method should be added to Dask and documented there?
I’m not sure what the best method is. I suppose there could be a section on combining graph optimization functions and why it might be needed, but I’d be curious if there is a way to run all optimization methods for all collection types. Sure it is a performance hit during graph optimizations to search for optimizations that aren’t possible, but theoretically that happens once and should be minimal compared to the overall computation. For example, a graph made up of purely Delayed functions won’t gain anything by trying to fuse slice/getitem tasks or other numpy function calls like an Array graph would have. However, if combining/mismatching collection types is even a little bit common it could really improve performance for those use cases to try all graph optimizations.
I forgot to mention the other benefit of performing all optimizations at the same time is that dask.array.store would no longer need to optimize for Array graphs internally. The store function could then:
Properly respect graph optimization keyword arguments, like optimize_graph=False which it does not currently do.
Thanks a lot @djhoese for all the explanations here.
I have to admit that I’m not sure of the real explanation behind why optimizations are not always performed for all collections by default, but you sort of argument this why.
I also find Dask optimization and related documentation a bit cryptic when it comes to be sure which optimization is performed and when.
Considering your solution on pytroll, I’m still glad there is a somehow easy way of forcing the optimizations you want on your custom computation, but I agree this is not ideal, or a the very least not well documented.
In the end, I would recommend to raise a new issue describing this case and the potentials solutions you propose directly on github, you’ll get more traction for maintainers there. I’d be happy to help clarify the documentation if it comes to that at the end.