Hi !
I am having trouble with dask memory management, I am trying to interpolate a large raster using dask, rioxarray and rasterio.
Code
I open my GTiff file using rioxarray and I chunk it to get a dask array :
import rioxarray
import rasterio
ds = rioxarray.open_rasterio(filename="D:\\Documents\\MISSIONS\\DASK\\interpolation\\echant_much_more_bigger_30cm_pix_l93.tif",
chunks=(1, 5000, 5000)).astype(rasterio.int8)
ds
Then, I recover all my chunks thanks to this function :
chunks = ds.data.to_delayed().ravel()
To interpolate, I use the rasterio.fill.fillnodata function so I have to create a mask array :
import numpy as np
def compute_chunk(chunk):
return chunk.compute()
def create_mask(chunk):
mask = np.where(chunk < 9, 0, chunk)
return mask
masks = []
for chunk in chunks:
computed = client.submit(compute_chunk, chunk)
mask = client.submit(create_mask, computed)
del computed
masks.append(mask)
del mask
Problem
When my masks array is created, my processing memory is around 8 GiB with about 2 GiB of unmanaged memory (1 old and 1 recent). So when I try to get the results of my mask (that is an array of futures) like this :
res = [future.result() for future in masks]
I get a MemoryError.
I did some researches and I tried to change my memory manging method but I think I’m doing something wrong… but I don’t know what. Did someone have an idea or an advice ?
Thanks you in advance !
Clément
@ClementAlba Welcome to Dask! Apologies for the delay in response.
I think we can optimize your code to use Dask more efficiently, which will help with memory management:
- You need not use Delayed here. We can use Dask Array API directly becasue
ds.data
is a Dask Array (mixing collection isn’t generally recommended).
-
compute_chunk()
isn’t necessary, you can directly do computed = chunks.compute()
– not only are for-loops in general bottlenecks, but we don’t really need to compute the chunks at all. We can continue in a lazy fashion.
- You can use
map_blocks
instead of for-loop + create_mask
. Or, da.where()
, Dask Array’s where
function directly.
Reproducibe example:
import dask
import rioxarray
import rasterio
import dask.array as da
import numpy as np
from distributed import Client
client = Client()
ds = rioxarray.open_rasterio(filename="https://github.com/rasterio/rasterio/raw/1.2.1/tests/data/RGB.byte.tif",
chunks=(1, 5000, 5000)).astype(rasterio.int8)
chunks = ds.data.ravel()
result = da.where(chunks < 9, 0, chunks)
result.compute()
2 Likes
Hi @pavithraes, thanks for your response !
I wasn’t available for more than one month and I tried your solution when I was back. Unfortunately it not worked but I finally solve the problem.
After creating my masks I wan’t to interpolate my data and so as not to overflow my memory, each time a tile was interpoled, I wrote it in a directory (so I don’t keep any array in my memory). The idea is to keep all my functions lazy in my for-loop and after, do the computation :
# Cette fonction permet de créer le mask pour un chunk donné
@dask.delayed
def create_mask(chunk):
mask = np.where(chunk < config.get("mask").get("limit"), 0, chunk)
return mask
# Cette fonction permet d'interpoler un chunk donné. Elle retourne un tableau numpy
@dask.delayed
def interpolation(chunk, mask):
return rasterio.fill.fillnodata(chunk, mask, config.get("interpolation").get("max_search_distance"))
# Cette fonction permet de créer un DataArray à partir d'un chunk. L'intérêt d'utiliser les DataArrays est que l'on peut facilement
# les convertir en raster
@dask.delayed
def create_data_array(chunk, coords):
return xr.DataArray(
chunk,
dims=["band", "y", "x"],
coords=coords).rio.write_crs(2154)
# Cette fonction permet d'exporter un chunk au format .tif
@dask.delayed
def write_tile(tile, i):
tile.rio.to_raster(config.get("directories").get("output") + "/" + str(i) + ".tif")
for k in range(len(delayed_chunks)):
# Calcul du mask
mask = dask.delayed(create_mask)(delayed_chunks[k])
# Interpolation du chunk
interpoled = dask.delayed(interpolation)(delayed_chunks[k], mask)
# Création du DataArray
data_array = dask.delayed(create_data_array)(interpoled, coordinates[k])
# Rangement dans le tableau
tiles_to_write.append(dask.delayed(write_tile)(data_array, k))
By doing this strategy I don’t have memory error yet. I also manage my workers and my threads_per_worker when I create my Client.
If you wan’t to see my full code you can go to my GitHub repository : parallelisation_process/interpolation at main · ClementAlba/parallelisation_process · GitHub
Obviously, if you have advices or remarks on my usage of Dask, I’d be happy to read them !