Partial_fit in dask_ml.wrappers.Incremental crashes memory

I have some data larger than available memory. I am using Scikit-learn SGDRegressor to fit a Lasso on this data using stochastic gradient descent. I pass the Scikit-learn estimator into the Incremental wrapper and then call the partia_fit function as follow:

est = Incremental(SGDRegressor(penalty="l1")).partial_fit(X_train, y_train)

X_train and y_train are dask arrays containing the data for training.

I noticed, in the dashboard, that the problem may be related to the fact that dask reads my parquet files (the data is split into many parquet files) quicker than (partial) fitting the estimator on each parquet. This makes the memory gets full quickly while waiting for the fit function. Is there any solution for this problem? Can we make dask slowly read the parquet files so that the partial_fit function can catch up?

Hi @Zalla, welcome to Dask community!

I think the correct way to use Incremental wrapper (as described in the documentation) would be something like:

est = Incremental(SGDRegressor(penalty="l1"))
est.fit(X_train, y_train)

You shouldn’t call the partial_fit function.

I also tried to use fit but I always have the same problem ; in dask dashboard the memory fills up quickly with the read_parquet function and the fit function doesn’t catch up (to help free the memory).
I reimplemented my own fit function using Scikit-learn as following:

est = SGDRegressor(penalty="l1")
for block in zip(X_train.blocks, y_train.blocks):
    X_train_i, y_train_i = dask.compute(block[0], block[1])
    est.partial_fit(X_train_i, y_train_i)

It works well but is there any better way?

Which Dask version are you using? Memory filling up has been fixed quite a long time ago. Dask shouldn’t be reading more than 1.1 * nthreads chunk of data per worker.

I am using dask_ml 2023.3.24 and dask 2023.6.0

I don’t have this problem when I use dask. It only happens with dask_ml when using the Incremental wrapper on SGDRegressor or SGDClassifier estimators of scikit-learn.

Versions are recent enough, so this is not the problem I was thinking about. Again, you shouldn’t have more than 1,1 more input chunks in memory than the total number of Worker threads you have.

Could you produce a minimal reproducer?

It might be due to your input dataset, with too large chunks. Are you able to compute say the mean of X_train?