Non-negative Matrix Factorization (NMF) in Dask is too slow

Hello All,

I’ve re-written the multiplicative NMF in Dask and I found that’s its too slow running in Colab with 8 cores. It took 6 minutes to complete on a random matrix of size [4000 x 3000]. Compared to non-distributed numpy version of NMF, it took only less than 5 seconds to complete. It doesn’t make sense as the NMF Dask version should be faster. If anyone has a useful tip, that would be great. Here’s the NMF Dask re-write:

#----- Code Snippets -----

import dask.array as da
import numpy as np

def multiplicative_update_dask(A: dask.array, k: int=2, max_iter: int=100, init_mode: str='random'):
    """

    Args:
      A:
      k:
      max_iter:
      init_mode:

    Returns:
      W:
      H:
      norms:

    """
    rank = k
    num_rows, num_cols = A.shape
  
    W = da.random.random((num_rows, rank), chunks=(100, 200)) 
    H = da.random.random((rank, num_cols), chunks=(100, 200)) 

    norms = []
    epsilon = 1.0e-10

    for _ in range(max_iter):
        # Update H
        W_TA = da.dot(W.T, A)
        WH = da.dot(W, H)
        W_TWH = da.dot(W.T, WH) + epsilon
        H = H * (W_TA / W_TWH)

        # Update W
        AH_T = da.dot(A, H.T)
        HHT = da.dot(H, H.T)
        WHH_T = da.dot(W, HHT) + epsilon
        W = W * (AH_T / WHH_T)

        WH2 = da.dot(W, H)
        norm = da.linalg.norm(A - WH2, 'fro')
        norms.append(norm.compute())

    return W.compute(), H.compute(), norms

A = da.random.random((nrows, ncols), chunks=(100, 100)).persist()

W, H, norms = multiplicative_update_dask(A, k=4, max_iter=5)

#The function call above took ~6 minutes

###-------------------------------

def multiplicative_update_numpy(A, k, max_iter, init_mode='random'):
    """
    Perform Multiplicative Update (MU) algorithm for Non-negative Matrix Factorization (NMF).

    Parameters:
    - A: Input matrix
    - k: Rank of the factorization
    - max_iter: Maximum number of iterations
    - init_mode: Initialization mode ('random' or 'nndsvd')

    Returns:
    - W: Factorized matrix W
    - H: Factorized matrix H
    - norms: List of Frobenius norms at each iteration
    """

    rank = k
    num_rows, num_cols = A.shape
   
    W = np.random.rand(num_rows, rank)
    H = np.random.rand(rank, num_cols)

    norms = []
    epsilon = 1.0e-10
    for cn in range(max_iter):
        # Update H
        W_TA = W.T @ A
        W_TWH = W.T @ W @ H + epsilon
        H *= W_TA / W_TWH

        # Update W
        AH_T = A @ H.T
        WHH_T = W @ H @ H.T + epsilon
        W *= AH_T / WHH_T

        norm = np.linalg.norm(A - W @ H, 'fro')
        norms.append(norm)

        print('Iter #', (cn+1))

    return W, H, norms

Ac = A.compute()

W2, H2, norms2 = multiplicative_update_numpy(Ac, k=4, max_iter=5)

#The function call above took less than 5 seconds.

#----- End Code Snippets -----

Does anyone have a tip about the NMF Dask rewrite? Any feedback would be much appreciated.

Thanks.

Hi @spalu, welcome to Dask Discourse forum!

Firstly:

In general, you do not want to use Dask to speedup problems that are not memory bound and take seconds to finish. Why would you want it to take less than 5 seconds? Do you need to scale it on bigger problem at some points? Something that would not fit into memory?

I tried your code snippet, and reproduced the same behavior: less than a second with Numpy, 35 with your Dask version. I did not look into the code, but again, for such a fast computation and low memory footprint, this is perfectly normal!

A few othr things:

  • you are using (100,100) chunks, which is way to small, increasing Dask overhead. Just using (1000,1000) chunks, the time was reduced to 12s.
  • I get a lot of messages telling: PerformanceWarning: Increasing number of chunks by factor of 30. Again, I did not look inside your code, but you are probably doing operation generating a lot more chunks.
  • Dask introduce overhead, for each task and partitions generated by its Task graph. Splitting into to small parts your data is one problem, but in any case, you will generate overhead. On a computation that takes less than a seconds, or a few seconds at most, Dask will not help.