Summing multiple non-contiguous subsets of an array

Do you see a more dask-ish way of achieving the same result? Using a list comprehension seems hackish…

import dask.array as da
darr = da.random.randint(256, size = (10000,10000))
darr_idx = da.random.randint(256, size = (10000))
res = da.concatenate([darr[da.where(darr_idx == i)].sum(axis=0)[None, :] for i in range(256)], axis=0)

Hi @mewmew_laser_kittens, welcome to Discourse! I’m sharing a snippet relying on Dask DataFrame from @pavithraes, which might be a bit more readable than the list comprehension:

import dask.dataframe as dd

darr = da.random.randint(256, size = (10000,10000))
ddf = dd.from_array(darr)

darr_idx = da.random.randint(256, size = (10000))
ddf['idx'] = dd.from_array(darr_idx)

ddf = ddf.sort_values('idx') # Optional

result = ddf.groupby('idx').sum().reset_index()
result.compute()
1 Like

Thank you for your answer! Although your snippet runs slower than my example, it seems to scale with dask distributed much better. I however receive a deprecation warning. Is there a way around it?

C:\Users\___\anaconda3\envs\dask\lib\site-packages\dask\dataframe\methods.py:333: FutureWarning: reindexing with a non-unique Index is deprecated and will raise in a future version.

@mewmew_laser_kittens Thanks for sharing! I’m not getting that warning with Dask 2022.02.0, could you please share your Dask version?

I cannot access the computer I used but from memory it was 2021.10.x, which was the latest version available in conda.

I experimented a bit since then, and it turns out that the groupby method unfortunately seems to be really slow. By using a simple python loop iterating over the values of arr_idx, I get a speedup > 50 (!!!).

Hi @mewmew_laser_kittens,

Here’s another dask.array-ish solution that ran (on my laptop) 10 times faster than the original one with list-comprehension. Feel free to ask for clarification if you need it.

import numpy as np
import dask.array as da


def your_expression(arr, arr_idx, rng):
    res = np.concatenate(
        [arr[np.where(arr_idx == i)].sum(axis=0)[None, :] for i in rng],
        axis=0
    )
    # res.ndim = 2; insert new axis to get 3, as expected by blockwise()
    return res[np.newaxis, ...]


n = 256
darr = da.random.randint(n, size = (10000, 10000))
darr_idx = da.random.randint(n, size = 10000)
range_ = da.arange(n)
res = da.core.blockwise(
    *(your_expression, 'ikj'),
    *(darr, 'ij'),
    *(darr_idx, 'i'),
    *(range_, 'k'),
    adjust_chunks={'i': 1},
    dtype=np.int_,
    meta=np.array([[[]]])
)

identity = lambda x, axis, keepdims: x
meta = np.array([[]])
res = da.reduction(
    res, identity, np.sum, axis=0, dtype=np.int_, meta=meta
)
res.compute()
2 Likes