I have a zarr (v3) array (shape: [256000, 20, 64, 64]), which is a effectively a collection of video sequences that I want to use for training a PyTorch model. But initial access time is very long for unknown to me reason.
To read the data I simply created a torch Dataset
, which uses dask.array.core.from_zarr
to read the data as Dask Array. As the data is store in uint8 (for smaller file size) I also need to cast the data into float32 and normalize by dividing by 255.0. Then when accessing I slice the dask array and call compute()
to convert the slice to torch.tensor
.
The core functionality boils down to this:
import dask.array.core
import numpy
import torch
from torch.utils.data import Dataset
class ZarrDataset(Dataset):
def __init__(self):
super().__init__()
data = dask.array.core.from_zarr("dataset.zarr/arr").astype(numpy.float32, casting="safe")
self._data = data / 255.0
def __getitem__(self, idx):
inputs = torch.tensor(
self._data[idx, :10, :, :].compute(),
requires_grad=True,
dtype=torch.float,
).unsqueeze(1)
targets = torch.tensor(
self._data[idx, 10:, :, :].compute(),
requires_grad=True,
dtype=torch.float,
).unsqueeze(1)
return inputs, targets
__getitem__
is where the issue occurs as that function just for minutes for I don’t really know what - there’s no significant CPU usage, RAM usage or IO usage.
I preferably would use xarray
for this task, but I doesn’t support zarr v3.
I tried using .compute()
in the constructor but that only resulted in very fast growth of memory usage. I also tried doing the casting and division after calling .compute()
, but it didn’t help.
I obviously don’t understand something about dask’s inner workings and I cannot find an information what kind of issue my code encounters.
How can I make my code read the data fast?