How to efficiently extract patches from a xarray/dask dataset?

I am trying to divide a large xarray dataset into smaller patches and save these patches as separate files. However, xarray/dask takes an incredibly long time for this process and I don’t know where my mistake lies…Even with parallel computing, extracting the patches takes an eternity. I would love to use xarray and dask for this use case as I have too large datasets to load into RAM and the structure and idea behind xarray with dask makes a lot of sense. However, the performance really leaves a lot to be desired and I don’t understand where I’m going wrong?

# S2.data is my xarray dataset with dim: (time:76 y:7336 x:4302) and 12 data variables (satellite bands) and is loaded from 76 GeoTiff files
patches = patch_indices["Annotation"] + patch_indices["No-Annotation"]
tasks = []
for patch_idx in tqdm(patches, total=len(patches), desc="Creating patches"):
    i, j = patch_idx
    filename = Path(patch_folder, f"patch_{i}_{j}.nc")
    if not filename.exists():
        patch = S2.data.isel(x=slice(i, min(i + 128, S2.data.x.size)), y=slice(j, min(j + 128, S2.data.y.size)))
        delayed_obj = patch.to_netcdf(filename, format="NETCDF4", engine="netcdf4", compute=False)
        tasks.append(delayed_obj)  
if tasks:
    print("Saving patches")
    with ProgressBar():
        dask.compute(*tasks)

What are your chunk sizes in the original data? Are these patches overlapping at all?

I set the chunksize to „auto“ in the big original dataset. Should I set it to something like: (time: 1, x: 500, y:500)?

Hi @Paulus, welcome to Dask community!

@martindurant question makes a lot of sense. Since you are reading Sentinel 2 data, files by files, I assume your chunks are initially of only 1 size in the time dimension. If I understand your workflow well, you are trying to build multi temporal patches of 128 by 128 pixels size, but with the entire time dimension.

I imagine that in this case, to build a single patch, Dask/Xarray have to go through 76 S2 images times 12 bands, take only a really small portion of each file, and build the netCDF. This is probably why it takes such a long time.

The best approach here would probably be to rechunk your data on the time dimension, and then perform the operation. But again as Martin said, this also depends on your patches, their overlap, and other considerations.

So what chunk does this choose? Your xarray dataset’s repr will tell you.

So the „auto“ option is creating chunks of a dimension: (1,7400,6500). This is a full S2 image at one timestep.

The patches are created with an overlap of 32 on each side with a width/height of 128. But I am not using all patches in the end, because I only have for some parts of the S2 images groundtruth data.

For the example I mentioned with 76 S2 images (7400x6500) I extract roughly 1200 patches, but of course some do have an overlap there.

Can you give an example of how to chunk the data properly on the time dimension? Do I use a single timestep with smaller width/height or multiple timesteps?

And I read something that when you use xarray.open_mfdataset(parallel=True, chunk=auto…) with and you should not rechunk your data (Parallel computing with Dask)…but since I am not using this function, because of the performance lack, this might not be a problem.

It probably only comes from the way you read the files at first, they are chunked this way on disk.

Since you want to write the whole time slice, I’d use all the time dimension in depth, maybe something like (76, 512, 512).

How are you creating your Xarray Dataset in the first place? Rechunking is not recommended because it can be an heavy and costly process, even if simple in theory. That’s why some packages such as rechunker have been created.