What's the most efficient way to fetch a shuffled dataframe in batches for ML training?

In most ML algorithms, data is fed to the algorithm in batches (where the batch-size is a parameter). Typically, the order of elements is shuffled before each iteration.

What is the most efficient way to replicate this using a Dask dataframe, where each row is an example? We will assume the model is being trained on the same machine as the Client.

This the paradigm I am considering so far:

In each iteration:

  • Select a random partition (numpy.random.randint(…)) and fetch it to the client as a Pandas Dataframe using .compute(). Keep track of the partition number.
  • Feed it to the model (this itself can be done in smaller batches with shuffling)
  • Select the next partition.
  • Etc.

Is there an inbuilt way to do this? In my approach, I feel like (a) state management falls to me (b) an entire partition must be loaded on the client, which might be infeasible depending on the client limits/partition size (it would be good to only fetch the necessary subset of rows)

I’m not sure of this entirely relevant to your example, but most ML libraries will have this type of functionality built in. Take a look at the dask_ml.model_selection.KFold docs for the dask_ml package for example:


It allows you to do this using the shuffle argument.


Hey Matthias. I’m not looking for K-Fold, which from my understanding creates 2*K indexes (one for each train and val kold) which can be used to access the dataframe.

My question is about streaming data to a model for training.

Suppose I have to train a BERT model on 100MM rows of data. I’m definitely not going to all 100MM rows, I will use minibatch Stochastic Gradient Descent, and feed the model, say, 64 rows at a time.

So basically I want to access the entire dataframe, 64 rows at a time, and convert each 64 rows into a Pandas DataFrame with .compute(). I have written some code which does this:

def stream_chunks(
            df,  ## Dask DataFrame
            chunk_num_rows: int = 64,
            shuffle: bool = True,
        idx: np.ndarray = np.arange(0,  df.npartitions)
        if shuffle:
            idx = np.random.permutation(idx)
        for partition_i in idx:
            df_partition = df.partitions[partition_i].compute()
            for i in range(0, len(df_partition), chunk_num_rows):
                df_chunk = df_partition.iloc[i:i + chunk_num_rows, :]
                yield df_chunk

This returns an iterator, which can be invoked as follows:

num_rows_in_df = np.sum([df_chunk.shape[0] for df_chunk in stream_chunks(df, chunk_num_rows=64)])
assert num_rows_in_df == len(df)

However, this method has a few disadvantages:

  1. Does not support state across multiple clients. E.g. if I am doing distributed training, I might want chunk_0_63 to go to client A (which is running model_copy_A) and chunk_64_127 to go to client A (which is running model_copy_A).

  2. It’s slow. When using a local cluster and pre-persisting the dataframe in memory via df.persist(), each call to df_partition = df.partitions[partition_i].compute() takes about 50ms on an m5.24xlarge machine…so basically 50ms just to convert a partition to Pandas. Ideally I would want multiple of these partitions to be cached and ready, since doing so is an IO-bound operation which is trivially parallelizable even with GIL-locked Python.

  3. (Minor) requires storing an entire df_partition in memory (as Pandas) on the client side. This can be large, depending on the kind of data and your partition-size.

Since you’re trying to stream data I’d take a look at this:


And the suggestions linked in that order:

1. Use normal for loops with Client.submit/gather and as_completed

2. Use asynchronous async/await code and a few coroutines

3. Try out the Streamz project, which has Dask support

Some general feedback, which you’ll have to take with a serious grain of salt:

  1. Am I understanding correctly you’d try to run multiple Clients? It may be better to run a single Dask scheduler / client par and have each worker train a single model locally. Which you could possibly try to combine after training.

It seems like the Streamz project may have some pre-built machinery you could use?