Using Cull-like func to prune branches based on knowledge of chunks with all zeros

Hi !
I have a use case for dask where I am using the computational graph and optimization inherent in dask to reduce the amount of computations that is being done. For my first example (subsampling) this works great.

import dask.array as da
import dask_image.ndfilters as din
from dask.threaded import get
from dask.optimization import cull, inline, inline_functions, fuse
from dask.base import visualize

def subsample(arr, ind):
    return arr[ind]

# create random input array
arr = da.random.random(10, chunks=(2,))

# perform small convolution
weights = da.array([0.5, 0.5])
out = din.convolve(arr, weights)

# subsample result
sub = subsample(out, [0, 9])

# non optimized task graph

keys = sub.__dask_keys__()
dsk1, deps = cull(sub.dask, keys)
dsk2 = inline(dsk1, dependencies=deps)
dsk3 = inline_functions(dsk2, keys)
dsk4, deps = fuse(dsk3)
get(dsk4, keys)

# culled task graph

Dask is smart enough to cull branches of the compute graph that aren’t ever used in the final output of a computation.

I wanted to try to extend this culling mechanism to do a bit more than what I believe dask was designed to do. This is where I need some help to see if this is possible/worthwhile or a fools errand.

The basic idea is I would want dask to recognize when there are blocks with all zeros, initialized from da.zeros(), then with certain interactions such as mul, sum, reshape, convolve, correlate, the computations don’t need to be computed at all. The zero_like task could essentially eat certain downstream tasks and replace them with a zero_like task of the necessary output size.

In this example all computations could technically be ignored given the information that both matrices are all zeros.

import dask.array as da
matrix1 = da.zeros((10, 10), chunks = (5, 5))
matrix2 = da.zeros((10, 10), chunks = (5, 5))
matrix3 = matrix1 @ matrix2

The reason for wanting to do this is that I will be creating very sparse arrays and I know ahead of time where the zeros exist in the array. Given that I have a smart chunking strategy I was hoping that I could spend computation on blocks that need it and the big chunks of zero_like blocks could be reduced to just a concatenation operation of zeros. Here is a small example:

import dask.array as da
import dask_image.ndfilters as din

def subsample(arr, ind):
    return arr[ind]

def subsample_adjoint(arr, shape, ind):
    "Places arr into array of zeros based on indices"
    y = da.zeros(shape=(shape))
    y[ind] = arr
    return y

def subsample_adjoint_chunks(arr, shape, ind):
    "Concatenates arr with arrays of zeros to create a chunking structure based around zeros and non-zeros."
    concat_list = []
    prev_ind = -1
    for idx, i in enumerate(ind):
        diff = i - prev_ind -1
        if diff > 1:
            c = da.zeros(shape=(diff,))
        prev_ind = i
    if shape[0] - ind[-1] - 1 > 0:
        concat_list.append(da.zeros(shape=(shape[0] - ind[-1] - 1,)))
    return da.concatenate(concat_list)

# create random input array
shape = (10,)
arr = da.random.random(shape, chunks=(2,))

# perform small convolution
weights = da.array([0.5, 0.5])
out = din.convolve(arr, weights)

# subsample result
ind = [0]
sub = subsample(out, ind)

# create super sparse array (all zeros except where subsampled)
sub_adj = subsample_adjoint_chunks(sub, shape, ind)

# perform correlation -> can we save computation here by culling the compute graph given we know most of sub_adj is zeros?
final_out = din.correlate(sub_adj, weights)


I have looked at the sparse package (that works with dask) which would make a lot of sense as I have a sparse array. However, the sparse matrixes don’t support a lot of operations and therefore severely limit the generalizability of such an approach.

I am mainly looking for advice on if such an approach would be possible and feasible.


Welcome @alec-flowers !

Your idea is in fact a very tantalizing one. It points to the concept of “Sparse Block Matrices”, which I imagine someone may have already attempted to consolidate elsewhere.

At first sight though, it doesn’t look like it is going to be a “quick fix” (if that’s what you’re aiming for), and may require more consideration. For example, it seems that you would have to reimplement every dask function you are planning to use in such a way that it recognizes and handles, possibly ignoring, these “zero-chunks” appropriately at graph construction time. In other words, certain functions you reimplement may have different ways of handling such chunks from other functions. As an example, a chunk operation such as X + 1, where X is a zero-chunk cannot simply be “ignored”, as its output is a non-zero chunk, albeit a special kind of non-zero chunk.

Still if you can overcome challenges such as these, then you would have achieved a really attractive methodology that would be likely to improve dask.array's performance if successful.