I reconstructed the matrix using M’=UΣV^† after calculating svd. The real matrix input yields reasonable answer, and it fails for complex matrix input. Here is my code
import os
import sys
import time
import numpy as np
from opt_einsum import contract
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from dask_cuda.initialize import initialize
from dask.utils import parse_bytes
from dask.distributed import performance_report
from dask.distributed import wait
from dask.distributed import get_task_stream
import cupy
import rmm
import cudf
import dask.array as da
def setup_rmm_pool(client):
client.run(
cudf.set_allocator,
pool=False,
#initial_pool_size= parse_bytes("1GB"),
allocator="default"
)
client.run(
cupy.cuda.set_allocator,
#rmm.rmm_cupy_allocator,
rmm.mr.set_current_device_resource(rmm.mr.ManagedMemoryResource())
)
if __name__ == "__main__":
initialize(create_cuda_context=True)
cluster = LocalCUDACluster(local_directory="./tmp/",
memory_limit=None)
client = Client(cluster)
setup_rmm_pool(client)
nprs = np.random.RandomState(seed=1234)
rs = da.random.RandomState(seed=1234,RandomState=cupy.random.RandomState)
SIZE = 15000
k = 32
b = nprs.rand(SIZE) + 1j * nprs.rand(SIZE)
b = da.from_array(b, chunks=(5000))
b = b.map_blocks(cupy.asarray)
#a = contract("i,j->ij",b,b) * 10
a = da.einsum("i,j->ij",b,b) * 10
#a = a.persist()
a = da.exp(1.2*a)
t0=time.time()
u,s,vh=da.linalg.svd_compressed(a,k=k, seed=rs)
u,s,vh=da.compute(u,s,vh)
t1=time.time()
u=da.from_array(u,chunks=(5000,k))
vh=da.from_array(vh,chunks=(k,5000))
#b = contract("ij,j,jk->ik",u,s,vh)
b = da.einsum("ij,j,jk->ik",u,s,vh)
a = a - b #<-a:M, b:M'
tr = da.sum(da.diagonal(a)).compute()
print("trace:{:}".format(tr))
norm=da.linalg.norm(a).compute()
print("norm:{:}".format(norm))
sys.exit(0)
I calculated the trace and norm of M-M’ (M:original matrix; M’:M’~UΣV^†) which should be small. On the complex case, difference between M and M’ is large.
trace:(768972.9893344672+7563949.776901031j)
norm:54801713.60352144
But on the real case, it yields a reasonable answer.
trace:-2.470109161656353e-08
norm:1.8832952585877038e-07
I also posted the same on github