A more efficient way to sample an array

I have a large, chunked 3D zarr of geospatial data with dimensions (features, y, x), chunked over (y, x), and I want to sample of around 1M points on the (y, x) coords to compute some quantiles. This is a very large array (actually a concatenation of multiple zarrs) and dask’s quantile functions raise memory errors like

MemoryError: Task ('custom_quantile-rechunk-merge-transpose-f5359f8b871b8fb06c24c4b2121bd049', 0, 2) has 814.72 GiB worth of input dependencies, but worker tls://10.8.0.13:42617 has memory_limit set to 30.55 GiB.

so I am trying to approximate it from a sample instead. I tried to use `sel` to extract the sample, but I get a 0.5GB graph and a performance warning. It sometimes work, but takes ages for the computation to start and sometimes the scheduler disconnects and i have to try again. A side question is why this operation causes a large graph. This is roughly what I am doing:

pcts = np.array([1e-6, 0.5, 1-1e-6])
# take 1M samples across the array

tile_samples = []
for tile_dataset, tile_geometry in zip(tile_datasets, tile_geometries):
    
    is_valid_sample = gpd.points_from_xy(sample[:, 0], sample[:, 1]).within(tile_geometry)
    valid_sample = sample[is_valid_sample]
    # get sample from each sub-array
    tile_sample = tile_dataset.sel(
        x=xr.DataArray(valid_sample[:, 0], dims="points"),
        y=xr.DataArray(valid_sample[:, 1], dims="points"),
        method="nearest",
    )
    tile_samples.append(tile_sample)

all_samples = xr.concat(
    tile_samples,
    dim="points",
) # concatenate samples from all sub-arrays

percentiles = da.nanquantile(
    all_samples.to_array().data,
    q=pcts,
    axis=(1,),
).compute() # compute percentiles across all samples

It would be very useful to be able to sample efficiently for other MC estimates. Am I missing something here?

Hi @Nestor_Sanchez, welcome to Dask community!

It’s not easy to grasp your code as is, because it lacks some information on some variables like tile_datasets, tile_geometries, or even sample. But if I understand correctly, your looping through chunks and selecting points by a nearest method in each tile, then concatenating all this sel operations together and applying quantile. This makes probably a huge graph.

I’ve got a suggestion, if you now the chunking on the (x,y) coords, and more precisely the number of chunks, you could determine a target number of points per chunk and take some sample using map_blocks from Dask Array as you don’t seem to use Xarray for the quantile? Or maybe you could apply your code inside a map_blocks directly, this way it would probably simplify the graph and you wouldn’t have to concat results.