I’m planning to use dask to do research about tensor network, and in my case, dask runs much slower than numpy. The code is as following
code
import dask.array as da
import numpy as np
import time
from opt_einsum import contract
from dask.distributed import Client, LocalCluster
def tensor_reshape(T, Renorm_direction:str):
if Renorm_direction == "X" or Renorm_direction == "x":
return contract("cdab->abcd", T)
elif Renorm_direction == "Y" or Renorm_direction == "y":
return T
def tensor_contract(T, Dcut:int, Renorm_direction:str ):
T = tensor_reshape(T, Renorm_direction)
t0=time.time()
MM_dagLU = contract("aijk,blkm,cijn,dlnm->abcd", T, T, T, T)
MM_dagRD = contract("iajk,lbkm,icjn,ldnm->abcd", T, T, T, T)
t1=time.time()
MM_dagLU = da.reshape(MM_dagLU, (MM_dagLU.shape[0]*MM_dagLU.shape[1], MM_dagLU.shape[2]*MM_dagLU.shape[3]))
MM_dagRD = da.reshape(MM_dagRD, (MM_dagRD.shape[0]*MM_dagRD.shape[1], MM_dagRD.shape[2]*MM_dagRD.shape[3]))
t2=time.time()
Eigval_LU, Eigvect_LU = da.apply_gufunc(np.linalg.eigh,'(i,j)->(i),(i,j)', MM_dagLU, allow_rechunk=True)
Eigval_RD, Eigvect_RD = da.apply_gufunc(np.linalg.eigh,'(i,j)->(i),(i,j)', MM_dagRD, allow_rechunk=True)
t3=time.time()
del MM_dagRD, MM_dagLU
D = len(Eigval_LU)
if D <= Dcut:
Eigvect_cut = Eigvect_LU
else:
t4=time.time()
epsilon_LU = da.sum(Eigval_LU[:D - Dcut])
epsilon_RD = da.sum(Eigval_RD[:D - Dcut])
t5=time.time()
t8=time.time()
if epsilon_LU < epsilon_RD:
Eigvect_cut = Eigvect_LU[:,D - Dcut:]
else:
Eigvect_cut = Eigvect_RD[:,D - Dcut:]
t9=time.time()
del Eigvect_LU, Eigval_RD
I = int(np.sqrt(Eigvect_cut.shape[0]))
Eigvect_cut = da.reshape(Eigvect_cut, (I, I, Eigvect_cut.shape[1]))
t6=time.time()
T_new = contract("ikcm,jlmd,ija,klb->abcd", T, T, da.conj(Eigvect_cut), Eigvect_cut)
t7=time.time()
dt1=t1-t0
dt3=t3-t2
dt5=t5-t4
dt7=t7-t6
dt9=t9-t8
print("4 tensors contraction: {:.2e} s, evd: {:.2e} s, eigval sum: {:.2e} s, compare: {:.2e} s, new tensor: {:.2e} s".format(dt1,dt3,dt5,dt9,dt7))
return tensor_reshape(T_new, Renorm_direction)
if __name__ == "__main__":
cluster = LocalCluster(n_workers=16,threads_per_worker=1)
client = Client(cluster)
print(client)
Dcut=40
size=(Dcut,Dcut,Dcut,Dcut)
chunk=int(Dcut/2)
chunks=(chunk,chunk,chunk,chunk)
rs = da.random.RandomState(seed=1234, RandomState=np.random.RandomState)
T = rs.random(size=size,chunks=chunks)
T = T/da.max(T)
for i in range(10):
T = tensor_contract(T, Dcut, "Y")
T = T/da.max(T)
T = tensor_contract(T, Dcut, "X")
T = T/da.max(T)
t0=time.time()
T = T.compute()
t1=time.time()
print("{:.2e} s".format(t1-t0))
In the code, I’m timing each operation in the function tensor_contract(). I know dask is lazy evaluation and I’m timing the generation time of task graph. But the important thing is that the generation time of task graph is so long and increases with the loops in the main.
The timing details are shown as following. We can see when use dask, the compare time, which’s operation is surrounded by t8 and t9, is increasing, and finally the total time consumption becomes very large. On the other hand, the total time consumption is more lower.
out put detials by using dask:
<Client: 'tcp://127.0.0.1:37922' processes=16 threads=16, memory=186.33 GiB>
4 tensors contraction: 1.31e-02 s, evd: 2.82e-03 s, eigval sum: 2.72e-03 s, compare: 4.85e+00 s, new tensor: 9.04e-03 s
4 tensors contraction: 1.22e-02 s, evd: 3.11e-03 s, eigval sum: 4.23e-03 s, compare: 1.04e+01 s, new tensor: 7.68e-03 s
4 tensors contraction: 1.21e-02 s, evd: 2.67e-03 s, eigval sum: 2.71e-03 s, compare: 1.36e+01 s, new tensor: 6.20e-03 s
4 tensors contraction: 1.31e-02 s, evd: 2.74e-03 s, eigval sum: 2.66e-03 s, compare: 2.00e+01 s, new tensor: 6.17e-03 s
4 tensors contraction: 1.16e-02 s, evd: 2.55e-03 s, eigval sum: 2.78e-03 s, compare: 2.47e+01 s, new tensor: 7.07e-03 s
4 tensors contraction: 1.48e-02 s, evd: 3.77e-03 s, eigval sum: 1.80e-02 s, compare: 2.98e+01 s, new tensor: 8.28e-03 s
4 tensors contraction: 1.33e-02 s, evd: 3.10e-03 s, eigval sum: 2.65e-03 s, compare: 3.46e+01 s, new tensor: 7.88e-03 s
4 tensors contraction: 2.75e-02 s, evd: 5.13e-03 s, eigval sum: 3.01e-03 s, compare: 4.22e+01 s, new tensor: 8.39e-03 s
4 tensors contraction: 1.41e-02 s, evd: 3.43e-03 s, eigval sum: 3.70e-03 s, compare: 4.68e+01 s, new tensor: 8.48e-03 s
4 tensors contraction: 1.45e-02 s, evd: 3.68e-03 s, eigval sum: 3.57e-03 s, compare: 5.18e+01 s, new tensor: 8.52e-03 s
4 tensors contraction: 1.46e-02 s, evd: 3.98e-03 s, eigval sum: 3.56e-03 s, compare: 5.86e+01 s, new tensor: 8.00e-03 s
4 tensors contraction: 1.55e-02 s, evd: 4.63e-03 s, eigval sum: 3.45e-03 s, compare: 6.14e+01 s, new tensor: 8.07e-03 s
4 tensors contraction: 1.58e-02 s, evd: 5.25e-03 s, eigval sum: 3.03e-03 s, compare: 6.49e+01 s, new tensor: 8.72e-03 s
4 tensors contraction: 1.61e-02 s, evd: 1.72e-02 s, eigval sum: 3.45e-03 s, compare: 7.27e+01 s, new tensor: 8.26e-03 s
4 tensors contraction: 1.58e-02 s, evd: 4.30e-03 s, eigval sum: 3.45e-03 s, compare: 7.77e+01 s, new tensor: 9.91e-03 s
4 tensors contraction: 1.67e-02 s, evd: 1.50e-02 s, eigval sum: 5.95e-03 s, compare: 8.43e+01 s, new tensor: 8.66e-03 s
4 tensors contraction: 1.72e-02 s, evd: 5.07e-03 s, eigval sum: 1.99e-02 s, compare: 8.65e+01 s, new tensor: 8.51e-03 s
4 tensors contraction: 1.59e-02 s, evd: 4.30e-03 s, eigval sum: 7.72e-03 s, compare: 9.67e+01 s, new tensor: 8.77e-03 s
4 tensors contraction: 1.60e-02 s, evd: 4.90e-03 s, eigval sum: 3.61e-03 s, compare: 1.02e+02 s, new tensor: 8.64e-03 s
4 tensors contraction: 2.02e-02 s, evd: 1.53e-02 s, eigval sum: 3.96e-03 s, compare: 1.04e+02 s, new tensor: 9.15e-03 s
1.06e+02 s
out put detials by using numpy:
4 tensors contraction: 2.39e-01 s, evd: 8.25e-01 s, eigval sum: 5.21e-04 s, compare: 1.62e-05 s, new tensor: 2.47e+00 s
4 tensors contraction: 2.27e-01 s, evd: 7.48e-01 s, eigval sum: 6.87e-05 s, compare: 3.34e-06 s, new tensor: 2.48e+00 s
4 tensors contraction: 2.29e-01 s, evd: 7.08e-01 s, eigval sum: 4.00e-04 s, compare: 2.79e-05 s, new tensor: 2.44e+00 s
4 tensors contraction: 2.29e-01 s, evd: 7.03e-01 s, eigval sum: 5.96e-05 s, compare: 3.10e-06 s, new tensor: 2.45e+00 s
4 tensors contraction: 2.52e-01 s, evd: 7.01e-01 s, eigval sum: 6.25e-05 s, compare: 2.86e-06 s, new tensor: 2.50e+00 s
4 tensors contraction: 2.35e-01 s, evd: 7.09e-01 s, eigval sum: 6.18e-05 s, compare: 3.10e-06 s, new tensor: 2.50e+00 s
4 tensors contraction: 2.36e-01 s, evd: 7.04e-01 s, eigval sum: 8.77e-05 s, compare: 5.48e-06 s, new tensor: 2.49e+00 s
4 tensors contraction: 2.38e-01 s, evd: 6.99e-01 s, eigval sum: 6.99e-05 s, compare: 5.48e-06 s, new tensor: 2.50e+00 s
4 tensors contraction: 2.34e-01 s, evd: 7.05e-01 s, eigval sum: 6.20e-05 s, compare: 3.10e-06 s, new tensor: 2.46e+00 s
4 tensors contraction: 2.42e-01 s, evd: 7.04e-01 s, eigval sum: 6.10e-05 s, compare: 2.86e-06 s, new tensor: 2.49e+00 s
4 tensors contraction: 2.36e-01 s, evd: 7.15e-01 s, eigval sum: 6.13e-05 s, compare: 3.10e-06 s, new tensor: 2.72e+00 s
4 tensors contraction: 2.37e-01 s, evd: 7.15e-01 s, eigval sum: 9.30e-05 s, compare: 2.38e-06 s, new tensor: 2.46e+00 s
4 tensors contraction: 2.37e-01 s, evd: 7.14e-01 s, eigval sum: 7.77e-05 s, compare: 5.96e-06 s, new tensor: 2.43e+00 s
4 tensors contraction: 2.34e-01 s, evd: 7.03e-01 s, eigval sum: 6.99e-05 s, compare: 5.48e-06 s, new tensor: 2.50e+00 s
4 tensors contraction: 2.35e-01 s, evd: 7.11e-01 s, eigval sum: 6.27e-05 s, compare: 2.86e-06 s, new tensor: 2.44e+00 s
4 tensors contraction: 2.35e-01 s, evd: 7.11e-01 s, eigval sum: 6.13e-05 s, compare: 3.10e-06 s, new tensor: 2.51e+00 s
4 tensors contraction: 2.57e-01 s, evd: 8.80e-01 s, eigval sum: 5.98e-05 s, compare: 3.58e-06 s, new tensor: 2.58e+00 s
4 tensors contraction: 2.34e-01 s, evd: 7.18e-01 s, eigval sum: 6.22e-05 s, compare: 2.86e-06 s, new tensor: 2.46e+00 s
4 tensors contraction: 2.32e-01 s, evd: 7.36e-01 s, eigval sum: 7.27e-05 s, compare: 3.58e-06 s, new tensor: 2.56e+00 s
4 tensors contraction: 2.32e-01 s, evd: 7.25e-01 s, eigval sum: 6.13e-05 s, compare: 2.62e-06 s, new tensor: 2.45e+00 s
7.00e+01 s