Pytorch computation with Distributed utilizing GPUs

I would like to scale pytorch computations (computing gradient of loss with respect to model parameters)
using dask distributed. For this, I define a class which holds the model and exposes a method, to compute the gradients per sample in a batch/block. Using map_blocks, I map the batch function over the blocks of dask arrays (chunked only in the batch dimension). A fully working example looks like this:

import torch
from torch.func import functional_call
import dask.array as da
from distributed import Client, LocalCluster

class LossGrad:
    def __init__(self, model, loss):
        self.loss = loss
        self.model = model

    def device(self):
        return next(self.model.parameters()).device

    def _compute_single_loss(self, params, x, y):
        outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device)))
        return self.loss(outputs, y.unsqueeze(0))

    def apply(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        params = {k: p.detach() for k, p in self.model.named_parameters()}
        gradient_fnc = torch.func.jacrev(torch.vmap(self._compute_single_loss, in_dims=(None, 0, 0)))
        return gradient_fnc(params,,

    def to(self, device):
        self.model =
        return self

def block_apply(x, y, _loss_grad: LossGrad):
    tensor_dict = _loss_grad.apply(torch.as_tensor(x, dtype=torch.float32), torch.as_tensor(y, dtype=torch.float32))
    return[t.reshape(t.shape[0], -1) for t in tensor_dict.values()], dim=-1).cpu().numpy()

if __name__ == "__main__":
    dimensions = (200, 2)
    num_params = (dimensions[0] + 1) * dimensions[1]
    num_data = int(1e6)
    chunk_size = 1000
    t_x = torch.rand(num_data, dimensions[0])
    t_y = torch.rand(num_data, dimensions[1])

    with Client(LocalCluster(processes=False)) as client:
        torch_model = torch.nn.Linear(*dimensions, bias=True)
        loss_grad = LossGrad(torch_model, torch.nn.functional.mse_loss)
        da_x = da.from_array(t_x.numpy(), chunks=(chunk_size, -1))
        da_y = da.from_array(t_y.numpy(), chunks=(chunk_size, -1))
        # ??? what to do, if I want to have the model on gpu?
        loss_grad_future = client.scatter(loss_grad, broadcast=True)
        # only works when chunk_size divides num_data
        grads = da.map_blocks(block_apply, da_x, da_y, loss_grad_future, dtype=da_x.dtype, chunks=(chunk_size, num_params))
        result = da.to_zarr(grads, "grads.zarr", overwrite=True)

Doing this, I faced two questions:

  1. map_blocks only works, if chunk_size divides num_data (when using to_zarr, if calling .compute() this works). Besides handling this manually (concatenate several map_block calls), is there a better way to do it?
  2. Later, I want to use dask_cuda, with LocalCUDACluster instead of the local one. How can I achieve to move the model copies (is this what client.scatter does?) to the respective GPU (so having one worker per available GPU).


  • torch: 2.1.0
  • dask: 2023.5.0
  • distributed: 2023.5.0
  • python: 3.8.17

Thanks for your help

Hi @schroedk, welcome to Dask community!

  1. This is because you specify into map_blocks identical output chunks using the chunks kwarg.

One way to correctly handle this would be to specify all the output chunks. For example, reusing input chunking scheme:

out_chunks = list(da_x.chunks)
out_chunks[1] = num_params
out_chunks = tuple(out_chunks)

grads = da.map_blocks(block_apply, da_x, da_y, loss_grad_future, dtype=da_x.dtype, chunks=out_chunks)
  1. I guess you’ll have to move your model and data to the GPU. For the data using something as x = x.map_blocks(cupy.asarray), maybe using pytorch. You should apply the same kind of transformation on your model too, but I have to admit I don’t know much about this.