Hello All,

I’ve re-written the multiplicative NMF in Dask and I found that’s its too slow running in Colab with 8 cores. It took 6 minutes to complete on a random matrix of size [4000 x 3000]. Compared to non-distributed numpy version of NMF, it took only less than 5 seconds to complete. It doesn’t make sense as the NMF Dask version should be faster. If anyone has a useful tip, that would be great. Here’s the NMF Dask re-write:

#----- Code Snippets -----

```
import dask.array as da
import numpy as np
def multiplicative_update_dask(A: dask.array, k: int=2, max_iter: int=100, init_mode: str='random'):
"""
Args:
A:
k:
max_iter:
init_mode:
Returns:
W:
H:
norms:
"""
rank = k
num_rows, num_cols = A.shape
W = da.random.random((num_rows, rank), chunks=(100, 200))
H = da.random.random((rank, num_cols), chunks=(100, 200))
norms = []
epsilon = 1.0e-10
for _ in range(max_iter):
# Update H
W_TA = da.dot(W.T, A)
WH = da.dot(W, H)
W_TWH = da.dot(W.T, WH) + epsilon
H = H * (W_TA / W_TWH)
# Update W
AH_T = da.dot(A, H.T)
HHT = da.dot(H, H.T)
WHH_T = da.dot(W, HHT) + epsilon
W = W * (AH_T / WHH_T)
WH2 = da.dot(W, H)
norm = da.linalg.norm(A - WH2, 'fro')
norms.append(norm.compute())
return W.compute(), H.compute(), norms
A = da.random.random((nrows, ncols), chunks=(100, 100)).persist()
W, H, norms = multiplicative_update_dask(A, k=4, max_iter=5)
```

#The function call above took ~6 minutes

###-------------------------------

```
def multiplicative_update_numpy(A, k, max_iter, init_mode='random'):
"""
Perform Multiplicative Update (MU) algorithm for Non-negative Matrix Factorization (NMF).
Parameters:
- A: Input matrix
- k: Rank of the factorization
- max_iter: Maximum number of iterations
- init_mode: Initialization mode ('random' or 'nndsvd')
Returns:
- W: Factorized matrix W
- H: Factorized matrix H
- norms: List of Frobenius norms at each iteration
"""
rank = k
num_rows, num_cols = A.shape
W = np.random.rand(num_rows, rank)
H = np.random.rand(rank, num_cols)
norms = []
epsilon = 1.0e-10
for cn in range(max_iter):
# Update H
W_TA = W.T @ A
W_TWH = W.T @ W @ H + epsilon
H *= W_TA / W_TWH
# Update W
AH_T = A @ H.T
WHH_T = W @ H @ H.T + epsilon
W *= AH_T / WHH_T
norm = np.linalg.norm(A - W @ H, 'fro')
norms.append(norm)
print('Iter #', (cn+1))
return W, H, norms
Ac = A.compute()
W2, H2, norms2 = multiplicative_update_numpy(Ac, k=4, max_iter=5)
```

# #The function call above took less than 5 seconds.

#----- End Code Snippets -----

Does anyone have a tip about the NMF Dask rewrite? Any feedback would be much appreciated.

Thanks.