Dask.array.einsum memory leak

Hi Dask community,

I’m taking my first steps with Dask, and I’m particularly interested in performing intensive calculations with Numpy arrays, mainly using the einsum module. When doing this with dask.array, I’ve noticed that the calculations consume so much memory that the process aborts due to resource exhaustion.

A simplified version of my code is as follows:

import numpy as np
import dask.array as da 
n1 = 18
n2 = 30
c1 = np.random.random((n2, n1, n2, n2))
c2 = np.random.random((n2, n2, n1, n2))
c3 = np.random.random((n1, n2, n1, n1))
c4 = np.random.random((n2, n1, n1, n1))

mask_bg = np.eye(n1)
mask_np = np.eye(n2)
mask_ag = np.eye(n1)
mask_mp = np.eye(n2)
delta_vir = 1 - np.eye(n2)
delta_occ = 1 - np.eye(n1)
deltas = np.einsum('nm,ab->nbma', delta_vir, delta_occ)

s_2 = da.einsum('ag,nbmp->nbmapg', mask_ag, c1)
s_2 +=  da.einsum('bg,nmap->nbmapg', mask_bg, c2)

s_2 += da.einsum('np,bmag->nbmapg', mask_np, c3)

s_2 += da.einsum('mp,nbag->nbmapg', mask_mp, c4)

s_2 = da.einsum('nbma,nbmapg->nbma', deltas, s_2).compute()

I’ve try with dask.map_blocks() but with no better results.
Is there any way to improve the RAM usage more efficiently than just adjusting the chunk size? Is it possible to set a maximum amount of RAM to be used during the calculation?

I appreciate any guidance or advice you can offer.

Best regards


Hi @DanielBajac, welcome to Dask Discourse forum!

I just tried your code, and it was very fast so I didn’t notice any memory problem. Is this example reproducing your problems?

Adjusting the chunk size and expressing the problem differently are the best way to optimize RAM usage. Using a LocalCluster, you can also set a maximum memory per Worker, but if your computation needs more, it can just block the Workers and never finish.

Is this example reproducing your problems?

Yes, but my laptop has only 12 GB of RAM. For instance, if you increase the matrix size (e.g., n1 and n2 to 100), you will likely encounter problems. The following image shows the memory leak observed in my system monitor

With n1 and n2 to 100, you create intermediate Arrays with ~1TB chunks!!

I’m not familiar with the computations you are doing, but I imagine this is were the problem comes from.