Skip nan blocks in map_blocks

One very common pattern I use is a map_blocks (or most often map_overlap) and then pass directly to to_zarr to trigger computation. Often times my array is 2d and many chunks are all nan because the regions are interest don’t fill the 2d grid.

Is there a way to reduce the number of tasks in the map_blocks or map_overlap function by specifying which chunks to focus on or to avoid altogether? I am thinking something like

blocks = array.blocks.ravel()

def block_condition(block):
    # simulate some condition to trigger on specific blocks
    condition = np.random.random() > .5
    return np.reshape(condition, [1] * len(block.shape))

r = array.map_blocks(compute_block_sum, chunks=[1] * len(blocks[0].shape))

# incur this cost upfront
passing_blocks = r.ravel().compute()
block_names = [x._name for x in blocks]
passing_block_names = zip(passing_blocks, block_names)

# reduce tasks to blocks that pass condition only 
array.map_blocks(expensive_function, block_ids=[b for b, c in passing_block_names if c])

Is there any discussion or existing work on this topic?

Hi @ljstrnadiii,

Would it do to build a wrapper around expensive_function? That wrapper could then returns the block if it contains only NaNs and continues to call expensive_function with the provided args and kwargs is it does contain data.

Something like

def expensive_function_nan_filter(block, *args, **kwargs):
    if numpy.isnan(block).all():
        return block
    return expensive_function(block, *args, **kwargs)

my_array.map_blocks(expensive_function_nan_filter)

This way you would not need to call compute on each block separately first, to check if a block is all NaN.

This would not quite, as you say “reduce the number of tasks in the map_blocks or map_overlap function”. You still have a task for each block, but it does prevent you from calling the expensive_function on them all.

Let me know if this helps at all.

Cheers,
Timo

1 Like

Hi @ljstrnadiii,

I’ve got to admit that I’m not aware of any discussion on this kind of graph optimization.

The solution proposed by @TMillenaar is the simplest one, but as he said, this does not reduce the number of tasks, which is probably what you are really after.

If you can know only with the chunk positition in the array if it is only NaN or not, then you can probably optimize the graph at some point by removing some nodes, but I’m not sure if there is an easy way. Your solution seems to rely on a first call to map_blocks, which means that you’ll generate as many tasks as block anyway at some point, is that correct?

1 Like

Hi @TMillenaar, this is what I am actually doing now and it helps a ton! Thank you :pray:

Hey @guillaumeeb, I am not sure there is a way around all the tasks for checking which blocks are non-nan, but those are much cheaper than the expensive op. I am currently slicing very large regions and submitting those in a way that keeps tasks less than some threshold. This speed up a ton when I skip the expensive op with the proposed by @TMillenaar , but I don’t get great resource utilization as a small subset of those tasks are heavy cpu operations, which might only have a subset of workers doing anything once the nan-blocks are done.

I think I see so many tasks with map_overlap because each block corresponds to all its neighboring blocks, which are tasks themselves. Maybe we could replace the nan blocks with a new dask array and reconstruct the map_overlap output with da.block? This could help with the number of tasks only for map_overlap. (I’ll report back if this is helpful)

In reality, I am writing to zarr with (to_zarr back in xarray land) the result from map_overlap. Perhaps we can do a similar thing and call to_zarr with compute=False for only valid blocks and submit those in batch to limit tasks and to only submit tasks with expensive op in order to have more uniform tasks submitted to dask. (I’ll report back on this as well).

Any concerns or gotchas in mind?

Have you considered conditionally converting chunks to sparse?

def maybe_to_sparse(x, threshold):
    if x.size and numpy.isnan(x).sum() / x.size > threshold:
        return sparse.COO.from_numpy(x, fill_value=numpy.nan)
    else:
        return x

def ensure_dense(x):
    return x.todense() if isinstance(x, sparse.COO) else x

a = a.map_blocks(maybe_to_sparse, threshold=0.7)
...
a = a.map_blocks(ensure_dense)

This does not guarantee that your functions between the two map_blocks will be any faster on the sparse data out of the box - but it does make it fairly straighforward to optimize them yourself if they don’t. ufuncs should run in O(1) out of the box.

@crusaderky thanks for the recommendation. I have not tried this, but will consider this going forward!