Hi,
I am working on a problem where I need different dask workers to use have access to an entire array that can itself be lazily computed. The easiest way to explain this is with an example. We first set up a dask array for which the internal details are not important – it is an array that might be expensive to compute:
import numpy as np
import dask.array as da
from dask import delayed
@delayed(pure=True)
def make_global_array():
print('Making global array')
return np.random.random((2048, 2048))
global_array = make_global_array()
I could also simply write:
global_array = da.random.random((2048, 2048))
for this example but I am using the code above to more easily be able to see when that array is actually computed. The details of that array are unimportant but in the real-life scenario we have it is an array that could potentially be computed efficiently in parallel and for which the computation could be expensive.
I will now set up another, regular dask array:
array = da.random.random((1024, 1024), chunks=(512, 512))
I then want to operate over chunks of this array, and for each chunk I need to call a C function that needs the data for that chunk, as well as the entire global_array
, as C-contiguous buffers. The C function might look like:
def c_function(a, global_array):
return a + global_array.mean()
Again the details inside the function are unimportant above but all the chunk data and all (not just part of) the global array are needed.
I can then use map_blocks
to iterate over chunks and do the operation I need to do:
def process(array):
def do_computation(a, block_info=None):
computed_global_array = global_array.compute()
return c_function(a, computed_global_array)
return da.map_blocks(do_computation, array)
Note that above I need to compute the global array before passing it to the C function.
Finally, I compute the result array:
dask_result = process(array)
dask_result.compute()
At this point, this will print the following output:
Making global array
Making global array
Making global array
Making global array
The bottom line of the issue I am trying to solve is that I want to make sure the global array gets computed only once, using all available workers, rather than be computed once per worker or chunk.
An obvious solution is to move:
computed_global_array = global_array.compute()
outside of the do_computation
function as such:
def process(array):
computed_global_array = global_array.compute()
def do_computation(a, block_info=None):
return c_function(a, computed_global_array)
return da.map_blocks(do_computation, array)
and the global array gets computed only once but it will get computed during the following step:
dask_result = process(array)
not during the compute step:
dask_result.compute()
What I really want is for process(array)
to be fast and not do any computation, but for the computation of the global array to happen at most once when the resulting dask array (or an array depending on it) is computed.
Does anyone see a way to achieve this?
For background, this is for the astropy reproject package, and the internal C functions are similar to (but not exactly the same as) map_coordinates, in that each chunk of the output array might be some arbitrary combination of points sampled from anywhere in the input array, which is why the internal C functions need access to the whole input array.