Here’s another dask.array
-ish solution that ran (on my laptop) 10 times faster than the original one with list-comprehension. Feel free to ask for clarification if you need it.
import numpy as np
import dask.array as da
def your_expression(arr, arr_idx, rng):
res = np.concatenate(
[arr[np.where(arr_idx == i)].sum(axis=0)[None, :] for i in rng],
axis=0
)
# res.ndim = 2; insert new axis to get 3, as expected by blockwise()
return res[np.newaxis, ...]
n = 256
darr = da.random.randint(n, size = (10000, 10000))
darr_idx = da.random.randint(n, size = 10000)
range_ = da.arange(n)
res = da.core.blockwise(
*(your_expression, 'ikj'),
*(darr, 'ij'),
*(darr_idx, 'i'),
*(range_, 'k'),
adjust_chunks={'i': 1},
dtype=np.int_,
meta=np.array([[[]]])
)
identity = lambda x, axis, keepdims: x
meta = np.array([[]])
res = da.reduction(
res, identity, np.sum, axis=0, dtype=np.int_, meta=meta
)
res.compute()