Memory error even with managing it - Question

Hi !

I am having trouble with dask memory management, I am trying to interpolate a large raster using dask, rioxarray and rasterio.

Code

I open my GTiff file using rioxarray and I chunk it to get a dask array :

import rioxarray
import rasterio

ds = rioxarray.open_rasterio(filename="D:\\Documents\\MISSIONS\\DASK\\interpolation\\echant_much_more_bigger_30cm_pix_l93.tif", 
                             chunks=(1, 5000, 5000)).astype(rasterio.int8)
ds

Then, I recover all my chunks thanks to this function :

chunks = ds.data.to_delayed().ravel()

To interpolate, I use the rasterio.fill.fillnodata function so I have to create a mask array :

import numpy as np

def compute_chunk(chunk):
    return chunk.compute()

def create_mask(chunk):
    mask = np.where(chunk < 9, 0, chunk)
    return mask
masks = []
for chunk in chunks:
    computed = client.submit(compute_chunk, chunk)
    mask = client.submit(create_mask, computed)
    del computed
    masks.append(mask)
    del mask

Problem

When my masks array is created, my processing memory is around 8 GiB with about 2 GiB of unmanaged memory (1 old and 1 recent). So when I try to get the results of my mask (that is an array of futures) like this :

res = [future.result() for future in masks]

I get a MemoryError.

I did some researches and I tried to change my memory manging method but I think I’m doing something wrong… but I don’t know what. Did someone have an idea or an advice ?

Thanks you in advance ! :slight_smile:

Clément

@ClementAlba Welcome to Dask! Apologies for the delay in response.

I think we can optimize your code to use Dask more efficiently, which will help with memory management:

  • You need not use Delayed here. We can use Dask Array API directly becasue ds.data is a Dask Array (mixing collection isn’t generally recommended).
  • compute_chunk() isn’t necessary, you can directly do computed = chunks.compute() – not only are for-loops in general bottlenecks, but we don’t really need to compute the chunks at all. We can continue in a lazy fashion.
  • You can use map_blocks instead of for-loop + create_mask. Or, da.where(), Dask Array’s where function directly.

Reproducibe example:

import dask

import rioxarray
import rasterio

import dask.array as da
import numpy as np

from distributed import Client

client = Client()

ds = rioxarray.open_rasterio(filename="https://github.com/rasterio/rasterio/raw/1.2.1/tests/data/RGB.byte.tif", 
                             chunks=(1, 5000, 5000)).astype(rasterio.int8)

chunks = ds.data.ravel()

result = da.where(chunks < 9, 0, chunks)
result.compute()
1 Like