I would like to scale pytorch computations (computing gradient of loss with respect to model parameters)
using dask distributed. For this, I define a class which holds the model and exposes a method, to compute the gradients per sample in a batch/block. Using map_blocks, I map the batch function over the blocks of dask arrays (chunked only in the batch dimension). A fully working example looks like this:
import torch
from torch.func import functional_call
import dask.array as da
from distributed import Client, LocalCluster
class LossGrad:
def __init__(self, model, loss):
self.loss = loss
self.model = model
@property
def device(self):
return next(self.model.parameters()).device
def _compute_single_loss(self, params, x, y):
outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device)))
return self.loss(outputs, y.unsqueeze(0))
def apply(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
params = {k: p.detach() for k, p in self.model.named_parameters()}
gradient_fnc = torch.func.jacrev(torch.vmap(self._compute_single_loss, in_dims=(None, 0, 0)))
return gradient_fnc(params, x.to(self.device), y.to(self.device))
def to(self, device):
self.model = self.model.to(device)
return self
def block_apply(x, y, _loss_grad: LossGrad):
tensor_dict = _loss_grad.apply(torch.as_tensor(x, dtype=torch.float32), torch.as_tensor(y, dtype=torch.float32))
return torch.cat([t.reshape(t.shape[0], -1) for t in tensor_dict.values()], dim=-1).cpu().numpy()
if __name__ == "__main__":
dimensions = (200, 2)
num_params = (dimensions[0] + 1) * dimensions[1]
num_data = int(1e6)
chunk_size = 1000
t_x = torch.rand(num_data, dimensions[0])
t_y = torch.rand(num_data, dimensions[1])
with Client(LocalCluster(processes=False)) as client:
torch_model = torch.nn.Linear(*dimensions, bias=True)
loss_grad = LossGrad(torch_model, torch.nn.functional.mse_loss)
da_x = da.from_array(t_x.numpy(), chunks=(chunk_size, -1))
da_y = da.from_array(t_y.numpy(), chunks=(chunk_size, -1))
# ??? what to do, if I want to have the model on gpu?
loss_grad_future = client.scatter(loss_grad, broadcast=True)
# only works when chunk_size divides num_data
grads = da.map_blocks(block_apply, da_x, da_y, loss_grad_future, dtype=da_x.dtype, chunks=(chunk_size, num_params))
result = da.to_zarr(grads, "grads.zarr", overwrite=True)
Doing this, I faced two questions:
- map_blocks only works, if chunk_size divides num_data (when using to_zarr, if calling .compute() this works). Besides handling this manually (concatenate several map_block calls), is there a better way to do it?
- Later, I want to use dask_cuda, with LocalCUDACluster instead of the local one. How can I achieve to move the model copies (is this what client.scatter does?) to the respective GPU (so having one worker per available GPU).
Versions:
- torch: 2.1.0
- dask: 2023.5.0
- distributed: 2023.5.0
- python: 3.8.17
Thanks for your help