Very long wait time when reading zarr array into PyTorch

I have a zarr (v3) array (shape: [256000, 20, 64, 64]), which is a effectively a collection of video sequences that I want to use for training a PyTorch model. But initial access time is very long for unknown to me reason.

To read the data I simply created a torch Dataset, which uses dask.array.core.from_zarr to read the data as Dask Array. As the data is store in uint8 (for smaller file size) I also need to cast the data into float32 and normalize by dividing by 255.0. Then when accessing I slice the dask array and call compute() to convert the slice to torch.tensor.

The core functionality boils down to this:

import dask.array.core
import numpy
import torch
from torch.utils.data import Dataset


class ZarrDataset(Dataset):
    def __init__(self):
        super().__init__()
        data = dask.array.core.from_zarr("dataset.zarr/arr").astype(numpy.float32, casting="safe")
        self._data = data / 255.0


    def __getitem__(self, idx):

        inputs = torch.tensor(
            self._data[idx, :10, :, :].compute(),
            requires_grad=True,
            dtype=torch.float,
        ).unsqueeze(1)

        targets = torch.tensor(
            self._data[idx, 10:, :, :].compute(),
            requires_grad=True,
            dtype=torch.float,
        ).unsqueeze(1)

        return inputs, targets

__getitem__ is where the issue occurs as that function just for minutes for I don’t really know what - there’s no significant CPU usage, RAM usage or IO usage.

I preferably would use xarray for this task, but I doesn’t support zarr v3.

I tried using .compute() in the constructor but that only resulted in very fast growth of memory usage. I also tried doing the casting and division after calling .compute(), but it didn’t help.

I obviously don’t understand something about dask’s inner workings and I cannot find an information what kind of issue my code encounters.

How can I make my code read the data fast?

Hi @quba, welcome to Dask community!

First, how is your Zarr dataset chunked on disk? Is it aligned with how you are trying to access data?

Then, I would try to see performances outside this Dataset class. What append when you time

data = dask.array.core.from_zarr("dataset.zarr/arr").astype(numpy.float32, casting="safe")
data = data / 255.0
data[idx, :10, :, :].compute()

And particularly just the third lign?

More generally, I think that since you are not using Dask to perform a distributed computation, but just to access parts of the Data, you would get better performances using zarr directly!

Testing your minimal example the performance is as expected and the issue persists even when dask is not involved at all. So this is definitely an issue with pytorch Dataloader, which is used for accessing the Dataset.

Thanks for your help.

If someone reading this encounters similar issue this PyTorch forum thread suggest such behaviour could be caused by a deadlock between dask and torch. But that is not the case for me.

1 Like