Hi @guillaumeeb Thanks for your reply!

I will attach the complete code here and read the documents you povide.

`da_x`

and `da_y`

are dask-arrays acquired using `xarray.open_dataset("a NetCDF4 file").data`

. I could not provide the raw data, but using `da.from_array()`

to create dummy data might be enough to reproduce rest of the code.

```
import time
import math
import numpy as np
import netCDF4 as nc
from numba import cuda
import cupy as cp
from dask.distributed import Client, progress
import dask.array as da
from dask_cuda import LocalCUDACluster
import xarray as xr
@cuda.jit
def cuda_corr(X, Y, out):
"""
X: ([1440/c], 721, n)
Y: ([480/c], 371, n)
out: [[480/c], 371, 721, [1440/c]]
"""
start_i, start_j, start_k = cuda.grid(3)
stride_i, stride_j, stride_k = cuda.gridsize(3)
n = X.shape[2] # Time
for k in range(start_k, X.shape[0], stride_k):
for j in range(start_j, X.shape[1], stride_j):
for i in range(start_i, Y.shape[0], stride_i):
for p in range(Y.shape[1]):
sum_X = 0.
sum_Y = 0.
sum_XY = 0.
squareSum_X = 0.
squareSum_Y = 0.
for h in range(n):
sum_X = sum_X + X[k, j, h]
sum_Y = sum_Y + Y[i, p, h]
sum_XY = sum_XY + X[k, j, h] * Y[i, p, h]
squareSum_X = squareSum_X + X[k, j, h] * X[k, j, h]
squareSum_Y = squareSum_Y + Y[i, p, h] * Y[i, p, h]
out[i, p, j, k] = (n * sum_XY - sum_X * sum_Y)/ \
(math.sqrt((n * squareSum_X - \
sum_X * sum_X)* (n * squareSum_Y - \
sum_Y * sum_Y)))
def foo(X, Y):
cuda_X = cp.asarray(X) # copying dask-array from CPU to GPU
cuda_Y = cp.asarray(Y) # copying dask-array from CPU to GPU
cuda_out = cp.zeros((Y.shape[0], Y.shape[1], X.shape[1], X.shape[2])) # create a cupy array directly on GPU
cuda_corr[(8, 8, 8), (8, 8, 8)](cuda_X.T, cuda_Y, cuda_out)
out = cp.asnumpy(cuda_out) # converting cupy array to numpy array
return out
fileName="tp.nc" #netCDF (time: 10958, lon: 1440, lat: 721)
nc1 = xr.open_dataset(fileName, chunks={'time': -1, 'lon': 90, 'lat': -1})
da_x = nc1['tp'].sel({'time': slice("2000-01-01", "2010-01-10")}).data
fileName="rr.nc" #netCDF (time: 8401, lon: 371, lat: 480)
nc2 = xr.open_dataset(fileName, chunks={'time': -1, 'lon': -1, 'lat': 24})
da_y = nc2['rr'].sel({'time': slice("2000-01-01", "2010-01-10")}).data
# da_x: dask.array<getitem, shape=(3663, 721, 1440), dtype=float32, chunksize=(3663, 721, 90), chunktype=numpy.ndarray>,
# da_y: dask.array<getitem, shape=(3663, 480, 371), dtype=float32, chunksize=(3663, 24, 371), chunktype=numpy.ndarray>
client = Client(n_workers=4, threads_per_worker=2)
result = da.map_blocks(foo, da_x, da_y, chunks=(24, 371, 721, 90), drop_axis=0, new_axis=[1,2], dtype=np.float64)
result_ = result.compute()
```

Hope these materials could provide more information. Thanks again!