Skip to content

Commit

Permalink
decomp: make eigh_truncated backend agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 24, 2024
1 parent ef966ae commit c8ba863
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
38 changes: 35 additions & 3 deletions quimb/tensor/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,46 @@ def svdvals_eig(x): # pragma: no cover
return s2[::-1] ** 0.5


@compose
def eigh_truncated(
x,
cutoff=-1.0,
cutoff_mode=4,
max_bond=-1,
absorb=0,
renorm=0,
backend=None,
):
with backend_like(backend):
s, U = do("linalg.eigh", x)

# make sure largest singular value first
k = do("argsort", -do("abs", s))
s, U = s[k], U[:, k]

# absorb phase into V
V = ldmul(sgn(s), dag(U))
s = do("abs", s)
return _trim_and_renorm_svd_result(
U, s, V, cutoff, cutoff_mode, max_bond, absorb, renorm
)


@eigh_truncated.register("numpy")
@njit # pragma: no cover
def eigh(x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0):
def eigh_truncated_numba(
x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0
):
"""SVD-decomposition, using hermitian eigen-decomposition, only works if
``x`` is hermitian.
"""
s, U = np.linalg.eigh(x)
s, U = s[::-1], U[:, ::-1] # make sure largest singular value first

# make sure largest singular value first
k = np.argsort(-np.abs(s))
s, U = s[k], U[:, k]

# absorb phase into V
V = ldmul_numba(sgn_numba(s), dag_numba(U))
s = np.abs(s)
return _trim_and_renorm_svd_result_numba(
Expand Down Expand Up @@ -597,7 +629,7 @@ def eigsh(x, cutoff=0.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0):
if k == "full":
if not isinstance(x, np.ndarray):
x = x.to_dense()
return eigh(x, cutoff, cutoff_mode, max_bond, absorb)
return eigh_truncated(x, cutoff, cutoff_mode, max_bond, absorb)

s, U = base_linalg.eigh(x, k=k)
s, U = s[::-1], U[:, ::-1] # make sure largest singular value first
Expand Down
4 changes: 2 additions & 2 deletions quimb/tensor/tensor_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from .array_ops import asarray, sensibly_scale, reshape, do
from .contraction import array_contract
from .decomp import eigh
from .decomp import eigh_truncated
from .tensor_arbgeom import (
TensorNetworkGen,
TensorNetworkGenVector,
Expand Down Expand Up @@ -1785,7 +1785,7 @@ def classical_ising_sqrtS_matrix(beta, j=1.0, asymm=None):
network.
"""
if (j < 0.0) and (asymm is not None):
Slr = eigh(classical_ising_S_matrix(beta=beta, j=j))
Slr = eigh_truncated(classical_ising_S_matrix(beta=beta, j=j))
S_1_2 = {
"l": Slr[0],
"lT": Slr[0].T,
Expand Down

0 comments on commit c8ba863

Please sign in to comment.