Dask.array.svd_compressed() fails for calculating svd of complex matrix

I reconstructed the matrix using M’=UΣV^† after calculating svd. The real matrix input yields reasonable answer, and it fails for complex matrix input. Here is my code

import os
import sys
import time
import numpy as np
from opt_einsum import contract

from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from dask_cuda.initialize import initialize
from dask.utils import parse_bytes
from dask.distributed import performance_report
from dask.distributed import wait
from dask.distributed import get_task_stream

import cupy
import rmm
import cudf
import dask.array as da


def setup_rmm_pool(client):
    client.run(
        cudf.set_allocator,
        pool=False,
        #initial_pool_size= parse_bytes("1GB"),
        allocator="default"
    )
    client.run(
        cupy.cuda.set_allocator,
        #rmm.rmm_cupy_allocator,
        rmm.mr.set_current_device_resource(rmm.mr.ManagedMemoryResource())
    )

if __name__ == "__main__":

    initialize(create_cuda_context=True)
    
    cluster = LocalCUDACluster(local_directory="./tmp/",
                                memory_limit=None)
       
    client = Client(cluster)
    setup_rmm_pool(client)

    nprs = np.random.RandomState(seed=1234)
    rs = da.random.RandomState(seed=1234,RandomState=cupy.random.RandomState)

    SIZE = 15000
    k = 32

    b = nprs.rand(SIZE) + 1j * nprs.rand(SIZE)
    b = da.from_array(b, chunks=(5000))
    b = b.map_blocks(cupy.asarray)
    #a = contract("i,j->ij",b,b) * 10
    a = da.einsum("i,j->ij",b,b) * 10
    #a = a.persist()

    a = da.exp(1.2*a)
    
    t0=time.time()
    u,s,vh=da.linalg.svd_compressed(a,k=k, seed=rs)
    u,s,vh=da.compute(u,s,vh)
    t1=time.time()
    u=da.from_array(u,chunks=(5000,k))
    vh=da.from_array(vh,chunks=(k,5000))
        
    #b = contract("ij,j,jk->ik",u,s,vh)
    b = da.einsum("ij,j,jk->ik",u,s,vh)
    a = a - b #<-a:M, b:M'
    tr = da.sum(da.diagonal(a)).compute()
    print("trace:{:}".format(tr))
    norm=da.linalg.norm(a).compute()
    print("norm:{:}".format(norm))

    sys.exit(0)

I calculated the trace and norm of M-M’ (M:original matrix; M’:M’~UΣV^†) which should be small. On the complex case, difference between M and M’ is large.

trace:(768972.9893344672+7563949.776901031j)
norm:54801713.60352144

But on the real case, it yields a reasonable answer.

trace:-2.470109161656353e-08
norm:1.8832952585877038e-07

I also posted the same on github

Thanks for this question, this is a bug ref: svd_compressed() fails for complex input · Issue #7639 · dask/dask · GitHub.