Fast dask-parallel in-place modification of xr.Dataset data (specific dimensions)

This was originally posted on GitHub Discussions: fast dask-parallel in-place modification of xr.Dataset data (specific dimensions) · Discussion #9426 · dask/dask · GitHub


[ Re-posted here following advice from the xarray-dev community: fast dask-parallel "shuffling" of xr.Dataset dimensions · Discussion #6951 · pydata/xarray · GitHub ]

Hi there,

I am trying to come up with a fast, Dask-parallel way of “shuffling” the contents of an xr.Dataset (or Datarray) along certain dimensions. An example below:

# create LocalCluster, dask Client, etc.

# open a large netCDF file
dset = xr.open_dataset(..., chunk={"t": 1, "z": 10})
# dset contains a single variable called "dummy" with dimensions (t, z, y, x) == (5, 50, 1000, 1000)

def pseudo_shuffle_xy(ds: xr.Dataset) -> xr.Dataset:
    """Shuffle values along (x, y) dimensions and return the shuffled dataset."""
    for k in range(len(ds.z)):
        dict_isel_ = {
            'x': xr.DataArray(data=np.random.randint(low=0, high=nx, size=(ny, nx)), dims=['y', 'x']),
            'y': xr.DataArray(data=np.random.randint(low=0, high=ny, size=(ny, nx)), dims=['y', 'x'])
        }
        # not proper shuffling, but close enough
        ds.dummy.loc[dict(z=k)] = ds.dummy.isel(z=k).isel(dict_isel_)
    return ds

result = dset.map_blocks(pseudo_shuffle_xy, template=dset).compute()

So the call to map_blocks seems to work as expected but it’s likely slower than it has to be - and the manual loop is ugly.
I’ve also tried to use xr.apply_ufunc:

def pseudo_shuffle_xy_v2(da: xr.Dataset):
    nx, ny = da.shape[0], da.shape[1]   # da.shape = (1000, 1000, 50, 5), chunk size = (1000, 1000, 5, 1)
    idx_shuffled_x, idx_shuffled_y = np.random.permutation(nx), np.random.permutation(ny)
    return da.vindex[idx_shuffled_x, ...][:, idx_shuffled_y, ...]

result = xr.apply_ufunc(pseudo_shuffle_xy_v2, dset_dummy, input_core_dims=[["z", "t"]], output_core_dims=[["z", "t"]], dask="allowed", vectorize=True).compute()

This is faster than the map_blocks call but i get warnings about dask creating large chunks - likely as a result of the vindex call.

What’s the best way to vectorize the “shuffle” (in-place assignment, really) using Dask?

1 Like

Both map_blocks and apply_ufunc will shuffle elements with a block/chunk – which means we don’t do a complete shuffle, and an element in the first block can never end up in the last block. It could be worth re-chunking the data such that all the values you’d like to shuffle are in the same chunk.

1 Like