From 71b82c81c5bba0e190372d23d7f9df5e46fa6859 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Tue, 23 Jan 2024 22:55:31 -0800 Subject: [PATCH] tensor_core.py: black format --- quimb/tensor/tensor_core.py | 2277 +++++++++++++++++++---------------- 1 file changed, 1271 insertions(+), 1006 deletions(-) diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index e0b4d4d3..24c2e2a1 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -16,21 +16,52 @@ import numpy as np import scipy.sparse.linalg as spla from autoray import ( - do, conj, astype, infer_backend, get_dtype_name, dag, shape, size + astype, + conj, + dag, + do, + get_dtype_name, + infer_backend, + shape, + size, ) + try: from autoray import get_common_dtype except ImportError: from ..core import common_type as get_common_dtype -from ..core import (qarray, prod, realify_scalar, vdot, make_immutable) -from ..utils import (check_opt, oset, concat, frequencies, unique, deprecated, - valmap, ensure_dict, gen_bipartitions, tree_map) -from ..gen.rand import randn, seed_rand, rand_matrix, rand_uni +from ..core import qarray, prod, realify_scalar, vdot, make_immutable +from ..utils import ( + check_opt, + concat, + deprecated, + ensure_dict, + frequencies, + gen_bipartitions, + oset, + tree_map, + unique, + valmap, +) +from ..gen.rand import randn, seed_rand, rand_matrix, rand_uni, rand_iso from . import decomp -from .array_ops import (iscomplex, norm_fro, ndim, asarray, PArray, - find_diag_axes, find_antidiag_axes, find_columns) -from .drawing import draw_tn, visualize_tensor, auto_color_html +from .array_ops import ( + asarray, + find_antidiag_axes, + find_columns, + find_diag_axes, + iscomplex, + ndim, + norm_fro, + PArray, +) +from .drawing import ( + auto_color_html, + draw_tn, + visualize_tensor, + visualize_tensors, +) from .contraction import ( array_contract_expression, @@ -47,18 +78,18 @@ inds_to_symbols, ) -_inds_to_eq = deprecated(inds_to_eq, '_inds_to_eq', 'inds_to_eq') +_inds_to_eq = deprecated(inds_to_eq, "_inds_to_eq", "inds_to_eq") get_symbol = deprecated( - get_symbol, 'tensor_core.get_symbol', 'contraction.get_symbol' + get_symbol, "tensor_core.get_symbol", "contraction.get_symbol" ) # --------------------------------------------------------------------------- # # Tensor Funcs # # --------------------------------------------------------------------------- # + def oset_union(xs): - """Non-variadic ordered set union taking any sequence of iterables. - """ + """Non-variadic ordered set union taking any sequence of iterables.""" return oset(concat(xs)) @@ -68,8 +99,7 @@ def oset_intersection(xs): def tags_to_oset(tags): - """Parse a ``tags`` argument into an ordered set. - """ + """Parse a ``tags`` argument into an ordered set.""" if tags is None: return oset() elif isinstance(tags, (str, int)): @@ -103,12 +133,7 @@ def _gen_output_inds(all_inds): def _tensor_contract_get_other( - arrays, - inds, - inds_out, - shapes, - get, - **contract_opts + arrays, inds, inds_out, shapes, get, **contract_opts ): check_opt("get", get, _VALID_CONTRACT_GET) @@ -122,7 +147,7 @@ def _tensor_contract_get_other( output=inds_out, shapes=shapes, constants=constants, - **contract_opts + **contract_opts, ) if get == "tree": @@ -163,7 +188,7 @@ def tensor_contract( backend=None, preserve_tensor=False, drop_tags=False, - **contract_opts + **contract_opts, ): """Contract a collection of tensors into a scalar or tensor, automatically aligning their indices and computing an optimized contraction path. @@ -241,15 +266,17 @@ def tensor_contract( shapes=shapes, get=get, optimize=optimize, - **contract_opts + **contract_opts, ) # perform the contraction! data_out = array_contract( - arrays, inds, inds_out, + arrays, + inds, + inds_out, optimize=optimize, backend=backend, - **contract_opts + **contract_opts, ) if not inds_out and not preserve_tensor: @@ -273,7 +300,7 @@ def tensor_contract( itertools.chain.from_iterable( itertools.product(_RAND_ALPHABET, repeat=repeat) for repeat in itertools.count(5) - ) + ), ) @@ -292,39 +319,54 @@ def rand_uuid(base=""): return f"{base}_{_RAND_PREFIX}{next(RAND_UUIDS)}" -_VALID_SPLIT_GET = {None, 'arrays', 'tensors', 'values'} +_VALID_SPLIT_GET = {None, "arrays", "tensors", "values"} _SPLIT_FNS = { - 'svd': decomp.svd_truncated, - 'eig': decomp.svd_via_eig_truncated, - 'lu': decomp.lu_truncated, - 'qr': decomp.qr_stabilized, - 'lq': decomp.lq_stabilized, - 'polar_right': decomp.polar_right, - 'polar_left': decomp.polar_left, - 'eigh': decomp.eigh, - 'cholesky': decomp.cholesky, - 'isvd': decomp.isvd, - 'svds': decomp.svds, - 'rsvd': decomp.rsvd, - 'eigsh': decomp.eigsh, + "svd": decomp.svd_truncated, + "svdamr": decomp.svd_truncated_amr, + "eig": decomp.svd_via_eig_truncated, + "lu": decomp.lu_truncated, + "qr": decomp.qr_stabilized, + "lq": decomp.lq_stabilized, + "polar_right": decomp.polar_right, + "polar_left": decomp.polar_left, + "eigh": decomp.eigh_truncated, + "cholesky": decomp.cholesky, + "isvd": decomp.isvd, + "svds": decomp.svds, + "rsvd": decomp.rsvd, + "eigsh": decomp.eigsh, } -_SPLIT_VALUES_FNS = {'svd': decomp.svdvals, 'eig': decomp.svdvals_eig} -_FULL_SPLIT_METHODS = {'svd', 'eig', 'eigh'} -_RANK_HIDDEN_METHODS = {'qr', 'lq', 'cholesky', 'polar_right', 'polar_left'} +_SPLIT_VALUES_FNS = {"svd": decomp.svdvals, "eig": decomp.svdvals_eig} +_FULL_SPLIT_METHODS = {"svd", "svdamr", "eig", "eigh"} +_RANK_HIDDEN_METHODS = {"qr", "lq", "cholesky", "polar_right", "polar_left"} _DENSE_ONLY_METHODS = { - 'svd', 'eig', 'eigh', 'cholesky', 'qr', 'lq', 'polar_right', 'polar_left', - 'lu', + "svd", + "eig", + "eigh", + "cholesky", + "qr", + "lq", + "polar_right", + "polar_left", + "lu", + "svdamr", } -_LEFT_ISOM_METHODS = {'qr', 'polar_right'} -_RIGHT_ISOM_METHODS = {'lq', 'polar_left'} -_ISOM_METHODS = {'svd', 'eig', 'eigh', 'isvd', 'svds', 'rsvd', 'eigsh'} +_LEFT_ISOM_METHODS = {"qr", "polar_right"} +_RIGHT_ISOM_METHODS = {"lq", "polar_left"} +_ISOM_METHODS = {"svd", "eig", "eigh", "isvd", "svds", "rsvd", "eigsh"} _CUTOFF_LOOKUP = {None: -1.0} -_ABSORB_LOOKUP = {'left': -1, 'both': 0, 'right': 1, None: None} +_ABSORB_LOOKUP = {"left": -1, "both": 0, "right": 1, None: None} _MAX_BOND_LOOKUP = {None: -1} -_CUTOFF_MODES = {'abs': 1, 'rel': 2, 'sum2': 3, - 'rsum2': 4, 'sum1': 5, 'rsum1': 6} -_RENORM_LOOKUP = {'sum2': 2, 'rsum2': 2, 'sum1': 1, 'rsum1': 1} +_CUTOFF_MODES = { + "abs": 1, + "rel": 2, + "sum2": 3, + "rsum2": 4, + "sum1": 5, + "rsum1": 6, +} +_RENORM_LOOKUP = {"sum2": 2, "rsum2": 2, "sum1": 1, "rsum1": 1} @functools.lru_cache(None) @@ -335,35 +377,34 @@ def _parse_split_opts(method, cutoff, absorb, max_bond, cutoff_mode, renorm): if absorb is None: raise ValueError( "You can't return the singular values separately when " - "`method='{}'`.".format(method)) + "`method='{}'`.".format(method) + ) # options are only relevant for handling singular values return opts # convert defaults and settings to numeric type for numba funcs - opts['cutoff'] = _CUTOFF_LOOKUP.get(cutoff, cutoff) - opts['absorb'] = _ABSORB_LOOKUP[absorb] - opts['max_bond'] = _MAX_BOND_LOOKUP.get(max_bond, max_bond) - opts['cutoff_mode'] = _CUTOFF_MODES[cutoff_mode] + opts["cutoff"] = _CUTOFF_LOOKUP.get(cutoff, cutoff) + opts["absorb"] = _ABSORB_LOOKUP[absorb] + opts["max_bond"] = _MAX_BOND_LOOKUP.get(max_bond, max_bond) + opts["cutoff_mode"] = _CUTOFF_MODES[cutoff_mode] # renorm doubles up as the power used to renormalize if (method in _FULL_SPLIT_METHODS) and (renorm is None): - opts['renorm'] = _RENORM_LOOKUP.get(cutoff_mode, 0) + opts["renorm"] = _RENORM_LOOKUP.get(cutoff_mode, 0) else: - opts['renorm'] = 0 if renorm is None else renorm + opts["renorm"] = 0 if renorm is None else renorm return opts @functools.lru_cache(None) def _check_left_right_isom(method, absorb): - left_isom = ( - (method in _LEFT_ISOM_METHODS) or - (method in _ISOM_METHODS and absorb in (None, 'right')) + left_isom = (method in _LEFT_ISOM_METHODS) or ( + method in _ISOM_METHODS and absorb in (None, "right") ) - right_isom = ( - (method == _RIGHT_ISOM_METHODS) or - (method in _ISOM_METHODS and absorb in (None, 'left')) + right_isom = (method == _RIGHT_ISOM_METHODS) or ( + method in _ISOM_METHODS and absorb in (None, "left") ) return left_isom, right_isom @@ -371,12 +412,12 @@ def _check_left_right_isom(method, absorb): def tensor_split( T, left_inds, - method='svd', + method="svd", get=None, - absorb='both', + absorb="both", max_bond=None, cutoff=1e-10, - cutoff_mode='rel', + cutoff_mode="rel", renorm=None, ltags=None, rtags=None, @@ -478,7 +519,7 @@ def tensor_split( ``absorb=None`` the returned objects correspond to ``(left, singular_values, right)``. """ - check_opt('get', get, _VALID_SPLIT_GET) + check_opt("get", get, _VALID_SPLIT_GET) if left_inds is None: left_inds = oset(T.inds) - oset(right_inds) @@ -499,19 +540,20 @@ def tensor_split( array = T else: TT = T.transpose(*left_inds, *right_inds) - left_dims = TT.shape[:len(left_inds)] - right_dims = TT.shape[len(left_inds):] + left_dims = TT.shape[: len(left_inds)] + right_dims = TT.shape[len(left_inds) :] if (len(left_dims), len(right_dims)) != (1, 1): array = do("reshape", TT.data, (prod(left_dims), prod(right_dims))) else: array = TT.data - if get == 'values': + if get == "values": return _SPLIT_VALUES_FNS[method](array) opts = _parse_split_opts( - method, cutoff, absorb, max_bond, cutoff_mode, renorm) + method, cutoff, absorb, max_bond, cutoff_mode, renorm + ) # ``s`` itself will be None unless ``absorb=None`` is specified left, s, right = _SPLIT_FNS[method](array, **opts) @@ -520,7 +562,7 @@ def tensor_split( if len(right_dims) != 1: right = do("reshape", right, (-1, *right_dims)) - if get == 'arrays': + if get == "arrays": if absorb is None: return left, s, right return left, right @@ -546,19 +588,14 @@ def tensor_split( if right_isom: Tr.modify(left_inds=right_inds) - if get == 'tensors': + if get == "tensors": return tensors return TensorNetwork(tensors, virtual=True) def tensor_canonize_bond( - T1, - T2, - absorb='right', - gauges=None, - gauge_smudge=1e-6, - **split_opts + T1, T2, absorb="right", gauges=None, gauge_smudge=1e-6, **split_opts ): r"""Inplace 'canonization' of two tensors. This gauges the bond between the two such that ``T1`` is isometric:: @@ -581,16 +618,17 @@ def tensor_canonize_bond( Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`, with modified defaults of ``method=='qr'`` and ``absorb='right'``. """ - check_opt('absorb', absorb, ('left', 'both', 'right')) + check_opt("absorb", absorb, ("left", "both", "right")) - if absorb == 'both': + if absorb == "both": # same as doing reduced compression with no truncation - split_opts.setdefault('cutoff', 0.0) + split_opts.setdefault("cutoff", 0.0) return tensor_compress_bond( - T1, T2, gauges=gauges, gauge_smudge=gauge_smudge, **split_opts) + T1, T2, gauges=gauges, gauge_smudge=gauge_smudge, **split_opts + ) - split_opts.setdefault('method', 'qr') - if absorb == 'left': + split_opts.setdefault("method", "qr") + if absorb == "left": T1, T2 = T2, T1 lix, bix, _ = tensor_make_single_bond(T1, T2, gauges=gauges) @@ -608,7 +646,7 @@ def tensor_canonize_bond( outer, _ = tn.gauge_simple_insert(gauges, smudge=gauge_smudge) gauges.pop(bix, None) - new_T1, tRfact = T1.split(lix, get='tensors', **split_opts) + new_T1, tRfact = T1.split(lix, get="tensors", **split_opts) new_T2 = tRfact @ T2 new_T1.transpose_like_(T1) @@ -625,11 +663,11 @@ def tensor_compress_bond( T1, T2, reduced=True, - absorb='both', + absorb="both", gauges=None, gauge_smudge=1e-6, info=None, - **compress_opts + **compress_opts, ): r"""Inplace compress between the two single tensors. It follows the following steps to minimize the size of SVD performed:: @@ -665,7 +703,6 @@ def tensor_compress_bond( compress_opts : Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`. """ - lix, bix, rix = tensor_make_single_bond(T1, T2, gauges=gauges) if not bix: raise ValueError("The tensors specified don't share an bond.") @@ -678,18 +715,22 @@ def tensor_compress_bond( if reduced is True: # a) -> b) T1_L, T1_R = T1.split( - left_inds=lix, right_inds=bix, - get='tensors', method='qr') + left_inds=lix, right_inds=bix, get="tensors", method="qr" + ) T2_L, T2_R = T2.split( - left_inds=bix, right_inds=rix, - get='tensors', method='lq') + left_inds=bix, right_inds=rix, get="tensors", method="lq" + ) # b) -> c) M = T1_R @ T2_L # c) -> d) M_L, *s, M_R = M.split( - left_inds=T1_L.bonds(M), bond_ind=bix, - get='tensors', absorb=absorb, **compress_opts) + left_inds=T1_L.bonds(M), + bond_ind=bix, + get="tensors", + absorb=absorb, + **compress_opts, + ) # d) -> e) T1C = T1_L.contract(M_L, output_inds=T1.inds) @@ -698,14 +739,15 @@ def tensor_compress_bond( elif reduced == 'lazy': compress_opts.setdefault('method', 'isvd') T12 = TNLinearOperator((T1, T2), lix, rix) - T1C, *s, T2C = T12.split(get='tensors', absorb=absorb, **compress_opts) + T1C, *s, T2C = T12.split(get="tensors", absorb=absorb, **compress_opts) T1C.transpose_like_(T1) T2C.transpose_like_(T2) else: T12 = T1 @ T2 - T1C, *s, T2C = T12.split(left_inds=lix, get='tensors', - absorb=absorb, **compress_opts) + T1C, *s, T2C = T12.split( + left_inds=lix, get="tensors", absorb=absorb, **compress_opts + ) T1C.transpose_like_(T1) T2C.transpose_like_(T2) @@ -713,13 +755,13 @@ def tensor_compress_bond( T1.modify(data=T1C.data) T2.modify(data=T2C.data) - if absorb == 'right': + if absorb == "right": T1.modify(left_inds=lix) - elif absorb == 'left': + elif absorb == "left": T2.modify(left_inds=rix) if s and info is not None: - info['singular_values'] = s[0].data + info["singular_values"] = s[0].data if gauges is not None: tn.gauge_simple_remove(outer=outer) @@ -746,7 +788,7 @@ def tensor_balance_bond(t1, t2, smudge=1e-6): Avoid numerical issues by 'smudging' the correctional factor by this much - the gauging introduced is still exact. """ - ix, = bonds(t1, t2) + (ix,) = bonds(t1, t2) x = tensor_contract(t1.H, t1, output_inds=[ix]).data y = tensor_contract(t2.H, t2, output_inds=[ix]).data s = (x + smudge) / (y + smudge) @@ -762,9 +804,10 @@ def tensor_multifuse(ts, inds, gauges=None): if (gauges is not None) and any(ix in gauges for ix in inds): # gauge fusing gs = [ - gauges.pop(ix) if ix in gauges else + gauges.pop(ix) + if ix in gauges # if not present, ones is the identity gauge - do("ones", ts[0].ind_size(ix), like=ts[0].data) + else do("ones", ts[0].ind_size(ix), like=ts[0].data) for ix in inds ] # contract into a single gauge @@ -803,7 +846,7 @@ def tensor_fuse_squeeze(t1, t2, squeeze=True, gauges=None): t2.squeeze_(include=(ind0,)) if gauges is not None: - s0_1_2 = gauges.pop(ind0).item()**0.5 + s0_1_2 = gauges.pop(ind0).item() ** 0.5 t1 *= s0_1_2 t2 *= s0_1_2 @@ -836,15 +879,16 @@ def new_bond(T1, T2, size=1, name=None, axis1=0, axis2=0): def rand_padder(vector, pad_width, iaxis, kwargs): - """Helper function for padding tensor with random entries. - """ - rand_strength = kwargs.get('rand_strength') + """Helper function for padding tensor with random entries.""" + rand_strength = kwargs.get("rand_strength") if pad_width[0]: - vector[:pad_width[0]] = rand_strength * randn(pad_width[0], - dtype='float32') + vector[: pad_width[0]] = rand_strength * randn( + pad_width[0], dtype="float32" + ) if pad_width[1]: - vector[-pad_width[1]:] = rand_strength * randn(pad_width[1], - dtype='float32') + vector[-pad_width[1] :] = rand_strength * randn( + pad_width[1], dtype="float32" + ) return vector @@ -879,13 +923,14 @@ def array_direct_product(X, Y, sum_axes=()): padY.append((d1, 0)) else: if d1 != d2: - raise ValueError("Can only add sum tensor " - "indices of the same size.") + raise ValueError( + "Can only add sum tensor " "indices of the same size." + ) padX.append((0, 0)) padY.append((0, 0)) - pX = do('pad', X, padX, mode='constant') - pY = do('pad', Y, padY, mode='constant') + pX = do("pad", X, padX, mode="constant") + pY = do("pad", Y, padY, mode="constant") return pX + pY @@ -1143,7 +1188,7 @@ def maybe_unwrap( return t # else get the single tensor - t, = t.tensor_map.values() + (t,) = t.tensor_map.values() if output_inds is not None and t.inds != output_inds: t.transpose_(*output_inds) @@ -1162,7 +1207,7 @@ def tensor_network_distance( xAA=None, xAB=None, xBB=None, - method='auto', + method="auto", **contract_opts, ): r"""Compute the Frobenius norm distance between two tensor networks: @@ -1203,7 +1248,7 @@ def tensor_network_distance( ------- D : float """ - check_opt('method', method, ('auto', 'dense', 'overlap')) + check_opt("method", method, ("auto", "dense", "overlap")) tnA = tnA.as_network() tnB = tnB.as_network() @@ -1215,28 +1260,28 @@ def tensor_network_distance( "networks with matching outer indices." ) - if method == 'auto': + if method == "auto": d = tnA.inds_size(oix) if d <= 1 << 16: - method = 'dense' + method = "dense" else: - method = 'overlap' + method = "overlap" # directly form vectorizations of both - if method == 'dense': + if method == "dense": A = tnA.to_dense(oix) B = tnB.to_dense(oix) - return do('linalg.norm', A - B) + return do("linalg.norm", A - B) # overlap method if xAA is None: - xAA = (tnA | tnA.H).contract(all, **contract_opts) + xAA = (tnA | tnA.H).contract(..., **contract_opts) if xAB is None: - xAB = (tnA | tnB.H).contract(all, **contract_opts) + xAB = (tnA | tnB.H).contract(..., **contract_opts) if xBB is None: - xBB = (tnB | tnB.H).contract(all, **contract_opts) + xBB = (tnB | tnB.H).contract(..., **contract_opts) - return do('abs', xAA - 2 * do('real', xAB) + xBB)**0.5 + return do("abs", xAA - 2 * do("real", xAB) + xBB) ** 0.5 def tensor_network_fit_autodiff( @@ -1244,12 +1289,12 @@ def tensor_network_fit_autodiff( tn_target, steps=1000, tol=1e-9, - autodiff_backend='autograd', - contract_optimize='auto-hq', - distance_method='auto', + autodiff_backend="autograd", + contract_optimize="auto-hq", + distance_method="auto", inplace=False, progbar=False, - **kwargs + **kwargs, ): """Optimize the fit of ``tn`` with respect to ``tn_target`` using automatic differentation. This minimizes the norm of the difference @@ -1287,17 +1332,19 @@ def tensor_network_fit_autodiff( from .optimize import TNOptimizer xBB = (tn_target | tn_target.H).contract( - ..., output_inds=(), optimize=contract_optimize, + ..., + output_inds=(), + optimize=contract_optimize, ) tnopt = TNOptimizer( tn=tn, loss_fn=tensor_network_distance, - loss_constants={'tnB': tn_target, 'xBB': xBB}, - loss_kwargs={'method': distance_method, 'optimize': contract_optimize}, + loss_constants={"tnB": tn_target, "xBB": xBB}, + loss_kwargs={"method": distance_method, "optimize": contract_optimize}, autodiff_backend=autodiff_backend, progbar=progbar, - **kwargs + **kwargs, ) tn_fit = tnopt.optimize(steps, tol=tol) @@ -1321,18 +1368,17 @@ def _tn_fit_als_core( steps, enforce_pos, pos_smudge, - solver='solve', + solver="solve", progbar=False, ): # shared intermediates + greedy = good reuse of contractions with contract_strategy(contract_optimize): - # prepare each of the contractions we are going to repeat env_contractions = [] for tg in var_tags: # varying tensor and conjugate in norm - tk = tnAA['__KET__', tg] - tb = tnAA['__BRA__', tg] + tk = tnAA["__KET__", tg] + tb = tnAA["__BRA__", tg] # get inds, and ensure any bonds come last, for linalg.solve lix, bix, rix = group_inds(tb, tk) @@ -1340,48 +1386,49 @@ def _tn_fit_als_core( tb.transpose_(*lix, *bix) # form TNs with 'holes', i.e. environment tensors networks - A_tn = tnAA.select((tg,), '!all') - y_tn = tnAB.select((tg,), '!all') + A_tn = tnAA.select((tg,), "!all") + y_tn = tnAB.select((tg,), "!all") env_contractions.append((tk, tb, lix, bix, rix, A_tn, y_tn)) if tol != 0.0: - old_d = float('inf') + old_d = float("inf") if progbar: import tqdm + pbar = tqdm.trange(steps) else: pbar = range(steps) # the main iterative sweep on each tensor, locally optimizing for _ in pbar: - for (tk, tb, lix, bix, rix, A_tn, y_tn) in env_contractions: + for tk, tb, lix, bix, rix, A_tn, y_tn in env_contractions: Ni = A_tn.to_dense(lix, rix) Wi = y_tn.to_dense(rix, bix) if enforce_pos: - el, ev = do('linalg.eigh', Ni) - el = do('clip', el, el[-1] * pos_smudge, None) - Ni_p = ev * do('reshape', el, (1, -1)) @ dag(ev) + el, ev = do("linalg.eigh", Ni) + el = do("clip", el, el[-1] * pos_smudge, None) + Ni_p = ev * do("reshape", el, (1, -1)) @ dag(ev) else: Ni_p = Ni - if solver == 'solve': - x = do('linalg.solve', Ni_p, Wi) - elif solver == 'lstsq': - x = do('linalg.lstsq', Ni_p, Wi, rcond=pos_smudge)[0] + if solver == "solve": + x = do("linalg.solve", Ni_p, Wi) + elif solver == "lstsq": + x = do("linalg.lstsq", Ni_p, Wi, rcond=pos_smudge)[0] - x_r = do('reshape', x, tk.shape) + x_r = do("reshape", x, tk.shape) # n.b. because we are using virtual TNs -> updates propagate tk.modify(data=x_r) - tb.modify(data=do('conj', x_r)) + tb.modify(data=do("conj", x_r)) # assess | A - B | for convergence or printing if (tol != 0.0) or progbar: - xAA = do('trace', dag(x) @ (Ni @ x)) # - xAB = do('trace', do('real', dag(x) @ Wi)) # - d = do('abs', (xAA - 2 * xAB + xBB))**0.5 + xAA = do("trace", dag(x) @ (Ni @ x)) # + xAB = do("trace", do("real", dag(x) @ Wi)) # + d = do("abs", (xAA - 2 * xAB + xBB)) ** 0.5 if abs(d - old_d) < tol: break old_d = d @@ -1396,13 +1443,13 @@ def tensor_network_fit_als( tags=None, steps=100, tol=1e-9, - solver='solve', + solver="solve", enforce_pos=False, pos_smudge=None, tnAA=None, tnAB=None, xBB=None, - contract_optimize='greedy', + contract_optimize="greedy", inplace=False, progbar=False, ): @@ -1462,29 +1509,31 @@ def tensor_network_fit_als( """ # mark the tensors we are going to optimize tna = tn.copy() - tna.add_tag('__KET__') + tna.add_tag("__KET__") if tags is None: to_tag = tna else: - to_tag = tna.select_tensors(tags, 'any') + to_tag = tna.select_tensors(tags, "any") var_tags = [] for i, t in enumerate(to_tag): - var_tag = f'__VAR{i}__' + var_tag = f"__VAR{i}__" t.add_tag(var_tag) var_tags.append(var_tag) # form the norm of the varying TN (A) and its overlap with the target (B) if tnAA is None: - tnAA = tna | tna.H.retag_({'__KET__': '__BRA__'}) + tnAA = tna | tna.H.retag_({"__KET__": "__BRA__"}) if tnAB is None: tnAB = tna | tn_target.H if (tol != 0.0) and (xBB is None): # xBB = (tn_target | tn_target.H).contract( - ..., optimize=contract_optimize, output_inds=(), + ..., + optimize=contract_optimize, + output_inds=(), ) if pos_smudge is None: @@ -1519,6 +1568,7 @@ def tensor_network_fit_als( # Tensor Class # # --------------------------------------------------------------------------- # + class Tensor: """A labelled, tagged n-dimensional array. The index labels are used instead of axis numbers to identify dimensions, and are preserved through @@ -1558,7 +1608,7 @@ class Tensor: """ - __slots__ = ('_data', '_inds', '_tags', '_left_inds', '_owners') + __slots__ = ("_data", "_inds", "_tags", "_left_inds", "_owners") def __init__(self, data=1.0, inds=(), tags=None, left_inds=None): # a new or copied Tensor always has no owners @@ -1707,8 +1757,7 @@ def add_owner(self, tn, tid): self._owners[hash(tn)] = (weakref.ref(tn), tid) def remove_owner(self, tn): - """Remove TensorNetwork ``tn`` as an owner of this Tensor. - """ + """Remove TensorNetwork ``tn`` as an owner of this Tensor.""" self._owners.pop(hash(tn), None) def check_owners(self): @@ -1742,16 +1791,16 @@ def modify(self, **kwargs): left_inds : sequence of str, optional New grouping of indices to be 'on the left'. """ - if 'data' in kwargs: - self._set_data(kwargs.pop('data')) + if "data" in kwargs: + self._set_data(kwargs.pop("data")) self._left_inds = None - if 'apply' in kwargs: - self._apply_function(kwargs.pop('apply')) + if "apply" in kwargs: + self._apply_function(kwargs.pop("apply")) self._left_inds = None - if 'inds' in kwargs: - inds = tuple(kwargs.pop('inds')) + if "inds" in kwargs: + inds = tuple(kwargs.pop("inds")) # if this tensor has owners, update their ``ind_map``, but only if # the indices are actually being changed not just permuted old_inds = oset(self.inds) @@ -1763,8 +1812,8 @@ def modify(self, **kwargs): self._inds = inds self._left_inds = None - if 'tags' in kwargs: - tags = tags_to_oset(kwargs.pop('tags')) + if "tags" in kwargs: + tags = tags_to_oset(kwargs.pop("tags")) # if this tensor has owners, update their ``tag_map``. if self.check_owners(): for ref, tid in self._owners.values(): @@ -1772,19 +1821,23 @@ def modify(self, **kwargs): self._tags = tags - if 'left_inds' in kwargs: - self.left_inds = kwargs.pop('left_inds') + if "left_inds" in kwargs: + self.left_inds = kwargs.pop("left_inds") if kwargs: raise ValueError(f"Option(s) {kwargs} not valid.") if len(self.inds) != ndim(self.data): - raise ValueError("Mismatch between number of data dimensions and " - "number of indices supplied.") + raise ValueError( + "Mismatch between number of data dimensions and " + "number of indices supplied." + ) if self.left_inds and any(i not in self.inds for i in self.left_inds): - raise ValueError(f"The 'left' indices {self.left_inds} are " - f"not found in {self.inds}.") + raise ValueError( + f"The 'left' indices {self.left_inds} are " + f"not found in {self.inds}." + ) def apply_to_arrays(self, fn): """Apply the function ``fn`` to the underlying data array(s). This @@ -1823,7 +1876,8 @@ def isel(self, selectors, inplace=False): T = self if inplace else self.copy() new_inds = tuple( - ix for ix in self.inds + ix + for ix in self.inds if (ix not in selectors) or isinstance(selectors[ix], slice) ) @@ -1888,11 +1942,11 @@ def new_ind(self, name, size=1, axis=0): new_inds.insert(axis, name) - new_data = do('expand_dims', self.data, axis=axis) + new_data = do("expand_dims", self.data, axis=axis) self.modify(data=new_data, inds=new_inds) if size > 1: - self.expand_ind(name, size) + self.expand_ind(name, size, mode=mode) new_bond = new_bond @@ -1915,18 +1969,17 @@ def new_ind_with_identity(self, name, left_inds, right_inds, axis=0): Position of the new index. """ ldims = tuple(map(self.ind_size, left_inds)) - x_id = do('eye', prod(ldims), dtype=self.dtype, like=self.data) - x_id = do('reshape', x_id, ldims + ldims) + x_id = do("eye", prod(ldims), dtype=self.dtype, like=self.data) + x_id = do("reshape", x_id, ldims + ldims) t_id = Tensor(x_id, inds=left_inds + right_inds) t_id.transpose_(*self.inds) - new_data = do('stack', (self.data, t_id.data), axis=axis) + new_data = do("stack", (self.data, t_id.data), axis=axis) new_inds = list(self.inds) new_inds.insert(axis, name) self.modify(data=new_data, inds=new_inds) def conj(self, inplace=False): - """Conjugate this tensors data (does nothing to indices). - """ + """Conjugate this tensors data (does nothing to indices).""" t = self if inplace else self.copy() t.modify(apply=conj, left_inds=t.left_inds) return t @@ -1935,48 +1988,41 @@ def conj(self, inplace=False): @property def H(self): - """Conjugate this tensors data (does nothing to indices). - """ + """Conjugate this tensors data (does nothing to indices).""" return self.conj() @property def shape(self): - """The size of each dimension. - """ + """The size of each dimension.""" return shape(self._data) @property def ndim(self): - """The number of dimensions. - """ + """The number of dimensions.""" return len(self._inds) @property def size(self): - """The total number of array elements. - """ + """The total number of array elements.""" # more robust than calling _data.size (e.g. for torch) - consider # adding do('size', x) to autoray? return prod(self.shape) @property def dtype(self): - """The data type of the array elements. - """ + """The data type of the array elements.""" return getattr(self._data, "dtype", None) @property def backend(self): - """The backend inferred from the data. - """ + """The backend inferred from the data.""" return infer_backend(self._data) def iscomplex(self): return iscomplex(self.data) def astype(self, dtype, inplace=False): - """Change the type of this tensor to ``dtype``. - """ + """Change the type of this tensor to ``dtype``.""" T = self if inplace else self.copy() if T.dtype != dtype: T.modify(apply=lambda data: astype(data, dtype)) @@ -1985,30 +2031,25 @@ def astype(self, dtype, inplace=False): astype_ = functools.partialmethod(astype, inplace=True) def max_dim(self): - """Return the maximum size of any dimension, or 1 if scalar. - """ + """Return the maximum size of any dimension, or 1 if scalar.""" if self.ndim == 0: return 1 return max(self.shape) def ind_size(self, ind): - """Return the size of dimension corresponding to ``ind``. - """ + """Return the size of dimension corresponding to ``ind``.""" return int(self.shape[self.inds.index(ind)]) def inds_size(self, inds): - """Return the total size of dimensions corresponding to ``inds``. - """ + """Return the total size of dimensions corresponding to ``inds``.""" return prod(map(self.ind_size, inds)) def shared_bond_size(self, other): - """Get the total size of the shared index(es) with ``other``. - """ + """Get the total size of the shared index(es) with ``other``.""" return bonds_size(self, other) def inner_inds(self): - """ - """ + """ """ ind_freqs = frequencies(self.inds) return tuple(i for i in self.inds if ind_freqs[i] == 2) @@ -2084,15 +2125,17 @@ def transpose_like(self, other, inplace=False): diff_ix = set(t.inds) - set(other.inds) if len(diff_ix) > 1: - raise ValueError("More than one index don't match, the transpose " - "is therefore not well-defined.") + raise ValueError( + "More than one index don't match, the transpose " + "is therefore not well-defined." + ) # if their indices match, just plain transpose if not diff_ix: t.transpose_(*other.inds) else: - di, = diff_ix + (di,) = diff_ix new_ix = (i if i in t.inds else di for i in other.inds) t.transpose_(*new_ix) @@ -2131,16 +2174,11 @@ def moveindex(self, ind, axis, inplace=False): moveindex_ = functools.partialmethod(moveindex, inplace=True) def item(self): - """Return the scalar value of this tensor, if it has a single element. - """ + """Return the scalar value of this tensor, if it has a single element.""" return self.data.item() def trace( - self, - left_inds, - right_inds, - preserve_tensor=False, - inplace=False + self, left_inds, right_inds, preserve_tensor=False, inplace=False ): """Trace index or indices ``left_inds`` with ``right_inds``, removing them. @@ -2188,7 +2226,9 @@ def trace( raise ValueError(f"Indices {tuple(remap)} not found.") expr = array_contract_expression( - inputs=[old_inds], output=new_inds, shapes=[t.shape], + inputs=[old_inds], + output=new_inds, + shapes=[t.shape], ) t.modify(apply=expr, inds=new_inds, left_inds=None) @@ -2216,8 +2256,8 @@ def sum_reduce(self, ind, inplace=False): """ t = self if inplace else self.copy() axis = t.inds.index(ind) - new_inds = t.inds[:axis] + t.inds[axis + 1:] - t.modify(apply=lambda x: do('sum', x, axis=axis), inds=new_inds) + new_inds = t.inds[:axis] + t.inds[axis + 1 :] + t.modify(apply=lambda x: do("sum", x, axis=axis), inds=new_inds) return t sum_reduce_ = functools.partialmethod(sum_reduce, inplace=True) @@ -2240,12 +2280,12 @@ def vector_reduce(self, ind, v, inplace=False): Tensor """ t = self if inplace else self.copy() - axis= t.inds.index(ind) + axis = t.inds.index(ind) new_data = array_contract( (t.data, v), (tuple(range(self.ndim)), (axis,)), ) - new_inds = t.inds[:axis] + t.inds[axis + 1:] + new_inds = t.inds[:axis] + t.inds[axis + 1 :] t.modify(data=new_data, inds=new_inds) return t @@ -2263,14 +2303,17 @@ def collapse_repeated(self, inplace=False): return t expr = array_contract_expression( - inputs=[old_inds], output=new_inds, shapes=[t.shape], + inputs=[old_inds], + output=new_inds, + shapes=[t.shape], ) t.modify(apply=expr, inds=new_inds, left_inds=None) return t collapse_repeated_ = functools.partialmethod( - collapse_repeated, inplace=True) + collapse_repeated, inplace=True + ) @functools.wraps(tensor_contract) def contract(self, *others, output_inds=None, **opts): @@ -2279,7 +2322,8 @@ def contract(self, *others, output_inds=None, **opts): @functools.wraps(tensor_direct_product) def direct_product(self, other, sum_inds=(), inplace=False): return tensor_direct_product( - self, other, sum_inds=sum_inds, inplace=inplace) + self, other, sum_inds=sum_inds, inplace=inplace + ) direct_product_ = functools.partialmethod(direct_product, inplace=True) @@ -2381,14 +2425,14 @@ def gate( t.modify(data=new_data) else: # simply update index labels - new_inds = (ind, *t.inds[:ax], *t.inds[ax + 1:]) + new_inds = (ind, *t.inds[:ax], *t.inds[ax + 1 :]) t.modify(data=new_data, inds=new_inds) return t gate_ = functools.partialmethod(gate, inplace=True) - def singular_values(self, left_inds, method='svd'): + def singular_values(self, left_inds, method="svd"): """Return the singular values associated with splitting this tensor according to ``left_inds``. @@ -2405,9 +2449,9 @@ def singular_values(self, left_inds, method='svd'): 1d-array The singular values. """ - return self.split(left_inds=left_inds, method=method, get='values') + return self.split(left_inds=left_inds, method=method, get="values") - def entropy(self, left_inds, method='svd'): + def entropy(self, left_inds, method="svd"): """Return the entropy associated with splitting this tensor according to ``left_inds``. @@ -2423,9 +2467,9 @@ def entropy(self, left_inds, method='svd'): ------- float """ - el = self.singular_values(left_inds=left_inds, method=method)**2 + el = self.singular_values(left_inds=left_inds, method=method) ** 2 el = el[el > 0.0] - return do('sum', -el * do('log2', el)) + return do("sum", -el * do("log2", el)) def retag(self, retag_map, inplace=False): """Rename the tags of this tensor, optionally, in-place. @@ -2497,7 +2541,7 @@ def fuse(self, fuse_map, inplace=False): # compute numerical axes groups to supply to the array function fuse ind2ax = {ind: ax for ax, ind in enumerate(t.inds)} axes_groups = [] - gax0 = float('inf') + gax0 = float("inf") for fused_ind_group in fused_inds: group = [] for ind in fused_ind_group: @@ -2572,8 +2616,11 @@ def unfuse(self, unfuse_map, shape_map, inplace=False): # create new tensor with new + remaining indices # + updated 'left' marked indices assuming all unfused left inds # remain 'left' marked - t.modify(data=do("reshape", t.data, new_dims), - inds=new_inds, left_inds=new_left_inds) + t.modify( + data=do("reshape", t.data, new_dims), + inds=new_inds, + left_inds=new_left_inds, + ) return t @@ -2586,7 +2633,7 @@ def to_dense(self, *inds_seq, to_qarray=False): """ fuse_map = [(f"__d{i}__", ix) for i, ix in enumerate(inds_seq)] x = self.fuse(fuse_map).data - if to_qarray and (infer_backend(x) == 'numpy'): + if to_qarray and (infer_backend(x) == "numpy"): return qarray(x) return x @@ -2623,13 +2670,15 @@ def squeeze( new_inds = [] new_shape = [] - any_squeezed = False + any_squeezed = False for ix, d in zip(t.inds, t.shape): keep = ( # not squeezable - (d > 1) or + (d > 1) + or # is not in the list of allowed indices - (include is not None and ix not in include) or + (include is not None and ix not in include) + or # is in the list of not allowed indices (exclude is not None and ix in exclude) ) @@ -2646,8 +2695,9 @@ def squeeze( # we can propagate 'left' marked indices through squeezing new_left_inds = ( - None if self.left_inds is None else - (i for i in self.left_inds if i in new_inds) + None + if self.left_inds is None + else (i for i in self.left_inds if i in new_inds) ) t.modify(data=new_data, inds=new_inds, left_inds=new_left_inds) @@ -2659,7 +2709,7 @@ def largest_element(self): r"""Return the largest element, in terms of absolute magnitude, of this tensor. """ - return do('max', do('abs', self.data)) + return do("max", do("abs", self.data)) def idxmin(self, f=None): """Get the index configuration of the minimum element of this tensor, @@ -2715,7 +2765,6 @@ def idxmax(self, f=None): idx = np.unravel_index(flat_idx, self.shape) return dict(zip(self.inds, idx)) - def norm(self): r"""Frobenius norm of this tensor: @@ -2748,7 +2797,7 @@ def symmetrize(self, ind1, ind2, inplace=False): symmetrize_ = functools.partialmethod(symmetrize, inplace=True) - def isometrize(self, left_inds=None, method='qr', inplace=False): + def isometrize(self, left_inds=None, method="qr", inplace=False): r"""Make this tensor unitary (or isometric) with respect to ``left_inds``. The underlying method is set by ``method``. @@ -2800,7 +2849,8 @@ def isometrize(self, left_inds=None, method='qr', inplace=False): raise ValueError( "You must specify `left_inds` since this tensor does not " "have any indices marked automatically as such in the " - "attribute `left_inds`.") + "attribute `left_inds`." + ) else: left_inds = self.left_inds @@ -2832,8 +2882,8 @@ def isometrize(self, left_inds=None, method='qr', inplace=False): return Tu isometrize_ = functools.partialmethod(isometrize, inplace=True) - unitize = deprecated(isometrize, 'unitize', 'isometrize') - unitize_ = deprecated(isometrize_, 'unitize_', 'isometrize_') + unitize = deprecated(isometrize, "unitize", "isometrize") + unitize_ = deprecated(isometrize_, "unitize_", "isometrize_") def randomize(self, dtype=None, inplace=False, **randn_opts): """Randomize the entries of this tensor. @@ -2890,20 +2940,19 @@ def multiply_index_diagonal(self, ind, x, inplace=False): return t multiply_index_diagonal_ = functools.partialmethod( - multiply_index_diagonal, inplace=True) + multiply_index_diagonal, inplace=True + ) def almost_equals(self, other, **kwargs): - """Check if this tensor is almost the same as another. - """ - same_inds = (set(self.inds) == set(other.inds)) + """Check if this tensor is almost the same as another.""" + same_inds = set(self.inds) == set(other.inds) if not same_inds: return False otherT = other.transpose(*self.inds) - return do('allclose', self.data, otherT.data, **kwargs) + return do("allclose", self.data, otherT.data, **kwargs) def drop_tags(self, tags=None): - """Drop certain tags, defaulting to all, from this tensor. - """ + """Drop certain tags, defaulting to all, from this tensor.""" if tags is None: self.modify(tags=oset()) else: @@ -2966,7 +3015,7 @@ def __matmul__(self, other): ax1 = tuple(self.inds.index(b) for b in bix) ax2 = tuple(other.inds.index(b) for b in bix) data_out = do( - 'tensordot', + "tensordot", self.data, other.data, axes=(ax1, ax2), @@ -2983,14 +3032,12 @@ def __matmul__(self, other): return self.__class__(data_out, inds=new_inds, tags=new_tags) def as_network(self, virtual=True): - """Return a ``TensorNetwork`` with only this tensor. - """ + """Return a ``TensorNetwork`` with only this tensor.""" return TensorNetwork((self,), virtual=virtual) @functools.wraps(draw_tn) def draw(self, *args, **kwargs): - """Plot a graph of this tensor and its indices. - """ + """Plot a graph of this tensor and its indices.""" return draw_tn(self.as_network(), *args, **kwargs) graph = draw @@ -3028,8 +3075,7 @@ def _repr_info_extra(self): } def _repr_info_str(self, normal=True, extra=False): - """Render the general info as a string. - """ + """Render the general info as a string.""" info = {} if normal: info.update(self._repr_info()) @@ -3041,17 +3087,16 @@ def _repr_info_str(self, normal=True, extra=False): ) def _repr_html_(self): - """Render this Tensor as HTML, for Jupyter notebooks. - """ + """Render this Tensor as HTML, for Jupyter notebooks.""" s = "" s += "
" s += "" - shape_repr = ', '.join(auto_color_html(d) for d in self.shape) - inds_repr = ', '.join(auto_color_html(ix) for ix in self.inds) - tags_repr = ', '.join(auto_color_html(tag) for tag in self.tags) + shape_repr = ", ".join(auto_color_html(d) for d in self.shape) + inds_repr = ", ".join(auto_color_html(ix) for ix in self.inds) + tags_repr = ", ".join(auto_color_html(tag) for tag in self.tags) s += ( f"{auto_color_html(self.__class__.__name__)}(" - f'shape=({shape_repr}), inds=[{inds_repr}], tags={{{tags_repr}}}' + f"shape=({shape_repr}), inds=[{inds_repr}], tags={{{tags_repr}}}" ")," ) s += "" @@ -3066,14 +3111,10 @@ def _repr_html_(self): return s def __str__(self): - return ( - f"{self.__class__.__name__}({self._repr_info_str(extra=True)})" - ) + return f"{self.__class__.__name__}({self._repr_info_str(extra=True)})" def __repr__(self): - return ( - f"{self.__class__.__name__}({self._repr_info_str()})" - ) + return f"{self.__class__.__name__}({self._repr_info_str()})" @functools.lru_cache(128) @@ -3196,55 +3237,61 @@ def COPY_tree_tensors(d, inds, tags=None, dtype=float, ssa_path=None): # ------------------------- Add ufunc like methods -------------------------- # -def _make_promote_array_func(op, meth_name): +def _make_promote_array_func(op, meth_name): @functools.wraps(getattr(np.ndarray, meth_name)) def _promote_array_func(self, other): - """Use standard array func, but make sure Tensor inds match. - """ + """Use standard array func, but make sure Tensor inds match.""" if isinstance(other, Tensor): - if set(self.inds) != set(other.inds): - raise ValueError("The indicies of these two tensors do not " - f"match: {self.inds} != {other.inds}") + raise ValueError( + "The indicies of these two tensors do not " + f"match: {self.inds} != {other.inds}" + ) otherT = other.transpose(*self.inds) return Tensor( - data=op(self.data, otherT.data), inds=self.inds, - tags=self.tags | other.tags) + data=op(self.data, otherT.data), + inds=self.inds, + tags=self.tags | other.tags, + ) else: - return Tensor(data=op(self.data, other), - inds=self.inds, tags=self.tags) + return Tensor( + data=op(self.data, other), inds=self.inds, tags=self.tags + ) return _promote_array_func -for meth_name, op in [('__add__', operator.__add__), - ('__sub__', operator.__sub__), - ('__mul__', operator.__mul__), - ('__pow__', operator.__pow__), - ('__truediv__', operator.__truediv__)]: +for meth_name, op in [ + ("__add__", operator.__add__), + ("__sub__", operator.__sub__), + ("__mul__", operator.__mul__), + ("__pow__", operator.__pow__), + ("__truediv__", operator.__truediv__), +]: setattr(Tensor, meth_name, _make_promote_array_func(op, meth_name)) def _make_rhand_array_promote_func(op, meth_name): - @functools.wraps(getattr(np.ndarray, meth_name)) def _rhand_array_promote_func(self, other): - """Right hand operations -- no need to check ind equality first. - """ - return Tensor(data=op(other, self.data), - inds=self.inds, tags=self.tags) + """Right hand operations -- no need to check ind equality first.""" + return Tensor( + data=op(other, self.data), inds=self.inds, tags=self.tags + ) return _rhand_array_promote_func -for meth_name, op in [('__radd__', operator.__add__), - ('__rsub__', operator.__sub__), - ('__rmul__', operator.__mul__), - ('__rpow__', operator.__pow__), - ('__rtruediv__', operator.__truediv__)]: +for meth_name, op in [ + ("__radd__", operator.__add__), + ("__rsub__", operator.__sub__), + ("__rmul__", operator.__mul__), + ("__rpow__", operator.__pow__), + ("__rtruediv__", operator.__truediv__), +]: setattr(Tensor, meth_name, _make_rhand_array_promote_func(op, meth_name)) @@ -3252,16 +3299,25 @@ def _rhand_array_promote_func(self, other): # Tensor Network Class # # --------------------------------------------------------------------------- # + def _tensor_network_gate_inds_basic( - tn, G, inds, ng, tags, contract, isparam, info, **compress_opts, + tn, + G, + inds, + ng, + tags, + contract, + isparam, + info, + **compress_opts, ): tags = tags_to_oset(tags) if (ng == 1) and contract: # single site gate, eagerly applied so contract in directly -> # useful short circuit as it maintains the index structure exactly - ix, = inds - t, = tn._inds_get(ix) + (ix,) = inds + (t,) = tn._inds_get(ix) t.gate_(G, ix) t.add_tag(tags) return tn @@ -3273,7 +3329,8 @@ def _tensor_network_gate_inds_basic( # tensor representing the gate if isparam: TG = PTensor.from_parray( - G, inds=(*inds, *bnds), tags=tags, left_inds=bnds) + G, inds=(*inds, *bnds), tags=tags, left_inds=bnds + ) else: TG = Tensor(G, inds=(*inds, *bnds), tags=tags, left_inds=bnds) @@ -3289,7 +3346,7 @@ def _tensor_network_gate_inds_basic( tn |= TG return tn - tids = tn._get_tids_from_inds(inds, 'any') + tids = tn._get_tids_from_inds(inds, "any") if (contract is True) or (len(tids) == 1): # @@ -3300,7 +3357,7 @@ def _tensor_network_gate_inds_basic( tn.reindex_(reindex_map) # get the sites that used to have the physical indices - site_tids = tn._get_tids_from_inds(bnds, which='any') + site_tids = tn._get_tids_from_inds(bnds, which="any") # pop the sites, contract, then re-add pts = [tn.pop_tensor(tid) for tid in site_tids] @@ -3313,7 +3370,7 @@ def _tensor_network_gate_inds_basic( tl, tr = tn._inds_get(ixl, ixr) bnds_l, (bix,), bnds_r = group_inds(tl, tr) - if contract == 'split': + if contract == "split": # # │╱ │╱ │╱ │╱ # ──GGGGG── -> ──G~~~G── @@ -3322,16 +3379,19 @@ def _tensor_network_gate_inds_basic( # contract with new gate tensor tlGr = tensor_contract( - tl.reindex(reindex_map), - tr.reindex(reindex_map), - TG) + tl.reindex(reindex_map), tr.reindex(reindex_map), TG + ) # decompose back into two tensors tln, *maybe_svals, trn = tlGr.split( - left_inds=bnds_l, right_inds=bnds_r, - bond_ind=bix, get='tensors', **compress_opts) + left_inds=bnds_l, + right_inds=bnds_r, + bond_ind=bix, + get="tensors", + **compress_opts, + ) - if contract == 'reduce-split': + if contract == "reduce-split": # move physical inds on reduced tensors # # │ │ │ │ @@ -3341,11 +3401,19 @@ def _tensor_network_gate_inds_basic( # ╱ ╱ ╱ ╱ # tmp_bix_l = rand_uuid() - tl_Q, tl_R = tl.split(left_inds=None, right_inds=[bix, ixl], - method='qr', bond_ind=tmp_bix_l) + tl_Q, tl_R = tl.split( + left_inds=None, + right_inds=[bix, ixl], + method="qr", + bond_ind=tmp_bix_l, + ) tmp_bix_r = rand_uuid() - tr_L, tr_Q = tr.split(left_inds=[bix, ixr], right_inds=None, - method='lq', bond_ind=tmp_bix_r) + tr_L, tr_Q = tr.split( + left_inds=[bix, ixr], + right_inds=None, + method="lq", + bond_ind=tmp_bix_r, + ) # contract reduced tensors with gate tensor # @@ -3356,9 +3424,8 @@ def _tensor_network_gate_inds_basic( # ╱ ╱ ╱ ╱ # tlGr = tensor_contract( - tl_R.reindex(reindex_map), - tr_L.reindex(reindex_map), - TG) + tl_R.reindex(reindex_map), tr_L.reindex(reindex_map), TG + ) # split to find new reduced factors # @@ -3368,8 +3435,12 @@ def _tensor_network_gate_inds_basic( # ╱ ╱ ╱ ╱ # tl_R, *maybe_svals, tr_L = tlGr.split( - left_inds=[tmp_bix_l, ixl], right_inds=[tmp_bix_r, ixr], - bond_ind=bix, get='tensors', **compress_opts) + left_inds=[tmp_bix_l, ixl], + right_inds=[tmp_bix_r, ixr], + bond_ind=bix, + get="tensors", + **compress_opts, + ) # absorb reduced factors back into site tensors # @@ -3385,7 +3456,7 @@ def _tensor_network_gate_inds_basic( # return them via ``info``, e.g. for ``SimpleUpdate` if maybe_svals and info is not None: s = next(iter(maybe_svals)).data - info['singular_values', bix] = s + info["singular_values", bix] = s # update original tensors tl.modify(data=tln.transpose_like_(tl).data) @@ -3393,53 +3464,60 @@ def _tensor_network_gate_inds_basic( def _tensor_network_gate_inds_lazy_split( - tn, G, inds, ng, tags, contract, dims, **compress_opts, + tn, + G, + inds, + ng, + tags, + contract, + dims, + **compress_opts, ): - lix = [f'l{i}' for i in range(ng)] - rix = [f'r{i}' for i in range(ng)] + lix = [f"l{i}" for i in range(ng)] + rix = [f"r{i}" for i in range(ng)] TG = Tensor(data=G, inds=lix + rix, tags=tags, left_inds=rix) # check if we should split multi-site gates (which may result in an easier # tensor network to contract if we use compression) - if contract in ('split-gate', 'auto-split-gate'): + if contract in ("split-gate", "auto-split-gate"): # | | | | # GGG --> G~G # | | | | - tnG_spat = TG.split(('l0', 'r0'), bond_ind='b', **compress_opts) + tnG_spat = TG.split(("l0", "r0"), bond_ind="b", **compress_opts) # sometimes it is worth performing the decomposition *across* the gate, # effectively introducing a SWAP - if contract in ('swap-split-gate', 'auto-split-gate'): + if contract in ("swap-split-gate", "auto-split-gate"): # \ / # | | X # GGG --> / \ # | | G~G # | | - tnG_swap = TG.split(('l0', 'r1'), bond_ind='b', **compress_opts) + tnG_swap = TG.split(("l0", "r1"), bond_ind="b", **compress_opts) # like 'split-gate' but check the rank for swapped indices also, and if no # rank reduction, simply don't swap - if contract == 'auto-split-gate': + if contract == "auto-split-gate": # | | \ / # | | | | X | | # GGG --> G~G or / \ or ... GGG # | | | | G~G | | # | | | | - spat_rank = tnG_spat.ind_size('b') - swap_rank = tnG_swap.ind_size('b') + spat_rank = tnG_spat.ind_size("b") + swap_rank = tnG_swap.ind_size("b") if swap_rank < spat_rank: - contract = 'swap-split-gate' + contract = "swap-split-gate" elif spat_rank < prod(dims): - contract = 'split-gate' + contract = "split-gate" else: # else no rank reduction available - leave as ``contract=False``. contract = False - if contract == 'swap-split-gate': + if contract == "swap-split-gate": tnG = tnG_swap - elif contract == 'split-gate': + elif contract == "split-gate": tnG = tnG_spat else: tnG = TG @@ -3448,20 +3526,23 @@ def _tensor_network_gate_inds_lazy_split( _BASIC_GATE_CONTRACT = { - False, True, - 'split', - 'reduce-split', + False, + True, + "split", + "reduce-split", } _SPLIT_GATE_CONTRACT = { - 'auto-split-gate', - 'split-gate', - 'swap-split-gate', + "auto-split-gate", + "split-gate", + "swap-split-gate", } _VALID_GATE_CONTRACT = _BASIC_GATE_CONTRACT | _SPLIT_GATE_CONTRACT def tensor_network_gate_inds( - self, G, inds, + self, + G, + inds, contract=False, tags=None, info=None, @@ -3582,7 +3663,7 @@ def tensor_network_gate_inds( depending on whether either results in a lower rank. """ - check_opt('contract', contract, _VALID_GATE_CONTRACT) + check_opt("contract", contract, _VALID_GATE_CONTRACT) tn = self if inplace else self.copy() @@ -3595,13 +3676,16 @@ def tensor_network_gate_inds( G = do("reshape", G, dims * 2) if not all(d == dims[i % ng] for i, d in enumerate(G.shape)): - raise ValueError(f"Gate with shape {G.shape} doesn't match " - f"indices {inds} with dimensions {dims}.") + raise ValueError( + f"Gate with shape {G.shape} doesn't match " + f"indices {inds} with dimensions {dims}." + ) - basic = (contract in _BASIC_GATE_CONTRACT) + basic = contract in _BASIC_GATE_CONTRACT if ( # if single ind, gate splitting methods are same as lazy - ((not basic) and (ng == 1)) or + ((not basic) and (ng == 1)) + or # or for 3+ sites, treat auto as no splitting ((contract == "auto-split-gate") and (ng > 2)) ): @@ -3610,26 +3694,29 @@ def tensor_network_gate_inds( isparam = isinstance(G, PArray) if isparam: - if contract == 'auto-split-gate': + if contract == "auto-split-gate": # simply don't split basic = True contract = False elif contract and ng > 1: raise ValueError( "For a parametrized gate acting on more than one site " - "``contract`` must be false to preserve the array shape.") + "``contract`` must be false to preserve the array shape." + ) if basic: # no gate splitting involved _tensor_network_gate_inds_basic( - tn, G, inds, ng, tags, contract, isparam, info, **compress_opts) + tn, G, inds, ng, tags, contract, isparam, info, **compress_opts + ) else: # possible splitting of gate itself involved if ng > 2: raise ValueError(f"`contract='{contract}'` invalid for >2 sites.") _tensor_network_gate_inds_lazy_split( - tn, G, inds, ng, tags, contract, dims, **compress_opts) + tn, G, inds, ng, tags, contract, dims, **compress_opts + ) return tn @@ -3677,7 +3764,6 @@ class TensorNetwork(object): _CONTRACT_STRUCTURED = False def __init__(self, ts=(), *, virtual=False, check_collisions=True): - # short-circuit for copying or casting as TensorNetwork if isinstance(ts, TensorNetwork): self.tag_map = valmap(lambda tids: tids.copy(), ts.tag_map) @@ -3796,8 +3882,7 @@ def from_TN(cls, tn, like=None, inplace=False, **kwargs): return new_tn def view_as(self, cls, inplace=False, **kwargs): - """View this tensor network as subclass ``cls``. - """ + """View this tensor network as subclass ``cls``.""" return cls.from_TN(self, inplace=inplace, **kwargs) view_as_ = functools.partialmethod(view_as, inplace=True) @@ -3806,8 +3891,9 @@ def view_like(self, like, inplace=False, **kwargs): """View this tensor network as the same subclass ``cls`` as ``like`` inheriting its extra properties as well. """ - return self.view_as(like.__class__, like=like, - inplace=inplace, **kwargs) + return self.view_as( + like.__class__, like=like, inplace=inplace, **kwargs + ) view_like_ = functools.partialmethod(view_like, inplace=True) @@ -3824,8 +3910,7 @@ def copy(self, virtual=False, deep=False): __copy__ = copy def get_params(self): - """Get a pytree of the 'parameters', i.e. all underlying data arrays. - """ + """Get a pytree of the 'parameters', i.e. all underlying data arrays.""" return {tid: t.get_params() for tid, t in self.tensor_map.items()} def set_params(self, params): @@ -3836,8 +3921,7 @@ def set_params(self, params): self.tensor_map[tid].set_params(t_params) def _link_tags(self, tags, tid): - """Link ``tid`` to each of ``tags``. - """ + """Link ``tid`` to each of ``tags``.""" for tag in tags: if tag in self.tag_map: self.tag_map[tag].add(tid) @@ -3845,8 +3929,7 @@ def _link_tags(self, tags, tid): self.tag_map[tag] = oset((tid,)) def _unlink_tags(self, tags, tid): - """"Unlink ``tid`` from each of ``tags``. - """ + """ "Unlink ``tid`` from each of ``tags``.""" for tag in tags: try: tids = self.tag_map[tag] @@ -3859,8 +3942,7 @@ def _unlink_tags(self, tags, tid): pass def _link_inds(self, inds, tid): - """Link ``tid`` to each of ``inds``. - """ + """Link ``tid`` to each of ``inds``.""" for ind in inds: if ind in self.ind_map: self.ind_map[ind].add(tid) @@ -3871,8 +3953,7 @@ def _link_inds(self, inds, tid): self._outer_inds.add(ind) def _unlink_inds(self, inds, tid): - """"Unlink ``tid`` from each of ``inds``. - """ + """ "Unlink ``tid`` from each of ``inds``.""" for ind in inds: try: tids = self.ind_map[ind] @@ -3906,8 +3987,7 @@ def _next_tid(self): return self._tid_counter def add_tensor(self, tensor, tid=None, virtual=False): - """Add a single tensor to this network - mangle its tid if neccessary. - """ + """Add a single tensor to this network - mangle its tid if neccessary.""" # check for tid conflict if (tid is None) or (tid in self.tensor_map): tid = self._next_tid() @@ -3922,8 +4002,7 @@ def add_tensor(self, tensor, tid=None, virtual=False): self._link_inds(T.inds, tid) def add_tensor_network(self, tn, virtual=False, check_collisions=True): - """ - """ + """ """ if check_collisions: # add tensors individually # check for matching inner_indices -> need to re-index clash_ix = self._inner_inds & tn._inner_inds @@ -3941,32 +4020,34 @@ def add_tensor_network(self, tn, virtual=False, check_collisions=True): self.exponent = self.exponent + tn.exponent def add(self, t, virtual=False, check_collisions=True): - """Add Tensor, TensorNetwork or sequence thereof to self. - """ + """Add Tensor, TensorNetwork or sequence thereof to self.""" if isinstance(t, (tuple, list)): for each_t in t: - self.add(each_t, virtual=virtual, - check_collisions=check_collisions) + self.add( + each_t, virtual=virtual, check_collisions=check_collisions + ) return istensor = isinstance(t, Tensor) istensornetwork = isinstance(t, TensorNetwork) if not (istensor or istensornetwork): - raise TypeError("TensorNetwork should be called as " - "`TensorNetwork(ts, ...)`, where each " - "object in 'ts' is a Tensor or " - "TensorNetwork.") + raise TypeError( + "TensorNetwork should be called as " + "`TensorNetwork(ts, ...)`, where each " + "object in 'ts' is a Tensor or " + "TensorNetwork." + ) if istensor: self.add_tensor(t, virtual=virtual) else: - self.add_tensor_network(t, virtual=virtual, - check_collisions=check_collisions) + self.add_tensor_network( + t, virtual=virtual, check_collisions=check_collisions + ) def make_tids_consecutive(self, tid0=0): - """Reset the `tids` - node identifies - to be consecutive integers. - """ + """Reset the `tids` - node identifies - to be consecutive integers.""" tids = tuple(self.tensor_map.keys()) ts = tuple(map(self.pop_tensor, tids)) self._tid_counter = tid0 @@ -3996,19 +4077,16 @@ def _modify_tensor_inds(self, old, new, tid): @property def num_tensors(self): - """The total number of tensors in the tensor network. - """ + """The total number of tensors in the tensor network.""" return len(self.tensor_map) @property def num_indices(self): - """The total number of indices in the tensor network. - """ + """The total number of indices in the tensor network.""" return len(self.ind_map) def pop_tensor(self, tid): - """Remove tensor with ``tid`` from this network, and return it. - """ + """Remove tensor with ``tid`` from this network, and return it.""" # pop the tensor itself t = self.tensor_map.pop(tid) @@ -4021,9 +4099,13 @@ def pop_tensor(self, tid): return t - _pop_tensor = deprecated(pop_tensor, "_pop_tensor", "pop_tensor",) + _pop_tensor = deprecated( + pop_tensor, + "_pop_tensor", + "pop_tensor", + ) - def delete(self, tags, which='all'): + def delete(self, tags, which="all"): """Delete any tensors which match all or any of ``tags``. Parameters @@ -4038,8 +4120,7 @@ def delete(self, tags, which='all'): self.pop_tensor(tid) def check(self): - """Check some basic diagnostics of the tensor network. - """ + """Check some basic diagnostics of the tensor network.""" for tid, t in self.tensor_map.items(): t.check() @@ -4091,7 +4172,7 @@ def check(self): f"'{ix}' in tensors {ts}." ) - def add_tag(self, tag, where=None, which='all'): + def add_tag(self, tag, where=None, which="all"): """Add tag to every tensor in this network, or if ``where`` is specified, the tensors matching those tags -- i.e. adds the tag to all tensors in ``self.select_tensors(where, which=which)``. @@ -4112,7 +4193,7 @@ def drop_tags(self, tags=None): """ if tags is not None: tags = tags_to_oset(tags) - tids = self._get_tids_from_tags(tags, which='any') + tids = self._get_tids_from_tags(tags, which="any") else: tids = self.tensor_map.keys() @@ -4132,7 +4213,7 @@ def retag(self, tag_map, inplace=False): tn = self if inplace else self.copy() # get ids of tensors which have any of the tags - tids = tn._get_tids_from_tags(tag_map.keys(), which='any') + tids = tn._get_tids_from_tags(tag_map.keys(), which="any") for tid in tids: t = tn.tensor_map[tid] @@ -4187,8 +4268,7 @@ def mangle_inner_(self, append=None, which=None): return self def conj(self, mangle_inner=False, inplace=False): - """Conjugate all the tensors in this network (leaves all indices). - """ + """Conjugate all the tensors in this network (leaves all indices).""" tn = self if inplace else self.copy() for t in tn: @@ -4204,14 +4284,12 @@ def conj(self, mangle_inner=False, inplace=False): @property def H(self): - """Conjugate all the tensors in this network (leaves all indices). - """ + """Conjugate all the tensors in this network (leaves all indices).""" return self.conj() def item(self): - """Return the scalar value of this tensor network, if it is a scalar. - """ - t, = self.tensor_map.values() + """Return the scalar value of this tensor network, if it is a scalar.""" + (t,) = self.tensor_map.values() return t.item() def largest_element(self): @@ -4234,12 +4312,12 @@ def norm(self, **contract_opts): root of the sum of squared singular values across any partition. """ norm = self | self.conj() - return norm.contract(**contract_opts)**0.5 + return norm.contract(**contract_opts) ** 0.5 def make_norm( self, - mangle_append='*', - layer_tags=('KET', 'BRA'), + mangle_append="*", + layer_tags=("KET", "BRA"), return_all=False, ): """Make the norm tensor network of this tensor network ``tn.H & tn``. @@ -4281,7 +4359,7 @@ def multiply(self, x, inplace=False, spread_over=8): """ multiplied = self if inplace else self.copy() - if spread_over == 'all': + if spread_over == "all": spread_over = self.num_tensors else: spread_over = min(self.num_tensors, spread_over) @@ -4294,7 +4372,7 @@ def multiply(self, x, inplace=False, spread_over=8): if iscomplex(x): x_sign = 1.0 else: - x_sign = do('sign', x) + x_sign = do("sign", x) x = abs(x) x_spread = x ** (1 / spread_over) @@ -4340,28 +4418,23 @@ def multiply_each(self, x, inplace=False): multiply_each_ = functools.partialmethod(multiply_each, inplace=True) def __mul__(self, other): - """Scalar multiplication. - """ + """Scalar multiplication.""" return self.multiply(other) def __rmul__(self, other): - """Right side scalar multiplication. - """ + """Right side scalar multiplication.""" return self.multiply(other) def __imul__(self, other): - """Inplace scalar multiplication. - """ + """Inplace scalar multiplication.""" return self.multiply_(other) def __truediv__(self, other): - """Scalar division. - """ + """Scalar division.""" return self.multiply(other**-1) def __itruediv__(self, other): - """Inplace scalar division. - """ + """Inplace scalar division.""" return self.multiply_(other**-1) def __iter__(self): @@ -4369,14 +4442,12 @@ def __iter__(self): @property def tensors(self): - """Get the tuple of tensors in this tensor network. - """ + """Get the tuple of tensors in this tensor network.""" return tuple(self.tensor_map.values()) @property def arrays(self): - """Get the tuple of raw arrays containing all the tensor network data. - """ + """Get the tuple of raw arrays containing all the tensor network data.""" return tuple(t.data for t in self) def get_symbol_map(self): @@ -4441,8 +4512,8 @@ def get_inputs_output_size_dict(self, output_inds=None): get_symbol_map, get_equation """ eq = self.get_equation(output_inds=output_inds) - lhs, output = eq.split('->') - inputs = lhs.split(',') + lhs, output = eq.split("->") + inputs = lhs.split(",") size_dict = {} for term, t in zip(inputs, self): for k, d in zip(term, t.shape): @@ -4501,11 +4572,15 @@ def geometry_hash(self, output_inds=None, strict_index_order=False): ) if strict_index_order: - return hashlib.sha1(pickle.dumps(( - tuple(map(tuple, inputs)), - tuple(output), - sortedtuple(size_dict.items()) - ))).hexdigest() + return hashlib.sha1( + pickle.dumps( + ( + tuple(map(tuple, inputs)), + tuple(output), + sortedtuple(size_dict.items()), + ) + ) + ).hexdigest() edges = collections.defaultdict(list) for ix in output: @@ -4517,9 +4592,9 @@ def geometry_hash(self, output_inds=None, strict_index_order=False): # then sort edges by each's incidence nodes canonical_edges = sortedtuple(map(sortedtuple, edges.values())) - return hashlib.sha1(pickle.dumps(( - canonical_edges, sortedtuple(size_dict.items()) - ))).hexdigest() + return hashlib.sha1( + pickle.dumps((canonical_edges, sortedtuple(size_dict.items()))) + ).hexdigest() def tensors_sorted(self): """Return a tuple of tensors sorted by their respective tags, such that @@ -4541,13 +4616,13 @@ def apply_to_arrays(self, fn): # ----------------- selecting and splitting the network ----------------- # def _get_tids_from(self, xmap, xs, which): - inverse = which[0] == '!' + inverse = which[0] == "!" if inverse: which = which[1:] combine = { - 'all': oset_intersection, - 'any': oset_union, + "all": oset_intersection, + "any": oset_union, }[which] tid_sets = tuple(xmap[x] for x in xs) @@ -4561,7 +4636,7 @@ def _get_tids_from(self, xmap, xs, which): return tids - def _get_tids_from_tags(self, tags, which='all'): + def _get_tids_from_tags(self, tags, which="all"): """Return the set of tensor ids that match ``tags``. Parameters @@ -4587,15 +4662,13 @@ def _get_tids_from_tags(self, tags, which='all'): return self._get_tids_from(self.tag_map, tags, which) - def _get_tids_from_inds(self, inds, which='all'): - """Like ``_get_tids_from_tags`` but specify inds instead. - """ + def _get_tids_from_inds(self, inds, which="all"): + """Like ``_get_tids_from_tags`` but specify inds instead.""" inds = tags_to_oset(inds) return self._get_tids_from(self.ind_map, inds, which) def _tids_get(self, *tids): - """Convenience function that generates unique tensors from tids. - """ + """Convenience function that generates unique tensors from tids.""" seen = set() sadd = seen.add tmap = self.tensor_map @@ -4605,8 +4678,7 @@ def _tids_get(self, *tids): sadd(tid) def _inds_get(self, *inds): - """Convenience function that generates unique tensors from inds. - """ + """Convenience function that generates unique tensors from inds.""" seen = set() sadd = seen.add tmap = self.tensor_map @@ -4618,8 +4690,7 @@ def _inds_get(self, *inds): sadd(tid) def _tags_get(self, *tags): - """Convenience function that generates unique tensors from tags. - """ + """Convenience function that generates unique tensors from tags.""" seen = set() sadd = seen.add tmap = self.tensor_map @@ -4630,7 +4701,7 @@ def _tags_get(self, *tags): yield tmap[tid] sadd(tid) - def select_tensors(self, tags, which='all'): + def select_tensors(self, tags, which="all"): """Return the sequence of tensors that match ``tags``. If ``which='all'``, each tensor must contain every tag. If ``which='any'``, each tensor can contain any of the tags. @@ -4673,7 +4744,7 @@ def _select_without_tids(self, tids, virtual=True): tn.pop_tensor(tid) return tn - def select(self, tags, which='all', virtual=True): + def select(self, tags, which="all", virtual=True): """Get a TensorNetwork comprising tensors that match all or any of ``tags``, inherit the network properties/structure from ``self``. This returns a view of the tensors not a copy. @@ -4700,10 +4771,10 @@ def select(self, tags, which='all', virtual=True): tagged_tids = self._get_tids_from_tags(tags, which=which) return self._select_tids(tagged_tids, virtual=virtual) - select_any = functools.partialmethod(select, which='any') - select_all = functools.partialmethod(select, which='all') + select_any = functools.partialmethod(select, which="any") + select_all = functools.partialmethod(select, which="all") - def select_neighbors(self, tags, which='any'): + def select_neighbors(self, tags, which="any"): """Select any neighbouring tensors to those specified by ``tags``.self Parameters @@ -4746,8 +4817,11 @@ def _select_local_tids( exclude=None, ): span = self.get_tree_span( - tids, max_distance=max_distance, - include=include, exclude=exclude, inwards=inwards, + tids, + max_distance=max_distance, + include=include, + exclude=exclude, + inwards=inwards, ) local_tids = oset(tids) for s in span: @@ -4768,42 +4842,52 @@ def _select_local_tids( tn_sl = self._select_tids(local_tids, virtual=virtual) # optionally remove/reduce outer indices that appear outside `tag` - if reduce_outer == 'sum': + if reduce_outer == "sum": for ix in tn_sl.outer_inds(): - tid_edge, = tn_sl.ind_map[ix] + (tid_edge,) = tn_sl.ind_map[ix] if tid_edge in tids: continue tn_sl.tensor_map[tid_edge].sum_reduce_(ix) - elif reduce_outer == 'svd': + elif reduce_outer == "svd": for ix in tn_sl.outer_inds(): # get the tids that stretch across the border tid_out, tid_in = sorted( - self.ind_map[ix], key=tn_sl.tensor_map.__contains__) + self.ind_map[ix], key=tn_sl.tensor_map.__contains__ + ) # rank-1 decompose the outer tensor - l, r = self.tensor_map[tid_out].split( - left_inds=None, right_inds=[ix], - max_bond=1, get='arrays', absorb='left') + _, r = self.tensor_map[tid_out].split( + left_inds=None, + right_inds=[ix], + max_bond=1, + get="arrays", + absorb="left", + ) # absorb the factor into the inner tensor to remove that ind tn_sl.tensor_map[tid_in].gate_(r, ix).squeeze_(include=[ix]) - elif reduce_outer == 'svd-sum': + elif reduce_outer == "svd-sum": for ix in tn_sl.outer_inds(): # get the tids that stretch across the border tid_out, tid_in = sorted( - self.ind_map[ix], key=tn_sl.tensor_map.__contains__) + self.ind_map[ix], key=tn_sl.tensor_map.__contains__ + ) # full-rank decompose the outer tensor l, r = self.tensor_map[tid_out].split( - left_inds=None, right_inds=[ix], - max_bond=None, get='arrays', absorb='left') + left_inds=None, + right_inds=[ix], + max_bond=None, + get="arrays", + absorb="left", + ) # absorb the factor into the inner tensor then sum over it tn_sl.tensor_map[tid_in].gate_(r, ix).sum_reduce_(ix) - elif reduce_outer == 'reflect': + elif reduce_outer == "reflect": tn_sl |= tn_sl.H return tn_sl @@ -4811,7 +4895,7 @@ def _select_local_tids( def select_local( self, tags, - which='all', + which="all", max_distance=1, fillin=False, reduce_outer=None, @@ -4859,8 +4943,11 @@ def select_local( ------- TensorNetwork """ - check_opt('reduce_outer', reduce_outer, - (None, 'sum', 'svd', 'svd-sum', 'reflect')) + check_opt( + "reduce_outer", + reduce_outer, + (None, "sum", "svd", "svd-sum", "reflect"), + ) return self._select_local_tids( tids=self._get_tids_from_tags(tags, which), @@ -4869,7 +4956,8 @@ def select_local( reduce_outer=reduce_outer, virtual=virtual, include=include, - exclude=exclude) + exclude=exclude, + ) def __getitem__(self, tags): """Get the tensor(s) associated with ``tags``. @@ -4886,7 +4974,7 @@ def __getitem__(self, tags): if isinstance(tags, slice): return self.select_any(self.maybe_convert_coo(tags)) - tensors = self.select_tensors(tags, which='all') + tensors = self.select_tensors(tags, which="all") if len(tensors) == 0: raise KeyError(f"Couldn't find any tensors matching {tags}.") @@ -4897,29 +4985,30 @@ def __getitem__(self, tags): return tensors def __setitem__(self, tags, tensor): - """Set the single tensor uniquely associated with ``tags``. - """ - tids = self._get_tids_from_tags(tags, which='all') + """Set the single tensor uniquely associated with ``tags``.""" + tids = self._get_tids_from_tags(tags, which="all") if len(tids) != 1: - raise KeyError("'TensorNetwork.__setitem__' is meant for a single " - "existing tensor only - found {} with tag(s) '{}'." - .format(len(tids), tags)) + raise KeyError( + "'TensorNetwork.__setitem__' is meant for a single " + "existing tensor only - found {} with tag(s) '{}'.".format( + len(tids), tags + ) + ) if not isinstance(tensor, Tensor): raise TypeError("Can only set value with a new 'Tensor'.") - tid, = tids + (tid,) = tids self.pop_tensor(tid) self.add_tensor(tensor, tid=tid, virtual=True) def __delitem__(self, tags): - """Delete any tensors which have all of ``tags``. - """ - tids = self._get_tids_from_tags(tags, which='all') + """Delete any tensors which have all of ``tags``.""" + tids = self._get_tids_from_tags(tags, which="all") for tid in tuple(tids): self.pop_tensor(tid) - def partition_tensors(self, tags, inplace=False, which='any'): + def partition_tensors(self, tags, inplace=False, which="any"): """Split this TN into a list of tensors containing any or all of ``tags`` and a ``TensorNetwork`` of the the rest. @@ -4951,7 +5040,7 @@ def partition_tensors(self, tags, inplace=False, which='any'): return untagged_tn, tagged_ts - def partition(self, tags, which='any', inplace=False): + def partition(self, tags, which="any", inplace=False): """Split this TN into two, based on which tensors have any or all of ``tags``. Unlike ``partition_tensors``, both results are TNs which inherit the structure of the initial TN. @@ -4976,7 +5065,7 @@ def partition(self, tags, which='any', inplace=False): """ tagged_tids = self._get_tids_from_tags(tags, which=which) - kws = {'check_collisions': False} + kws = {"check_collisions": False} if inplace: t1 = self @@ -4997,7 +5086,7 @@ def partition(self, tags, which='any', inplace=False): def _split_tensor_tid(self, tid, left_inds, **split_opts): t = self.pop_tensor(tid) - tl, tr = t.split(left_inds=left_inds, get='tensors', **split_opts) + tl, tr = t.split(left_inds=left_inds, get="tensors", **split_opts) self.add_tensor(tl) self.add_tensor(tr) return self @@ -5012,10 +5101,10 @@ def split_tensor( resulting tensors from the decomposition back into the network. Inplace operation. """ - tid, = self._get_tids_from_tags(tags, which='all') + (tid,) = self._get_tids_from_tags(tags, which="all") self._split_tensor_tid(tid, left_inds, **split_opts) - def replace_with_identity(self, where, which='any', inplace=False): + def replace_with_identity(self, where, which="any", inplace=False): r"""Replace all tensors marked by ``where`` with an identity. E.g. if ``X`` denote ``where`` tensors:: @@ -5048,23 +5137,40 @@ def replace_with_identity(self, where, which='any', inplace=False): return tn (dl, il), (dr, ir) = TensorNetwork( - self.select_tensors(where, which=which)).outer_dims_inds() + self.select_tensors(where, which=which) + ).outer_dims_inds() if dl != dr: raise ValueError( "Can only replace_with_identity when the remaining indices " - f"have matching dimensions, but {dl} != {dr}.") + f"have matching dimensions, but {dl} != {dr}." + ) tn.delete(where, which=which) tn.reindex_({il: ir}) return tn - def replace_with_svd(self, where, left_inds, eps, *, which='any', - right_inds=None, method='isvd', max_bond=None, - absorb='both', cutoff_mode='rel', renorm=None, - ltags=None, rtags=None, keep_tags=True, - start=None, stop=None, inplace=False): + def replace_with_svd( + self, + where, + left_inds, + eps, + *, + which="any", + right_inds=None, + method="isvd", + max_bond=None, + absorb="both", + cutoff_mode="rel", + renorm=None, + ltags=None, + rtags=None, + keep_tags=True, + start=None, + stop=None, + inplace=False, + ): r"""Replace all tensors marked by ``where`` with an iteratively constructed SVD. E.g. if ``X`` denote ``where`` tensors:: @@ -5121,8 +5227,9 @@ def replace_with_svd(self, where, left_inds, eps, *, which='any', -------- replace_with_identity """ - leave, svd_section = self.partition(where, which=which, - inplace=inplace) + leave, svd_section = self.partition( + where, which=which, inplace=inplace + ) tags = svd_section.tags if keep_tags else oset() ltags = tags_to_oset(ltags) @@ -5130,31 +5237,47 @@ def replace_with_svd(self, where, left_inds, eps, *, which='any', if right_inds is None: # compute - right_inds = tuple(i for i in svd_section.outer_inds() - if i not in left_inds) + right_inds = tuple( + i for i in svd_section.outer_inds() if i not in left_inds + ) if (start is None) and (stop is None): - A = svd_section.aslinearoperator(left_inds=left_inds, - right_inds=right_inds) + A = svd_section.aslinearoperator( + left_inds=left_inds, right_inds=right_inds + ) else: from .tensor_1d import TNLinearOperator1D # check if need to invert start stop as well - if '!' in which: + if "!" in which: start, stop = stop, start + self.L left_inds, right_inds = right_inds, left_inds ltags, rtags = rtags, ltags - A = TNLinearOperator1D(svd_section, start=start, stop=stop, - left_inds=left_inds, right_inds=right_inds) + A = TNLinearOperator1D( + svd_section, + start=start, + stop=stop, + left_inds=left_inds, + right_inds=right_inds, + ) ltags = tags | ltags rtags = tags | rtags - TL, TR = tensor_split(A, left_inds=left_inds, right_inds=right_inds, - method=method, cutoff=eps, absorb=absorb, - max_bond=max_bond, cutoff_mode=cutoff_mode, - renorm=renorm, ltags=ltags, rtags=rtags) + TL, TR = tensor_split( + A, + left_inds=left_inds, + right_inds=right_inds, + method=method, + cutoff=eps, + absorb=absorb, + max_bond=max_bond, + cutoff_mode=cutoff_mode, + renorm=renorm, + ltags=ltags, + rtags=rtags, + ) leave |= TL leave |= TR @@ -5163,8 +5286,9 @@ def replace_with_svd(self, where, left_inds, eps, *, which='any', replace_with_svd_ = functools.partialmethod(replace_with_svd, inplace=True) - def replace_section_with_svd(self, start, stop, eps, - **replace_with_svd_opts): + def replace_section_with_svd( + self, start, stop, eps, **replace_with_svd_opts + ): """Take a 1D tensor network, and replace a section with a SVD. See :meth:`~quimb.tensor.tensor_core.TensorNetwork.replace_with_svd`. @@ -5185,19 +5309,23 @@ def replace_section_with_svd(self, start, stop, eps, TensorNetwork """ return self.replace_with_svd( - where=slice(start, stop), start=start, stop=stop, - left_inds=bonds(self[start - 1], self[start]), eps=eps, - **replace_with_svd_opts) + where=slice(start, stop), + start=start, + stop=stop, + left_inds=bonds(self[start - 1], self[start]), + eps=eps, + **replace_with_svd_opts, + ) def convert_to_zero(self): - """Inplace conversion of this network to an all zero tensor network. - """ + """Inplace conversion of this network to an all zero tensor network.""" outer_inds = self.outer_inds() for T in self: - new_shape = tuple(d if i in outer_inds else 1 - for d, i in zip(T.shape, T.inds)) - T.modify(data=do('zeros', new_shape, dtype=T.dtype, like=T.data)) + new_shape = tuple( + d if i in outer_inds else 1 for d, i in zip(T.shape, T.inds) + ) + T.modify(data=do("zeros", new_shape, dtype=T.dtype, like=T.data)) def _contract_between_tids( self, @@ -5226,7 +5354,8 @@ def _contract_between_tids( t1.multiply_index_diagonal_(ix, g) t12 = tensor_contract( - t1, t2, + t1, + t2, output_inds=local_output_inds, preserve_tensor=True, **contract_opts, @@ -5251,19 +5380,21 @@ def contract_between(self, tags1, tags2, **contract_opts): contract_opts Supplied to :func:`~quimb.tensor.tensor_core.tensor_contract`. """ - tid1, = self._get_tids_from_tags(tags1, which='all') - tid2, = self._get_tids_from_tags(tags2, which='all') + (tid1,) = self._get_tids_from_tags(tags1, which="all") + (tid2,) = self._get_tids_from_tags(tags2, which="all") self._contract_between_tids(tid1, tid2, **contract_opts) def contract_ind(self, ind, output_inds=None, **contract_opts): - """Contract tensors connected by ``ind``. - """ + """Contract tensors connected by ``ind``.""" tids = tuple(self._get_tids_from_inds(ind)) output_inds = self.compute_contracted_inds( - *tids, output_inds=output_inds) + *tids, output_inds=output_inds + ) tnew = tensor_contract( - *map(self.pop_tensor, tids), output_inds=output_inds, - preserve_tensor=True, **contract_opts + *map(self.pop_tensor, tids), + output_inds=output_inds, + preserve_tensor=True, + **contract_opts, ) self.add_tensor(tnew, tid=tids[0], virtual=True) @@ -5382,7 +5513,7 @@ def _compute_tree_gauges(self, tree, outputs): t_outer = t_outer.gate(Gs.pop(ix), ix) # compute the reduced factor to accumulated inwards - new_G = t_outer.compute_reduced_factor('right', outer_ix, inner_ix) + new_G = t_outer.compute_reduced_factor("right", outer_ix, inner_ix) # store the normalized gauge associated with the tree bond Gs[inner_ix] = new_G / do("linalg.norm", new_G) @@ -5399,7 +5530,7 @@ def _compute_tree_gauges(self, tree, outputs): t_outer = t_outer.gate(Gs.pop(ix), ix) # compute the final reduced factor - Gout = t_outer.compute_reduced_factor('right', outer_ix, ind) + Gout = t_outer.compute_reduced_factor("right", outer_ix, ind) Gouts.append(Gout) return Gouts @@ -5445,15 +5576,17 @@ def _compress_between_virtual_tree_tids( tr.gate_(Pr, bix) def _compute_bond_env( - self, tid1, tid2, + self, + tid1, + tid2, select_local_distance=None, select_local_opts=None, max_bond=None, cutoff=None, - method='contract_around', + method="contract_around", contract_around_opts=None, contract_compressed_opts=None, - optimize='auto-hq', + optimize="auto-hq", include=None, exclude=None, ): @@ -5467,12 +5600,16 @@ def _compute_bond_env( else: # ... or just a local patch of the TN (with dangling bonds removed) select_local_opts = ensure_dict(select_local_opts) - select_local_opts.setdefault('reduce_outer', 'svd') + select_local_opts.setdefault("reduce_outer", "svd") tn_env = self._select_local_tids( - (tid1, tid2), max_distance=select_local_distance, - virtual=False, include=include, exclude=exclude, - **select_local_opts) + (tid1, tid2), + max_distance=select_local_distance, + virtual=False, + include=include, + exclude=exclude, + **select_local_opts, + ) # not propagated by _select_local_tids tn_env.exponent = self.exponent @@ -5480,25 +5617,30 @@ def _compute_bond_env( # cut the bond between the two target tensors in the local TN t1 = tn_env.tensor_map[tid1] t2 = tn_env.tensor_map[tid2] - bond, = t1.bonds(t2) + (bond,) = t1.bonds(t2) lcut = rand_uuid() rcut = rand_uuid() t1.reindex_({bond: lcut}) t2.reindex_({bond: rcut}) if max_bond is not None: - if method == 'contract_around': + if method == "contract_around": tn_env._contract_around_tids( - (tid1, tid2), max_bond=max_bond, cutoff=cutoff, - **ensure_dict(contract_around_opts)) + (tid1, tid2), + max_bond=max_bond, + cutoff=cutoff, + **ensure_dict(contract_around_opts), + ) - elif method == 'contract_compressed': + elif method == "contract_compressed": tn_env.contract_compressed_( - max_bond=max_bond, cutoff=cutoff, - **ensure_dict(contract_compressed_opts)) + max_bond=max_bond, + cutoff=cutoff, + **ensure_dict(contract_compressed_opts), + ) else: - raise ValueError(f'Unknown method: {method}') + raise ValueError(f"Unknown method: {method}") return tn_env.to_dense([lcut], [rcut], optimize=optimize) @@ -5508,23 +5650,23 @@ def _compress_between_full_bond_tids( tid2, max_bond, cutoff=0.0, - absorb='both', + absorb="both", renorm=False, - method='eigh', + method="eigh", select_local_distance=None, select_local_opts=None, - env_max_bond='max_bond', - env_cutoff='cutoff', - env_method='contract_around', + env_max_bond="max_bond", + env_cutoff="cutoff", + env_method="contract_around", contract_around_opts=None, contract_compressed_opts=None, - env_optimize='auto-hq', + env_optimize="auto-hq", include=None, exclude=None, ): - if env_max_bond == 'max_bond': + if env_max_bond == "max_bond": env_max_bond = max_bond - if env_cutoff == 'cutoff': + if env_cutoff == "cutoff": env_cutoff = cutoff ta = self.tensor_map[tid1] @@ -5536,7 +5678,8 @@ def _compress_between_full_bond_tids( return E = self._compute_bond_env( - tid1, tid2, + tid1, + tid2, select_local_distance=select_local_distance, select_local_opts=select_local_opts, max_bond=env_max_bond, @@ -5550,13 +5693,14 @@ def _compress_between_full_bond_tids( ) Cl, Cr = decomp.similarity_compress( - E, max_bond, method=method, renorm=renorm) + E, max_bond, method=method, renorm=renorm + ) # absorb them into the tensors to compress this bond ta.gate_(Cr, bond) tb.gate_(Cl.T, bond) - if absorb != 'both': + if absorb != "both": tensor_canonize_bond(ta, tb, absorb=absorb) def _compress_between_local_fit( @@ -5565,35 +5709,42 @@ def _compress_between_local_fit( tid2, max_bond, cutoff=0.0, - absorb='both', - method='als', + absorb="both", + method="als", select_local_distance=1, select_local_opts=None, include=None, exclude=None, - **fit_opts + **fit_opts, ): if cutoff != 0.0: import warnings + warnings.warn("Non-zero cutoff ignored by local fit compress.") select_local_opts = ensure_dict(select_local_opts) tn_loc_target = self._select_local_tids( (tid1, tid2), - max_distance=select_local_distance, virtual=False, - include=include, exclude=exclude, **select_local_opts) + max_distance=select_local_distance, + virtual=False, + include=include, + exclude=exclude, + **select_local_opts, + ) tn_loc_compress = tn_loc_target.copy() tn_loc_compress._compress_between_tids( - tid1, tid2, max_bond=max_bond, cutoff=0.0) + tid1, tid2, max_bond=max_bond, cutoff=0.0 + ) tn_loc_opt = tn_loc_compress.fit_( - tn_loc_target, method=method, **fit_opts) + tn_loc_target, method=method, **fit_opts + ) for tid, t in tn_loc_opt.tensor_map.items(): self.tensor_map[tid].modify(data=t.data) - if absorb != 'both': + if absorb != "both": self._canonize_between_tids(tid1, tid2, absorb=absorb) def _compress_between_tids( @@ -5602,17 +5753,17 @@ def _compress_between_tids( tid2, max_bond=None, cutoff=1e-10, - absorb='both', + absorb="both", canonize_distance=None, canonize_opts=None, canonize_after_distance=None, canonize_after_opts=None, - mode='basic', + mode="basic", equalize_norms=False, gauges=None, gauge_smudge=1e-6, callback=None, - **compress_opts + **compress_opts, ): ta = self.tensor_map[tid1] tb = self.tensor_map[tid2] @@ -5628,9 +5779,10 @@ def _compress_between_tids( # special case - fixing any orthonormal basis for the left or # right tensor (whichever has smallest outer dimensions) will # produce the required compression without any SVD - compress_absorb = 'right' if lsize <= rsize else 'left' + compress_absorb = "right" if lsize <= rsize else "left" tensor_canonize_bond( - ta, tb, + ta, + tb, absorb=compress_absorb, gauges=gauges, gauge_smudge=gauge_smudge, @@ -5638,7 +5790,8 @@ def _compress_between_tids( if absorb != compress_absorb: tensor_canonize_bond( - ta, tb, + ta, + tb, absorb=absorb, gauges=gauges, gauge_smudge=gauge_smudge, @@ -5650,12 +5803,12 @@ def _compress_between_tids( return - compress_opts['max_bond'] = max_bond - compress_opts['cutoff'] = cutoff - compress_opts['absorb'] = absorb + compress_opts["max_bond"] = max_bond + compress_opts["cutoff"] = cutoff + compress_opts["absorb"] = absorb if gauges is not None: - compress_opts['gauges'] = gauges - compress_opts['gauge_smudge'] = gauge_smudge + compress_opts["gauges"] = gauges + compress_opts["gauge_smudge"] = gauge_smudge if isinstance(mode, str) and "virtual" in mode: # canonize distance is handled by the virtual tree @@ -5673,13 +5826,13 @@ def _compress_between_tids( if canonize_distance: # gauge around pair by absorbing QR factors along bonds canonize_opts = ensure_dict(canonize_opts) - canonize_opts.setdefault('equalize_norms', equalize_norms) + canonize_opts.setdefault("equalize_norms", equalize_norms) self._canonize_around_tids( (tid1, tid2), gauges=gauges, gauge_smudge=gauge_smudge, max_distance=canonize_distance, - **canonize_opts + **canonize_opts, ) if mode == 'basic': @@ -5700,9 +5853,7 @@ def _compress_between_tids( ) else: # assume callable - mode( - self, tid1, tid2, **compress_opts - ) + mode(self, tid1, tid2, **compress_opts) if equalize_norms: self.strip_exponent(tid1, equalize_norms) @@ -5715,7 +5866,7 @@ def _compress_between_tids( tids=(tid1, tid2), max_distance=canonize_after_distance, gauges=gauges, - **canonize_after_opts + **canonize_after_opts, ) if callback is not None: @@ -5727,7 +5878,7 @@ def compress_between( tags2, max_bond=None, cutoff=1e-10, - absorb='both', + absorb="both", canonize_distance=0, canonize_opts=None, equalize_norms=False, @@ -5777,22 +5928,23 @@ def compress_between( -------- canonize_between """ - tid1, = self._get_tids_from_tags(tags1, which='all') - tid2, = self._get_tids_from_tags(tags2, which='all') + (tid1,) = self._get_tids_from_tags(tags1, which="all") + (tid2,) = self._get_tids_from_tags(tags2, which="all") self._compress_between_tids( - tid1, tid2, + tid1, + tid2, max_bond=max_bond, cutoff=cutoff, absorb=absorb, canonize_distance=canonize_distance, canonize_opts=canonize_opts, equalize_norms=equalize_norms, - **compress_opts) + **compress_opts, + ) def compress_all(self, inplace=False, **compress_opts): - """Inplace compress all bonds in this network. - """ + """Inplace compress all bonds in this network.""" tn = self if inplace else self.copy() tn.fuse_multibonds_() @@ -5825,8 +5977,12 @@ def sorter(t, tn, distances, connectivity): for tid1, tid2, _ in span: # absorb='right' shifts orthog center inwards tn._compress_between_tids( - tid1, tid2, absorb='right', - canonize_distance=float('inf'), **compress_opts) + tid1, + tid2, + absorb="right", + canonize_distance=float("inf"), + **compress_opts, + ) return tn @@ -5871,10 +6027,10 @@ def compress_all_1d( if canonize: for tida, tidb, _ in span: - tn._canonize_between_tids(tida, tidb, absorb='right') - compress_opts.setdefault('absorb', 'right') + tn._canonize_between_tids(tida, tidb, absorb="right") + compress_opts.setdefault("absorb", "right") else: - compress_opts.setdefault('absorb', 'both') + compress_opts.setdefault("absorb", "both") for tida, tidb, _ in reversed(span): tn._compress_between_tids( @@ -5901,8 +6057,7 @@ def compress_all_simple( inplace=False, **gauge_simple_opts, ): - """ - """ + """ """ if max_iterations < 1: raise ValueError("Must have at least one iteration to compress.") @@ -5919,7 +6074,7 @@ def compress_all_simple( tol=tol, smudge=smudge, power=power, - **gauge_simple_opts + **gauge_simple_opts, ) # truncate the tensors @@ -5954,7 +6109,7 @@ def _canonize_between_tids( self, tid1, tid2, - absorb='right', + absorb="right", gauges=None, gauge_smudge=1e-6, equalize_norms=False, @@ -5963,18 +6118,19 @@ def _canonize_between_tids( Tl = self.tensor_map[tid1] Tr = self.tensor_map[tid2] tensor_canonize_bond( - Tl, Tr, + Tl, + Tr, absorb=absorb, gauges=gauges, gauge_smudge=gauge_smudge, - **canonize_opts + **canonize_opts, ) if equalize_norms: self.strip_exponent(tid1, equalize_norms) self.strip_exponent(tid2, equalize_norms) - def canonize_between(self, tags1, tags2, absorb='right', **canonize_opts): + def canonize_between(self, tags1, tags2, absorb="right", **canonize_opts): r"""'Canonize' the bond between the two single tensors in this network specified by ``tags1`` and ``tags2`` using ``tensor_canonize_bond``:: @@ -6008,8 +6164,8 @@ def canonize_between(self, tags1, tags2, absorb='right', **canonize_opts): -------- compress_between """ - tid1, = self._get_tids_from_tags(tags1, which='all') - tid2, = self._get_tids_from_tags(tags2, which='all') + (tid1,) = self._get_tids_from_tags(tags1, which="all") + (tid2,) = self._get_tids_from_tags(tags2, which="all") self._canonize_between_tids(tid1, tid2, absorb=absorb, **canonize_opts) def reduce_inds_onto_bond( @@ -6025,24 +6181,26 @@ def reduce_inds_onto_bond( of their respective tensors and onto the bond between them. This is an inplace operation. """ - tida, = self._get_tids_from_inds(inda) - tidb, = self._get_tids_from_inds(indb) + (tida,) = self._get_tids_from_inds(inda) + (tidb,) = self._get_tids_from_inds(indb) ta, tb = self._tids_get(tida, tidb) bix = bonds(ta, tb) if ta.ndim > ndim_cutoff: self._split_tensor_tid( - tida, left_inds=None, right_inds=[inda, *bix], method='qr') + tida, left_inds=None, right_inds=[inda, *bix], method="qr" + ) # get new location of ind - tida, = self._get_tids_from_inds(inda) + (tida,) = self._get_tids_from_inds(inda) else: drop_tags = False if tb.ndim > ndim_cutoff: self._split_tensor_tid( - tidb, left_inds=None, right_inds=[indb, *bix], method='qr') + tidb, left_inds=None, right_inds=[indb, *bix], method="qr" + ) # get new location of ind - tidb, = self._get_tids_from_inds(indb) + (tidb,) = self._get_tids_from_inds(indb) else: drop_tags = False @@ -6050,7 +6208,7 @@ def reduce_inds_onto_bond( tags = tags_to_oset(tags) if combine: self._contract_between_tids(tida, tidb) - tab, = self._inds_get(inda, indb) + (tab,) = self._inds_get(inda, indb) # modify with the desired tags if drop_tags: @@ -6059,8 +6217,8 @@ def reduce_inds_onto_bond( tab.modify(tags=tab.tags | tags) else: - ta, = self._inds_get(inda) - tb, = self._inds_get(indb) + (ta,) = self._inds_get(inda) + (tb,) = self._inds_get(indb) if drop_tags: ta.modify(tags=tags) tb.modify(tags=tags) @@ -6161,7 +6319,6 @@ def subgraphs(self, virtual=False): # check all nodes while tids: - # get a remaining node tid0 = tids.popright() queue = [tid0] @@ -6190,8 +6347,8 @@ def get_tree_span( max_distance=None, include=None, exclude=None, - ndim_sort='max', - distance_sort='min', + ndim_sort="max", + distance_sort="min", sorter=None, weight_bonds=True, inwards=True, @@ -6269,12 +6426,11 @@ def get_tree_span( # given equal connectivity compare neighbors based on # min/max distance and min/max ndim - distance_coeff = {'min': -1, 'max': 1, 'none': 0}[distance_sort] - ndim_coeff = {'min': -1, 'max': 1, 'none': 0}[ndim_sort] + distance_coeff = {"min": -1, "max": 1, "none": 0}[distance_sort] + ndim_coeff = {"min": -1, "max": 1, "none": 0}[ndim_sort] def _check_candidate(tid_surface, tid_neighb): - """Check the expansion of ``tid_surface`` to ``tid_neighb``. - """ + """Check the expansion of ``tid_surface`` to ``tid_neighb``.""" if (tid_neighb in region) or (tid_neighb not in allowed): # we've already absorbed it, or we're not allowed to return @@ -6291,13 +6447,17 @@ def _check_candidate(tid_surface, tid_neighb): # keep track of how connected to the current surface potential new # nodes are if weight_bonds: - connectivity[tid_neighb] += math.log2(bonds_size( - self.tensor_map[tid_surface], self.tensor_map[tid_neighb] - )) + connectivity[tid_neighb] += math.log2( + bonds_size( + self.tensor_map[tid_surface], + self.tensor_map[tid_neighb], + ) + ) else: connectivity[tid_neighb] += 1 if sorter is None: + def _sorter(t): # how to pick which tensor to absorb into the expanding surface # here, choose the candidate that is most connected to current @@ -6310,8 +6470,8 @@ def _sorter(t): ) else: _sorter = functools.partial( - sorter, tn=self, distances=distances, - connectivity=connectivity) + sorter, tn=self, distances=distances, connectivity=connectivity + ) # setup the initial region and candidate nodes to expand to for tid_surface in region: @@ -6351,12 +6511,12 @@ def _draw_tree_span_tids( max_distance=None, include=None, exclude=None, - ndim_sort='max', - distance_sort='min', + ndim_sort="max", + distance_sort="min", sorter=None, weight_bonds=True, - color='order', - colormap='Spectral', + color="order", + colormap="Spectral", **draw_opts, ): tn = self.copy() @@ -6374,7 +6534,8 @@ def _draw_tree_span_tids( ndim_sort=ndim_sort, distance_sort=distance_sort, sorter=sorter, - weight_bonds=weight_bonds) + weight_bonds=weight_bonds, + ) for i, (tid1, tid2, d) in enumerate(span): # get the tensors on either side of this tree edge @@ -6383,18 +6544,19 @@ def _draw_tree_span_tids( # get the ind(s) connecting them tix |= oset(bonds(t1, t2)) - if color == 'distance': + if color == "distance": # tag the outer tensor with distance ``d`` - t1.add_tag(f'D{d}') + t1.add_tag(f"D{d}") ds.add(d) - elif color == 'order': + elif color == "order": d = len(span) - i - t1.add_tag(f'D{d}') + t1.add_tag(f"D{d}") ds.add(d) if colormap is not None: if isinstance(colormap, str): import matplotlib.cm + cmap = getattr(matplotlib.cm, colormap) else: cmap = colormap @@ -6402,26 +6564,26 @@ def _draw_tree_span_tids( else: custom_colors = None - draw_opts.setdefault('legend', False) - draw_opts.setdefault('edge_color', (0.85, 0.85, 0.85)) - draw_opts.setdefault('highlight_inds', tix) - draw_opts.setdefault('custom_colors', custom_colors) + draw_opts.setdefault("legend", False) + draw_opts.setdefault("edge_color", (0.85, 0.85, 0.85)) + draw_opts.setdefault("highlight_inds", tix) + draw_opts.setdefault("custom_colors", custom_colors) - return tn.draw(color=[f'D{d}' for d in sorted(ds)], **draw_opts) + return tn.draw(color=[f"D{d}" for d in sorted(ds)], **draw_opts) def draw_tree_span( self, tags, - which='all', + which="all", min_distance=0, max_distance=None, include=None, exclude=None, - ndim_sort='max', - distance_sort='min', + ndim_sort="max", + distance_sort="min", weight_bonds=True, - color='order', - colormap='Spectral', + color="order", + colormap="Spectral", **draw_opts, ): """Visualize a generated tree span out of the tensors tagged by @@ -6464,7 +6626,8 @@ def draw_tree_span( weight_bonds=weight_bonds, color=color, colormap=colormap, - **draw_opts) + **draw_opts, + ) graph_tree_span = draw_tree_span @@ -6476,13 +6639,13 @@ def _canonize_around_tids( include=None, exclude=None, span_opts=None, - absorb='right', + absorb="right", gauge_links=False, - link_absorb='both', + link_absorb="both", inwards=True, gauges=None, gauge_smudge=1e-6, - **canonize_opts + **canonize_opts, ): span_opts = ensure_dict(span_opts) seq = self.get_tree_span( @@ -6492,7 +6655,8 @@ def _canonize_around_tids( include=include, exclude=exclude, inwards=inwards, - **span_opts) + **span_opts, + ) if gauge_links: # if specified we first gauge *between* the branches @@ -6520,21 +6684,23 @@ def _canonize_around_tids( for _ in range(int(gauge_links)): for tid1, tid2 in links: self._canonize_between_tids( - tid1, tid2, + tid1, + tid2, absorb=link_absorb, gauges=gauges, gauge_smudge=gauge_smudge, - **canonize_opts + **canonize_opts, ) # gauge inwards *along* the branches for tid1, tid2, _ in seq: self._canonize_between_tids( - tid1, tid2, + tid1, + tid2, absorb=absorb, gauges=gauges, gauge_smudge=gauge_smudge, - **canonize_opts + **canonize_opts, ) return self @@ -6542,18 +6708,18 @@ def _canonize_around_tids( def canonize_around( self, tags, - which='all', + which="all", min_distance=0, max_distance=None, include=None, exclude=None, span_opts=None, - absorb='right', + absorb="right", gauge_links=False, - link_absorb='both', + link_absorb="both", equalize_norms=False, inplace=False, - **canonize_opts + **canonize_opts, ): r"""Expand a locally canonical region around ``tags``:: @@ -6641,14 +6807,15 @@ def canonize_around( gauge_links=gauge_links, link_absorb=link_absorb, equalize_norms=equalize_norms, - **canonize_opts) + **canonize_opts, + ) canonize_around_ = functools.partialmethod(canonize_around, inplace=True) def gauge_all_canonize( self, max_iterations=5, - absorb='both', + absorb="both", gauges=None, gauge_smudge=1e-6, equalize_norms=False, @@ -6668,11 +6835,12 @@ def gauge_all_canonize( # fused multibond (removed) or not a bond (len(tids != 2)) continue tn._canonize_between_tids( - tid1, tid2, + tid1, + tid2, absorb=absorb, gauges=gauges, gauge_smudge=gauge_smudge, - **canonize_opts + **canonize_opts, ) if equalize_norms: @@ -6686,7 +6854,8 @@ def gauge_all_canonize( return tn gauge_all_canonize_ = functools.partialmethod( - gauge_all_canonize, inplace=True) + gauge_all_canonize, inplace=True + ) def gauge_all_simple( self, @@ -6720,6 +6889,7 @@ def gauge_all_simple( if progbar: import tqdm + pbar = tqdm.tqdm() else: pbar = None @@ -6760,12 +6930,13 @@ def gauge_all_simple( # perform SVD to get new bond gauge tensor_compress_bond( - t1, t2, absorb=None, info=info, cutoff=0.0) + t1, t2, absorb=None, info=info, cutoff=0.0 + ) - s = info['singular_values'] + s = info["singular_values"] smax = s[0] new_gauge = s / smax - nfact = do('log10', smax) + nfact + nfact = do("log10", smax) + nfact if (tol > 0.0) or (pbar is not None): # check convergence @@ -6776,7 +6947,7 @@ def gauge_all_simple( # the singular values directly old_gauge = 1.0 - sdiff = do('linalg.norm', old_gauge - new_gauge) + sdiff = do("linalg.norm", old_gauge - new_gauge) max_sdiff = max(max_sdiff, sdiff) # update inner gauge and undo outer gauges @@ -6794,8 +6965,7 @@ def gauge_all_simple( if pbar is not None: pbar.update() pbar.set_description( - f"max|dS|={max_sdiff:.2e}, " - f"nfact={nfact:.2f}" + f"max|dS|={max_sdiff:.2e}, " f"nfact={nfact:.2f}" ) not_converged = (tol == 0.0) or (max_sdiff > tol) @@ -6805,7 +6975,7 @@ def gauge_all_simple( tn.exponent += nfact else: # redistribute the accrued scaling - tn.multiply_each_(10**(nfact / tn.num_tensors)) + tn.multiply_each_(10 ** (nfact / tn.num_tensors)) if not gauges_supplied: # absorb all bond gauges @@ -6820,11 +6990,7 @@ def gauge_all_simple( gauge_all_simple_ = functools.partialmethod(gauge_all_simple, inplace=True) def gauge_all_random( - self, - max_iterations=1, - unitary=True, - seed=None, - inplace=False + self, max_iterations=1, unitary=True, seed=None, inplace=False ): """Gauge all the bonds in this network randomly. This is largely for testing purposes. @@ -6848,11 +7014,11 @@ def gauge_all_random( if unitary: G = rand_uni(d, dtype=get_dtype_name(t1.data)) - G = do('array', G, like=t1.data) + G = do("array", G, like=t1.data) Ginv = dag(G) else: G = rand_matrix(d, dtype=get_dtype_name(t1.data)) - G = do('array', G, like=t1.data) + G = do("array", G, like=t1.data) Ginv = do("linalg.inv", G) t1.gate_(G, ix) @@ -6892,30 +7058,36 @@ def _gauge_local_tids( self, tids, max_distance=1, - max_iterations='max_distance', - method='canonize', + max_iterations="max_distance", + method="canonize", inwards=False, include=None, exclude=None, - **gauge_local_opts + **gauge_local_opts, ): """Iteratively gauge all bonds in the local tensor network defined by ``tids`` according to one of several strategies. """ - if max_iterations == 'max_distance': + if max_iterations == "max_distance": max_iterations = max_distance tn_loc = self._select_local_tids( - tids, max_distance=max_distance, inwards=inwards, - virtual=True, include=include, exclude=exclude + tids, + max_distance=max_distance, + inwards=inwards, + virtual=True, + include=include, + exclude=exclude, ) if method == "canonize": tn_loc.gauge_all_canonize_( - max_iterations=max_iterations, **gauge_local_opts) + max_iterations=max_iterations, **gauge_local_opts + ) elif method == "simple": tn_loc.gauge_all_simple_( - max_iterations=max_iterations, **gauge_local_opts) + max_iterations=max_iterations, **gauge_local_opts + ) elif method == "random": tn_loc.gauge_all_random_(**gauge_local_opts) @@ -6924,12 +7096,12 @@ def _gauge_local_tids( def gauge_local( self, tags, - which='all', + which="all", max_distance=1, - max_iterations='max_distance', - method='canonize', + max_iterations="max_distance", + method="canonize", inplace=False, - **gauge_local_opts + **gauge_local_opts, ): """Iteratively gauge all bonds in the tagged sub tensor network according to one of several strategies. @@ -6937,8 +7109,12 @@ def gauge_local( tn = self if inplace else self.copy() tids = self._get_tids_from_tags(tags, which) tn._gauge_local_tids( - tids, max_distance=max_distance, max_iterations=max_iterations, - method=method, **gauge_local_opts) + tids, + max_distance=max_distance, + max_iterations=max_iterations, + method=method, + **gauge_local_opts, + ) return tn gauge_local_ = functools.partialmethod(gauge_local, inplace=True) @@ -6985,8 +7161,8 @@ def gauge_simple_insert( g = _get(ix, None) if g is None: continue - g = (g + smudge * g[0])**power - t, = self._inds_get(ix) + g = (g + smudge * g[0]) ** power + (t,) = self._inds_get(ix) t.multiply_index_diagonal_(ix, g) outer.append((t, ix, g)) @@ -7146,8 +7322,12 @@ def _contract_compressed_tid_sequence( if gauge_boundary_only: data_like = next(iter(tn.tensor_map.values())).data gauges = { - ix: do("ones", (tn.ind_size(ix),), - dtype=data_like.dtype, like=data_like) + ix: do( + "ones", + (tn.ind_size(ix),), + dtype=data_like.dtype, + like=data_like, + ) for ix in tn.inner_inds() } else: @@ -7161,19 +7341,19 @@ def _contract_compressed_tid_sequence( # options relating to locally canonizing around each compression if canonize_distance: canonize_opts = ensure_dict(canonize_opts) - canonize_opts.setdefault('equalize_norms', equalize_norms) + canonize_opts.setdefault("equalize_norms", equalize_norms) if gauge_boundary_only: - canonize_opts['include'] = boundary + canonize_opts["include"] = boundary else: - canonize_opts['include'] = None + canonize_opts["include"] = None # options relating to canonizing around tensors *after* compression if canonize_after_distance: canonize_after_opts = ensure_dict(canonize_after_opts) if gauge_boundary_only: - canonize_after_opts['include'] = boundary + canonize_after_opts["include"] = boundary else: - canonize_after_opts['include'] = None + canonize_after_opts["include"] = None def _do_contraction(tid1, tid2): """The inner closure that contracts the two tensors identified by @@ -7184,7 +7364,10 @@ def _do_contraction(tid1, tid2): # new tensor is now at ``tid2`` tn._contract_between_tids( - tid1, tid2, equalize_norms=equalize_norms, gauges=gauges, + tid1, + tid2, + equalize_norms=equalize_norms, + gauges=gauges, ) # update the boundary @@ -7221,9 +7404,9 @@ def _should_skip_compression(tid1, tid2): return True if ( - (not compress_matrices) and - (len(tn._get_neighbor_tids([tid1])) <= 2) and - (len(tn._get_neighbor_tids([tid1])) <= 2) + (not compress_matrices) + and (len(tn._get_neighbor_tids([tid1])) <= 2) + and (len(tn._get_neighbor_tids([tid1])) <= 2) ): # both are effectively matrices return True @@ -7242,12 +7425,14 @@ def _should_skip_compression(tid1, tid2): if callable(max_bond): chi_fn = max_bond else: + def chi_fn(d): return max_bond if callable(cutoff): eps_fn = cutoff else: + def eps_fn(d): return cutoff @@ -7263,7 +7448,6 @@ def _compress_neighbors(tid, t, d): return for tid_neighb in tn._get_neighbor_tids(tid): - # first just check for accumulation of small multi-bonds t_neighb = tn.tensor_map[tid_neighb] tensor_fuse_squeeze(t, t_neighb, gauges=gauges) @@ -7288,7 +7472,7 @@ def _compress_neighbors(tid, t, d): equalize_norms=equalize_norms, gauges=gauges, gauge_smudge=gauge_smudge, - **compress_opts + **compress_opts, ) if callback_post_compress is not None: @@ -7298,6 +7482,7 @@ def _compress_neighbors(tid, t, d): if progbar: import tqdm + max_size = 0.0 pbar = tqdm.tqdm(total=num_contractions) else: @@ -7309,14 +7494,14 @@ def _compress_neighbors(tid, t, d): tid1, tid2, *maybe_d = seq[i] if maybe_d: - d, = maybe_d + (d,) = maybe_d else: - d = float('inf') + d = float("inf") if compress_span: # only keep track of the next few contractions to ignore # (note if False whole seq is already excluded) - for s in seq[i + compress_span - 1:i + compress_span]: + for s in seq[i + compress_span - 1 : i + compress_span]: dont_compress_pairs.add(frozenset(s[:2])) if compress_late: @@ -7331,7 +7516,8 @@ def _compress_neighbors(tid, t, d): new_size = math.log2(t_new.size) max_size = max(max_size, new_size) pbar.set_description( - f"log2[SIZE]: {new_size:.2f}/{max_size:.2f}") + f"log2[SIZE]: {new_size:.2f}/{max_size:.2f}" + ) pbar.update() if not compress_late: @@ -7377,23 +7563,25 @@ def _contract_around_tids( tids, min_distance=min_distance, max_distance=max_distance, - **span_opts) + **span_opts, + ) canonize_opts = ensure_dict(canonize_opts) - canonize_opts['exclude'] = oset(itertools.chain( - canonize_opts.get('exclude', ()), tids - )) + canonize_opts["exclude"] = oset( + itertools.chain(canonize_opts.get("exclude", ()), tids) + ) return self._contract_compressed_tid_sequence( seq, max_bond=max_bond, cutoff=cutoff, compress_exclude=tids, - **kwargs + **kwargs, ) def compute_centralities(self): import cotengra as ctg + hg = ctg.get_hypergraph( {tid: t.inds for tid, t in self.tensor_map.items()} ) @@ -7424,7 +7612,7 @@ def contract_around_corner(self, **opts): def contract_around( self, tags, - which='all', + which="all", min_distance=0, max_distance=None, span_opts=None, @@ -7450,7 +7638,7 @@ def contract_around( callback_post_compress=None, callback=None, inplace=False, - **kwargs + **kwargs, ): """Perform a compressed contraction inwards towards the tensors identified by ``tags``. @@ -7484,7 +7672,7 @@ def contract_around( callback_post_compress=callback_post_compress, callback=callback, inplace=inplace, - **kwargs + **kwargs, ) contract_around_ = functools.partialmethod(contract_around, inplace=True) @@ -7516,7 +7704,7 @@ def contract_compressed( callback_post_compress=None, callback=None, progbar=False, - **kwargs + **kwargs, ): path = self.contraction_path(optimize, output_inds=output_inds) @@ -7559,11 +7747,12 @@ def contract_compressed( callback_post_compress=callback_post_compress, callback=callback, progbar=progbar, - **kwargs + **kwargs, ) contract_compressed_ = functools.partialmethod( - contract_compressed, inplace=True) + contract_compressed, inplace=True + ) def new_bond(self, tags1, tags2, **opts): """Inplace addition of a dummmy (size 1) bond between the single @@ -7582,13 +7771,13 @@ def new_bond(self, tags1, tags2, **opts): -------- new_bond """ - tid1, = self._get_tids_from_tags(tags1, which='all') - tid2, = self._get_tids_from_tags(tags2, which='all') + (tid1,) = self._get_tids_from_tags(tags1, which="all") + (tid2,) = self._get_tids_from_tags(tags2, which="all") new_bond(self.tensor_map[tid1], self.tensor_map[tid2], **opts) def _cut_between_tids(self, tid1, tid2, left_ind, right_ind): TL, TR = self.tensor_map[tid1], self.tensor_map[tid2] - bnd, = bonds(TL, TR) + (bnd,) = bonds(TL, TR) TL.reindex_({bnd: left_ind}) TR.reindex_({bnd: right_ind}) @@ -7597,8 +7786,8 @@ def cut_between(self, left_tags, right_tags, left_ind, right_ind): ``right_tags``, giving them the new inds ``left_ind`` and ``right_ind`` respectively. """ - tid1, = self._get_tids_from_tags(left_tags) - tid2, = self._get_tids_from_tags(right_tags) + (tid1,) = self._get_tids_from_tags(left_tags) + (tid2,) = self._get_tids_from_tags(right_tags) self._cut_between_tids(tid1, tid2, left_ind, right_ind) def isel(self, selectors, inplace=False): @@ -7642,7 +7831,7 @@ def sum_reduce(self, ind, inplace=False): Whether to perform the reduction inplace. """ tn = self if inplace else self.copy() - t, = tn._inds_get(ind) + (t,) = tn._inds_get(ind) t.sum_reduce_(ind) return tn @@ -7666,7 +7855,7 @@ def vector_reduce(self, ind, v, inplace=False): TensorNetwork """ tn = self if inplace else self.copy() - t, = tn._inds_get(ind) + (t,) = tn._inds_get(ind) t.vector_reduce_(ind, v) return tn @@ -7737,12 +7926,14 @@ def insert_operator(self, A, where1, where2, tags=None, inplace=False): d = A.shape[0] T1, T2 = tn[where1], tn[where2] - bnd, = bonds(T1, T2) + (bnd,) = bonds(T1, T2) db = T1.ind_size(bnd) if d != db: - raise ValueError(f"This operator has dimension {d} but needs " - f"dimension {db}.") + raise ValueError( + f"This operator has dimension {d} but needs " + f"dimension {db}." + ) # reindex one tensor, and add a new A tensor joining the bonds nbnd = rand_uuid() @@ -7766,14 +7957,14 @@ def _insert_gauge_tids( t1, t2 = self._tids_get(tid1, tid2) if bond is None: - bond, = t1.bonds(t2) + (bond,) = t1.bonds(t2) if Uinv is None: - Uinv = do('linalg.inv', U) + Uinv = do("linalg.inv", U) # if we get wildly larger inverse due to singular U, try pseudo-inv if vdot(Uinv, Uinv) / vdot(U, U) > 1 / tol: - Uinv = do('linalg.pinv', U, rcond=tol**0.5) + Uinv = do("linalg.pinv", U, rcond=tol**0.5) # if still wildly larger inverse raise an error if vdot(Uinv, Uinv) / vdot(U, U) > 1 / tol: @@ -7800,8 +7991,8 @@ def insert_gauge(self, U, where1, where2, Uinv=None, tol=1e-10): The inverse gauge, ``U @ Uinv == Uinv @ U == eye``, to insert. If not given will be calculated using :func:`numpy.linalg.inv`. """ - tid1, = self._get_tids_from_tags(where1, which='all') - tid2, = self._get_tids_from_tags(where2, which='all') + (tid1,) = self._get_tids_from_tags(where1, which="all") + (tid2,) = self._get_tids_from_tags(where2, which="all") self._insert_gauge_tids(U, tid1, tid2, Uinv=Uinv, tol=tol) # ----------------------- contracting the network ----------------------- # @@ -7809,14 +8000,14 @@ def insert_gauge(self, U, where1, where2, Uinv=None, tol=1e-10): def contract_tags( self, tags, - which='any', + which="any", output_inds=None, optimize=None, get=None, backend=None, preserve_tensor=False, inplace=False, - **contract_opts + **contract_opts, ): """Contract the tensors that match any or all of ``tags``. @@ -7897,9 +8088,7 @@ def contract_tags( # whether we should let tensor_contract return a raw scalar preserve_tensor = ( - preserve_tensor or - inplace or - (untagged_tn.num_tensors >= 1) + preserve_tensor or inplace or (untagged_tn.num_tensors >= 1) ) contracted = tensor_contract( @@ -7909,7 +8098,7 @@ def contract_tags( get=get, backend=backend, preserve_tensor=preserve_tensor, - **contract_opts + **contract_opts, ) if (untagged_tn.num_tensors == 0) and (not inplace): @@ -8005,24 +8194,25 @@ def contract( -------- contract_tags, contract_cumulative """ - opts['output_inds'] = output_inds - opts['optimize'] = optimize - opts['get'] = get - opts['backend'] = backend - opts['preserve_tensor'] = preserve_tensor + opts["output_inds"] = output_inds + opts["optimize"] = optimize + opts["get"] = get + opts["backend"] = backend + opts["preserve_tensor"] = preserve_tensor all_tags = (tags is all) or (tags is ...) if max_bond is not None: if not all_tags: raise NotImplementedError - if opts.pop('get', None) is not None: + if opts.pop("get", None) is not None: raise NotImplementedError - if opts.pop('backend', None) is not None: + if opts.pop("backend", None) is not None: raise NotImplementedError return self.contract_compressed( - max_bond=max_bond, inplace=inplace, **opts) + max_bond=max_bond, inplace=inplace, **opts + ) # this checks whether certain TN classes have a manually specified # contraction pattern (e.g. 1D along the line) @@ -8046,7 +8236,7 @@ def contract_cumulative( preserve_tensor=False, equalize_norms=False, inplace=False, - **opts + **opts, ): """Cumulative contraction of tensor network. Contract the first set of tags, then that set with the next set, then both of those with the next @@ -8087,7 +8277,7 @@ def contract_cumulative( c_tags |= tags_to_oset(tags) # peform the next contraction - tn.contract_tags_(c_tags, which='any', **opts) + tn.contract_tags_(c_tags, which="any", **opts) if tn.num_tensors == 1: # nothing more to contract @@ -8109,7 +8299,8 @@ def contraction_path(self, optimize=None, **contract_opts): if optimize is None: optimize = get_contract_strategy() return self.contract( - all, optimize=optimize, get='path', **contract_opts) + all, optimize=optimize, get="path", **contract_opts + ) def contraction_info(self, optimize=None, **contract_opts): """Compute the ``opt_einsum.PathInfo`` object decsribing the @@ -8119,7 +8310,8 @@ def contraction_info(self, optimize=None, **contract_opts): if optimize is None: optimize = get_contract_strategy() return self.contract( - all, optimize=optimize, get='path-info', **contract_opts) + all, optimize=optimize, get="path-info", **contract_opts + ) def contraction_tree( self, @@ -8162,28 +8354,23 @@ def contraction_cost(self, optimize=None, **contract_opts): return tree.contraction_cost() def __rshift__(self, tags_seq): - """Overload of '>>' for TensorNetwork.contract_cumulative. - """ + """Overload of '>>' for TensorNetwork.contract_cumulative.""" return self.contract_cumulative(tags_seq) def __irshift__(self, tags_seq): - """Overload of '>>=' for inplace TensorNetwork.contract_cumulative. - """ + """Overload of '>>=' for inplace TensorNetwork.contract_cumulative.""" return self.contract_cumulative(tags_seq, inplace=True) def __xor__(self, tags): - """Overload of '^' for TensorNetwork.contract. - """ + """Overload of '^' for TensorNetwork.contract.""" return self.contract(tags) def __ixor__(self, tags): - """Overload of '^=' for inplace TensorNetwork.contract. - """ + """Overload of '^=' for inplace TensorNetwork.contract.""" return self.contract(tags, inplace=True) def __matmul__(self, other): - """Overload "@" to mean full contraction with another network. - """ + """Overload "@" to mean full contraction with another network.""" return TensorNetwork((self, other)) ^ ... def as_network(self, virtual=True): @@ -8193,14 +8380,27 @@ def as_network(self, virtual=True): """ return self if virtual else self.copy() - - def aslinearoperator(self, left_inds, right_inds, ldims=None, rdims=None, - backend=None, optimize=None): + def aslinearoperator( + self, + left_inds, + right_inds, + ldims=None, + rdims=None, + backend=None, + optimize=None, + ): """View this ``TensorNetwork`` as a :class:`~quimb.tensor.tensor_core.TNLinearOperator`. """ - return TNLinearOperator(self, left_inds, right_inds, ldims, rdims, - optimize=optimize, backend=backend) + return TNLinearOperator( + self, + left_inds, + right_inds, + ldims, + rdims, + optimize=optimize, + backend=backend, + ) @functools.wraps(tensor_split) def split(self, left_inds, right_inds=None, **split_opts): @@ -8217,8 +8417,7 @@ def split(self, left_inds, right_inds=None, **split_opts): return T.split(**split_opts) def trace(self, left_inds, right_inds, **contract_opts): - """Trace over ``left_inds`` joined with ``right_inds`` - """ + """Trace over ``left_inds`` joined with ``right_inds``""" tn = self.reindex({u: l for u, l in zip(left_inds, right_inds)}) return tn.contract_tags(..., **contract_opts) @@ -8227,12 +8426,12 @@ def to_dense(self, *inds_seq, to_qarray=False, **contract_opts): for each of inds in ``inds_seqs``. E.g. to convert several sites into a density matrix: ``TN.to_dense(('k0', 'k1'), ('b0', 'b1'))``. """ - tags = contract_opts.pop('tags', all) + tags = contract_opts.pop("tags", all) t = self.contract( tags, output_inds=tuple(concat(inds_seq)), preserve_tensor=True, - **contract_opts + **contract_opts, ) return t.to_dense(*inds_seq, to_qarray=to_qarray) @@ -8299,12 +8498,14 @@ def compute_reduced_factor( # contract to dense array tnd = self.reindex(ixmap).conj_() & self XX = tnd.to_dense( - ixmap.values(), ixmap.keys(), - optimize=optimize, **contract_opts + ixmap.values(), ixmap.keys(), optimize=optimize, **contract_opts ) return decomp.squared_op_to_reduced_factor( - XX, d0, d1, right=(side == "right"), + XX, + d0, + d1, + right=(side == "right"), ) def insert_compressor_between_regions( @@ -8430,11 +8631,11 @@ def distance(self, *args, **kwargs): def fit( self, tn_target, - method='als', + method="als", tol=1e-9, inplace=False, progbar=False, - **fitting_opts + **fitting_opts, ): r"""Optimize the entries of this tensor network with respect to a least squares fit of ``tn_target`` which should have the same outer indices. @@ -8480,14 +8681,14 @@ def fit( tensor_network_fit_als, tensor_network_fit_autodiff, tensor_network_distance """ - check_opt('method', method, ('als', 'autodiff')) - fitting_opts['tol'] = tol - fitting_opts['inplace'] = inplace - fitting_opts['progbar'] = progbar + check_opt("method", method, ("als", "autodiff")) + fitting_opts["tol"] = tol + fitting_opts["inplace"] = inplace + fitting_opts["progbar"] = progbar tn_target = tn_target.as_network() - if method == 'autodiff': + if method == "autodiff": return tensor_network_fit_autodiff(self, tn_target, **fitting_opts) return tensor_network_fit_als(self, tn_target, **fitting_opts) @@ -8500,24 +8701,20 @@ def tags(self): return oset(self.tag_map) def all_inds(self): - """Return a tuple of all indices in this network. - """ + """Return a tuple of all indices in this network.""" return tuple(self.ind_map) def ind_size(self, ind): - """Find the size of ``ind``. - """ + """Find the size of ``ind``.""" tid = next(iter(self.ind_map[ind])) return self.tensor_map[tid].ind_size(ind) def inds_size(self, inds): - """Return the total size of dimensions corresponding to ``inds``. - """ + """Return the total size of dimensions corresponding to ``inds``.""" return prod(map(self.ind_size, inds)) def ind_sizes(self): - """Get dict of each index mapped to its size. - """ + """Get dict of each index mapped to its size.""" return {i: self.ind_size(i) for i in self.ind_map} def inner_inds(self): @@ -8580,11 +8777,7 @@ def get_multibonds( if ix not in exclude: seen[tuple(sorted(tids))].append(ix) - return { - tuple(ixs): tids - for tids, ixs in seen.items() - if len(ixs) > 1 - } + return {tuple(ixs): tids for tids, ixs in seen.items() if len(ixs) > 1} def get_hyperinds(self, output_inds=None): """Get a tuple of all 'hyperinds', defined as those indices which don't @@ -8615,7 +8808,8 @@ def get_hyperinds(self, output_inds=None): output_inds = tags_to_oset(output_inds) return tuple( - ix for ix, tids in self.ind_map.items() + ix + for ix, tids in self.ind_map.items() if (len(tids) + int(ix in output_inds)) != 2 ) @@ -8627,14 +8821,15 @@ def compute_contracted_inds(self, *tids, output_inds=None): output_inds = self._outer_inds # number of times each index appears on tensors - freqs = frequencies(concat( - self.tensor_map[tid].inds for tid in tids - )) + freqs = frequencies(concat(self.tensor_map[tid].inds for tid in tids)) return tuple( - ix for ix, c in freqs.items() if + ix + for ix, c in freqs.items() + if # ind also appears elsewhere -> keep - (c != len(self.ind_map[ix])) or + (c != len(self.ind_map[ix])) + or # explicitly in output -> keep (ix in output_inds) ) @@ -8676,7 +8871,7 @@ def squeeze( squeeze_ = functools.partialmethod(squeeze, inplace=True) - def isometrize(self, method='qr', allow_no_left_inds=False, inplace=False): + def isometrize(self, method="qr", allow_no_left_inds=False, inplace=False): """Project every tensor in this network into an isometric form, assuming they have ``left_inds`` marked. @@ -8728,14 +8923,16 @@ def isometrize(self, method='qr', allow_no_left_inds=False, inplace=False): if t.left_inds is None: if allow_no_left_inds: continue - raise ValueError("The tensor {} doesn't have left indices " - "marked using the `left_inds` attribute.") + raise ValueError( + "The tensor {} doesn't have left indices " + "marked using the `left_inds` attribute." + ) t.isometrize_(method=method) return tn isometrize_ = functools.partialmethod(isometrize, inplace=True) - unitize = deprecated(isometrize, 'unitize', 'isometrize') - unitize_ = deprecated(isometrize_, 'unitize_', 'isometrize_') + unitize = deprecated(isometrize, "unitize", "isometrize") + unitize_ = deprecated(isometrize_, "unitize_", "isometrize_") def randomize(self, dtype=None, seed=None, inplace=False, **randn_opts): """Randomize every tensor in this TN - see @@ -8792,14 +8989,14 @@ def strip_exponent(self, tid_or_tensor, value=None): stripped_factor = t.norm() / value t.modify(apply=lambda data: data / stripped_factor) - self.exponent = self.exponent + do('log10', stripped_factor) + self.exponent = self.exponent + do("log10", stripped_factor) def distribute_exponent(self): """Distribute the exponent ``p`` of this tensor network (i.e. corresponding to ``tn * 10**p``) equally among all tensors. """ # multiply each tensor by the nth root of 10**exponent - x = 10**(self.exponent / self.num_tensors) + x = 10 ** (self.exponent / self.num_tensors) self.multiply_each_(x) # reset the exponent to zero @@ -8934,23 +9131,30 @@ def expand_bond_dimension( for t in tn._inds_get(*inds_to_expand): # perform the array expansions pads = [ - (0, 0) if ind not in inds_to_expand else - (0, max(new_bond_dim - d, 0)) + (0, 0) + if ind not in inds_to_expand + else (0, max(new_bond_dim - d, 0)) for d, ind in zip(t.shape, t.inds) ] if rand_strength > 0: - edata = do('pad', t.data, pads, mode=rand_padder, - rand_strength=rand_strength) + edata = do( + "pad", + t.data, + pads, + mode=rand_padder, + rand_strength=rand_strength, + ) else: - edata = do('pad', t.data, pads, mode='constant') + edata = do("pad", t.data, pads, mode="constant") t.modify(data=edata) return tn expand_bond_dimension_ = functools.partialmethod( - expand_bond_dimension, inplace=True) + expand_bond_dimension, inplace=True + ) def flip(self, inds, inplace=False): """Flip the dimension corresponding to indices ``inds`` on all tensors @@ -9016,7 +9220,6 @@ def rank_simplify( scalars = [] count = collections.Counter() for tid, t in tuple(tn.tensor_map.items()): - # remove floating scalar tensors --> # these have no indices so won't be caught otherwise if t.ndim == 0: @@ -9038,14 +9241,16 @@ def rank_simplify( # sorted list of unique indices to check -> start with lowly connected def rank_weight(ind): - return (tn.ind_size(ind), -sum(tn.tensor_map[tid].ndim - for tid in tn.ind_map[ind])) + return ( + tn.ind_size(ind), + -sum(tn.tensor_map[tid].ndim for tid in tn.ind_map[ind]), + ) queue = oset(sorted(count, key=rank_weight)) # number of tensors for which there will be more pairwise combinations # than max_combinations - combi_cutoff = int(0.5 * ((8 * max_combinations + 1)**0.5 + 1)) + combi_cutoff = int(0.5 * ((8 * max_combinations + 1) ** 0.5 + 1)) while queue: # get next index @@ -9060,7 +9265,7 @@ def rank_weight(ind): # index only appears on one tensor and not in output -> can sum if count[ind] == 1: - tid, = tids + (tid,) = tids t = tn.tensor_map[tid] t.sum_reduce_(ind) @@ -9081,11 +9286,10 @@ def rank_weight(ind): tids = sorted(tids, key=lambda tid: tn.tensor_map[tid].ndim) for tid_a, tid_b in itertools.combinations(tids, 2): - ta = tn.tensor_map[tid_a] tb = tn.tensor_map[tid_b] - cache_key = ('rs', tid_a, tid_b, id(ta.data), id(tb.data)) + cache_key = ("rs", tid_a, tid_b, id(ta.data), id(tb.data)) if cache_key in cache: continue @@ -9148,7 +9352,7 @@ def rank_weight(ind): signs = [] for s in scalars: signs.append(do("sign", s)) - tn.exponent += do("log10", do('abs', s)) + tn.exponent += do("log10", do("abs", s)) scalars = signs if tn.num_tensors: @@ -9211,7 +9415,7 @@ def diagonal_reduce( tid = queue.pop() t = tn.tensor_map[tid] - cache_key = ('dr', tid, id(t.data)) + cache_key = ("dr", tid, id(t.data)) if cache_key in cache: continue @@ -9296,7 +9500,7 @@ def antidiag_gauge( tid = queue.pop() t = tn.tensor_map[tid] - cache_key = ('ag', tid, id(t.data)) + cache_key = ("ag", tid, id(t.data)) if cache_key in cache: continue @@ -9378,7 +9582,7 @@ def column_reduce( tid = queue.pop() t = tn.tensor_map[tid] - cache_key = ('cr', tid, id(t.data)) + cache_key = ("cr", tid, id(t.data)) if cache_key in cache: continue @@ -9435,9 +9639,8 @@ def split_simplify( cache = set() for tid, t in tuple(tn.tensor_map.items()): - # id's are reused when objects go out of scope -> use tid as well - cache_key = ('sp', tid, id(t.data)) + cache_key = ("sp", tid, id(t.data)) if cache_key in cache: continue @@ -9446,7 +9649,7 @@ def split_simplify( tl, tr = t.split( lix, right_inds=rix, - get='tensors', + get="tensors", cutoff=atol, **split_opts, ) @@ -9486,9 +9689,10 @@ def gen_loops(self, max_loop_length=None): tuple[int] """ from cotengra.core import get_hypergraph + inputs = {tid: t.inds for tid, t in self.tensor_map.items()} - hg = get_hypergraph(inputs, accel='auto') - return hg.compute_loops(max_loop_length) + hg = get_hypergraph(inputs, accel="auto") + return hg.compute_loops(max_loop_length=max_loop_length) def _get_string_between_tids(self, tida, tidb): strings = [(tida,)] @@ -9517,10 +9721,7 @@ def tids_are_connected(self, tids): """ enum = range(len(tids)) groups = dict(zip(enum, enum)) - regions = [ - (oset([tid]), self._get_neighbor_tids(tid)) - for tid in tids - ] + regions = [(oset([tid]), self._get_neighbor_tids(tid)) for tid in tids] for i, j in itertools.combinations(enum, 2): mi = groups.get(i, i) mj = groups.get(j, j) @@ -9566,9 +9767,9 @@ def compute_shortest_distances(self, tids=None, exclude_inds=()): for diff_tid in visitors[tid] - old_visitors[tid]: any_change = True if ( - (tid in tids) and - (diff_tid in tids) and - (tid < diff_tid) + (tid in tids) + and (diff_tid in tids) + and (tid < diff_tid) ): distances[tid, diff_tid] = d @@ -9582,7 +9783,7 @@ def compute_shortest_distances(self, tids=None, exclude_inds=()): def compute_hierarchical_linkage( self, tids=None, - method='weighted', + method="weighted", optimal_ordering=True, exclude_inds=(), ): @@ -9593,11 +9794,10 @@ def compute_hierarchical_linkage( try: from cotengra import get_hypergraph + hg = get_hypergraph( - { - tid: t.inds - for tid, t in self.tensor_map.items() - }, accel="auto", + {tid: t.inds for tid, t in self.tensor_map.items()}, + accel="auto", ) for ix in exclude_inds: hg.remove_edge(ix) @@ -9610,7 +9810,7 @@ def compute_hierarchical_linkage( distances = self.compute_shortest_distances(tids, exclude_inds) - dinf = 10 * self.num_tensors + dinf = 10 * self.num_tensors y = [ distances.get(tuple(sorted((i, j))), dinf) for i, j in itertools.combinations(tids, 2) @@ -9623,18 +9823,20 @@ def compute_hierarchical_linkage( def compute_hierarchical_ssa_path( self, tids=None, - method='weighted', + method="weighted", optimal_ordering=True, exclude_inds=(), are_sorted=False, linkage=None, ): - """Compute a hierarchical grouping of ``tids``, as a ``ssa_path``. - """ + """Compute a hierarchical grouping of ``tids``, as a ``ssa_path``.""" if linkage is None: linkage = self.compute_hierarchical_linkage( - tids, method=method, exclude_inds=exclude_inds, - optimal_ordering=optimal_ordering) + tids, + method=method, + exclude_inds=exclude_inds, + optimal_ordering=optimal_ordering, + ) sorted_ssa_path = ((int(x[0]), int(x[1])) for x in linkage) if are_sorted: @@ -9654,7 +9856,7 @@ def compute_hierarchical_ssa_path( def compute_hierarchical_ordering( self, tids=None, - method='weighted', + method="weighted", optimal_ordering=True, exclude_inds=(), linkage=None, @@ -9666,8 +9868,11 @@ def compute_hierarchical_ordering( if linkage is None: linkage = self.compute_hierarchical_linkage( - tids, method=method, exclude_inds=exclude_inds, - optimal_ordering=optimal_ordering) + tids, + method=method, + exclude_inds=exclude_inds, + optimal_ordering=optimal_ordering, + ) node2tid = {i: tid for i, tid in enumerate(sorted(tids))} return tuple(map(node2tid.__getitem__, hierarchy.leaves_list(linkage))) @@ -9676,7 +9881,7 @@ def compute_hierarchical_grouping( self, max_group_size, tids=None, - method='weighted', + method="weighted", optimal_ordering=True, exclude_inds=(), linkage=None, @@ -9691,12 +9896,18 @@ def compute_hierarchical_grouping( if linkage is None: linkage = self.compute_hierarchical_linkage( - tids, method=method, exclude_inds=exclude_inds, - optimal_ordering=optimal_ordering) + tids, + method=method, + exclude_inds=exclude_inds, + optimal_ordering=optimal_ordering, + ) ssa_path = self.compute_hierarchical_ssa_path( - tids=tids, method=method, exclude_inds=exclude_inds, - are_sorted=True, linkage=linkage, + tids=tids, + method=method, + exclude_inds=exclude_inds, + are_sorted=True, + linkage=linkage, ) # follow ssa_path, agglomerating groups as long they small enough @@ -9718,13 +9929,15 @@ def compute_hierarchical_grouping( # now sort groups by when their nodes in leaf ordering ordering = self.compute_hierarchical_ordering( - tids=tids, method=method, exclude_inds=exclude_inds, - optimal_ordering=optimal_ordering, linkage=linkage, + tids=tids, + method=method, + exclude_inds=exclude_inds, + optimal_ordering=optimal_ordering, + linkage=linkage, ) score = {tid: i for i, tid in enumerate(ordering)} groups = sorted( - groups.items(), - key=lambda kv: sum(map(score.__getitem__, kv[1])) + groups.items(), key=lambda kv: sum(map(score.__getitem__, kv[1])) ) return tuple(kv[1] for kv in groups) @@ -9750,7 +9963,7 @@ def pair_simplify( def gen_pairs(): # number of tensors for which there will be more pairwise # combinations than max_combinations - combi_cutoff = int(0.5 * ((8 * max_combinations + 1)**0.5 + 1)) + combi_cutoff = int(0.5 * ((8 * max_combinations + 1) ** 0.5 + 1)) while queue: ind = queue.pop() @@ -9763,7 +9976,8 @@ def gen_pairs(): # sort size of the tensors so that when we are limited by # max_combinations we check likely ones first tids = sorted( - tids, key=lambda tid: tn.tensor_map[tid].ndim) + tids, key=lambda tid: tn.tensor_map[tid].ndim + ) for _, (tid1, tid2) in zip( range(max_combinations), @@ -9773,10 +9987,13 @@ def gen_pairs(): yield tid1, tid2 for pair in gen_pairs(): - if cache is not None: - key = ('pc', frozenset((tid, id(tn.tensor_map[tid].data)) - for tid in pair)) + key = ( + "pc", + frozenset( + (tid, id(tn.tensor_map[tid].data)) for tid in pair + ), + ) if key in cache: continue @@ -9787,15 +10004,21 @@ def gen_pairs(): # don't check exponentially many bipartitions continue - t12 = tensor_contract(t1, t2, output_inds=inds, - preserve_tensor=True) + t12 = tensor_contract( + t1, t2, output_inds=inds, preserve_tensor=True + ) current_size = t1.size + t2.size cands = [] for lix, rix in gen_bipartitions(inds): - tl, tr = t12.split(left_inds=lix, right_inds=rix, - get='tensors', cutoff=cutoff, **split_opts) - new_size = (tl.size + tr.size) + tl, tr = t12.split( + left_inds=lix, + right_inds=rix, + get="tensors", + cutoff=cutoff, + **split_opts, + ) + new_size = tl.size + tr.size if new_size < current_size: cands.append((new_size / current_size, pair, tl, tr)) @@ -9834,7 +10057,7 @@ def loop_simplify( cache=None, equalize_norms=False, inplace=False, - **split_opts + **split_opts, ): """Try and simplify this tensor network by identifying loops and checking for low-rank decompositions across groupings of the loops @@ -9877,8 +10100,12 @@ def loop_simplify( continue if cache is not None: - key = ('lp', frozenset((tid, id(tn.tensor_map[tid].data)) - for tid in loop)) + key = ( + "lp", + frozenset( + (tid, id(tn.tensor_map[tid].data)) for tid in loop + ), + ) if key in cache: continue @@ -9894,17 +10121,22 @@ def loop_simplify( for left_inds, right_inds in gen_bipartitions(oix): if not ( tn.tids_are_connected(self._get_tids_from_inds(left_inds)) - and - tn.tids_are_connected(self._get_tids_from_inds(right_inds)) + and tn.tids_are_connected( + self._get_tids_from_inds(right_inds) + ) ): continue tl, tr = tensor_split( - tloop, left_inds=left_inds, right_inds=right_inds, - get='tensors', cutoff=cutoff, **split_opts + tloop, + left_inds=left_inds, + right_inds=right_inds, + get="tensors", + cutoff=cutoff, + **split_opts, ) - new_size = (tl.size + tr.size) + new_size = tl.size + tr.size if new_size < current_size: cands.append((new_size / current_size, loop, tl, tr)) @@ -9932,7 +10164,7 @@ def loop_simplify( def full_simplify( self, - seq='ADCR', + seq="ADCR", output_inds=None, atol=1e-12, equalize_norms=False, @@ -10022,31 +10254,38 @@ def full_simplify( if progbar: import tqdm + pbar = tqdm.tqdm() - pbar.set_description(f'{nt}, {ni}') + pbar.set_description(f"{nt}, {ni}") while (nt, ni) != (old_nt, old_ni): for meth in seq: - if progbar: pbar.update() pbar.set_description( - f'{meth} {tn.num_tensors}, {tn.num_indices}') + f"{meth} {tn.num_tensors}, {tn.num_indices}" + ) if meth in custom_methods: custom_methods[meth]( - tn, output_inds=output_inds, atol=atol, cache=cache) - elif meth == 'D': - tn.diagonal_reduce_(output_inds=ix_o, atol=atol, - cache=cache) - elif meth == 'R': - tn.rank_simplify_(output_inds=ix_o, cache=cache, - equalize_norms=equalize_norms, - **rank_simplify_opts) - elif meth == 'A': - tn.antidiag_gauge_(output_inds=ix_o, atol=atol, - cache=cache) - elif meth == 'C': + tn, output_inds=output_inds, atol=atol, cache=cache + ) + elif meth == "D": + tn.diagonal_reduce_( + output_inds=ix_o, atol=atol, cache=cache + ) + elif meth == "R": + tn.rank_simplify_( + output_inds=ix_o, + cache=cache, + equalize_norms=equalize_norms, + **rank_simplify_opts, + ) + elif meth == "A": + tn.antidiag_gauge_( + output_inds=ix_o, atol=atol, cache=cache + ) + elif meth == "C": tn.column_reduce_(output_inds=ix_o, atol=atol, cache=cache) elif meth == 'S': tn.split_simplify_(atol=atol, cache=cache, @@ -10084,7 +10323,7 @@ def full_simplify( def hyperinds_resolve( self, - mode='dense', + mode="dense", sorter=None, output_inds=None, inplace=False, @@ -10108,15 +10347,16 @@ def hyperinds_resolve( ------- TensorNetwork """ - check_opt('mode', mode, ('dense', 'mps', 'tree')) + check_opt("mode", mode, ("dense", "mps", "tree")) tn = self if inplace else self.copy() if output_inds is None: output_inds = self.outer_inds() - if sorter == 'centrality': + if sorter == "centrality": from cotengra.cotengra import nodes_to_centrality + cents = nodes_to_centrality( {tid: t.inds for tid, t in tn.tensor_map.items()} ) @@ -10124,7 +10364,7 @@ def hyperinds_resolve( def sorter(tid): return cents[tid] - if sorter == 'clustering': + if sorter == "clustering": tn_orig = tn.copy() ssa_path = None @@ -10135,16 +10375,19 @@ def sorter(tid): d = tn.ind_size(ix) tids = list(tids) - if sorter == 'clustering': - - if mode == 'tree': + if sorter == "clustering": + if mode == "tree": tids.sort() ssa_path = tn_orig.compute_hierarchical_ssa_path( - tids, optimal_ordering=False, exclude_inds=(ix,), - are_sorted=True) + tids, + optimal_ordering=False, + exclude_inds=(ix,), + are_sorted=True, + ) else: tids = tn_orig.compute_hierarchical_ordering( - tids, optimal_ordering=True, exclude_inds=(ix,)) + tids, optimal_ordering=True, exclude_inds=(ix,) + ) elif sorter is not None: tids.sort(key=sorter) @@ -10161,31 +10404,39 @@ def sorter(tid): copy_inds.append(ix) # inject new tensor(s) to connect dangling inds - if mode == 'dense': + if mode == "dense": copy_tensors.append( - COPY_tensor(d=d, inds=copy_inds, dtype=t.dtype)) - elif mode == 'mps': + COPY_tensor(d=d, inds=copy_inds, dtype=t.dtype) + ) + elif mode == "mps": copy_tensors.extend( - COPY_mps_tensors(d=d, inds=copy_inds, dtype=t.dtype)) - elif mode == 'tree': + COPY_mps_tensors(d=d, inds=copy_inds, dtype=t.dtype) + ) + elif mode == "tree": copy_tensors.extend( - COPY_tree_tensors(d=d, inds=copy_inds, dtype=t.dtype, - ssa_path=ssa_path)) + COPY_tree_tensors( + d=d, + inds=copy_inds, + dtype=t.dtype, + ssa_path=ssa_path, + ) + ) tn.add(copy_tensors) return tn hyperinds_resolve_ = functools.partialmethod( - hyperinds_resolve, inplace=True) + hyperinds_resolve, inplace=True + ) def compress_simplify( self, output_inds=None, atol=1e-6, - simplify_sequence_a='ADCRS', - simplify_sequence_b='RPL', - hyperind_resolve_mode='tree', - hyperind_resolve_sort='clustering', + simplify_sequence_a="ADCRS", + simplify_sequence_b="RPL", + hyperind_resolve_mode="tree", + hyperind_resolve_sort="clustering", final_resolve=False, split_method="svd", max_simplification_iterations=100, @@ -10201,24 +10452,24 @@ def compress_simplify( output_inds = self.outer_inds() simplify_opts = { - 'atol': atol, - 'equalize_norms': equalize_norms, - 'progbar': progbar, - 'output_inds': output_inds, - 'cache': set(), - 'split_method': split_method, + "atol": atol, + "equalize_norms": equalize_norms, + "progbar": progbar, + "output_inds": output_inds, + "cache": set(), + "split_method": split_method, **full_simplify_opts, } # order of tensors when converting hyperinds if callable(hyperind_resolve_sort) or (hyperind_resolve_sort is None): sorter = hyperind_resolve_sort - elif hyperind_resolve_sort == 'centrality': + elif hyperind_resolve_sort == "centrality": from cotengra.cotengra import nodes_to_centrality def sorter(tid): return cents[tid] - elif hyperind_resolve_sort == 'random': + elif hyperind_resolve_sort == "random": import random def sorter(tid): @@ -10228,15 +10479,15 @@ def sorter(tid): sorter = hyperind_resolve_sort hyperresolve_opts = { - 'mode': hyperind_resolve_mode, - 'sorter': sorter, - 'output_inds': output_inds, + "mode": hyperind_resolve_mode, + "sorter": sorter, + "output_inds": output_inds, } tn.full_simplify_(simplify_sequence_a, **simplify_opts) for i in range(max_simplification_iterations): nv, ne = tn.num_tensors, tn.num_indices - if hyperind_resolve_sort == 'centrality': + if hyperind_resolve_sort == "centrality": # recompute centralities cents = nodes_to_centrality( {tid: t.inds for tid, t in tn.tensor_map.items()} @@ -10244,15 +10495,14 @@ def sorter(tid): tn.hyperinds_resolve_(**hyperresolve_opts) tn.full_simplify_(simplify_sequence_b, **simplify_opts) tn.full_simplify_(simplify_sequence_a, **simplify_opts) - if ( - (tn.num_tensors == 1) or - (tn.num_tensors > (1 - converged_tol) * nv and - tn.num_indices > (1 - converged_tol) * ne) + if (tn.num_tensors == 1) or ( + tn.num_tensors > (1 - converged_tol) * nv + and tn.num_indices > (1 - converged_tol) * ne ): break if final_resolve: - if hyperind_resolve_sort == 'centrality': + if hyperind_resolve_sort == "centrality": # recompute centralities cents = nodes_to_centrality( {tid: t.inds for tid, t in tn.tensor_map.items()} @@ -10263,17 +10513,16 @@ def sorter(tid): return tn compress_simplify_ = functools.partialmethod( - compress_simplify, inplace=True) + compress_simplify, inplace=True + ) def max_bond(self): - """Return the size of the largest bond in this network. - """ + """Return the size of the largest bond in this network.""" return max(t.max_dim() for t in self) @property def shape(self): - """Actual, i.e. exterior, shape of this TensorNetwork. - """ + """Actual, i.e. exterior, shape of this TensorNetwork.""" return tuple(di[0] for di in self.outer_dims_inds()) @property @@ -10287,8 +10536,7 @@ def iscomplex(self): return iscomplex(self) def astype(self, dtype, inplace=False): - """Convert the type of all tensors in this network to ``dtype``. - """ + """Convert the type of all tensors in this network to ``dtype``.""" TN = self if inplace else self.copy() for t in TN: t.astype(dtype, inplace=True) @@ -10299,15 +10547,13 @@ def astype(self, dtype, inplace=False): def __getstate__(self): # This allows pickling, by removing all tensor owner weakrefs d = self.__dict__.copy() - d['tensor_map'] = { - k: t.copy() for k, t in d['tensor_map'].items() - } + d["tensor_map"] = {k: t.copy() for k, t in d["tensor_map"].items()} return d def __setstate__(self, state): # This allows picklings, by restoring the returned TN as owner self.__dict__ = state.copy() - for tid, t in self.__dict__['tensor_map'].items(): + for tid, t in self.__dict__["tensor_map"].items(): t.add_owner(self, tid=tid) def _repr_info(self): @@ -10315,21 +10561,19 @@ def _repr_info(self): relevant info to this dict. """ return { - 'tensors': self.num_tensors, - 'indices': self.num_indices, + "tensors": self.num_tensors, + "indices": self.num_indices, } def _repr_info_str(self): - """Render the general info as a string. - """ + """Render the general info as a string.""" return ", ".join( "{}={}".format(k, f"'{v}'" if isinstance(v, str) else v) for k, v in self._repr_info().items() ) def _repr_html_(self): - """Render this TensorNetwork as HTML, for Jupyter notebooks. - """ + """Render this TensorNetwork as HTML, for Jupyter notebooks.""" s = "" s += "
" s += "" @@ -10348,29 +10592,24 @@ def _repr_html_(self): def __str__(self): return ( - f"{self.__class__.__name__}([{os.linesep}" + - "".join( - f" {repr(t)},{os.linesep}" - for t in self.tensors - ) + - f"], {self._repr_info_str()})" + f"{self.__class__.__name__}([{os.linesep}" + + "".join(f" {repr(t)},{os.linesep}" for t in self.tensors) + + f"], {self._repr_info_str()})" ) def __repr__(self): return f"{self.__class__.__name__}({self._repr_info_str()})" draw = draw_tn - draw_3d = functools.partialmethod( - draw, dim=3, backend='matplotlib3d' - ) - draw_interactive = functools.partialmethod( - draw, backend='plotly' - ) + draw_3d = functools.partialmethod(draw, dim=3, backend="matplotlib3d") + draw_interactive = functools.partialmethod(draw, backend="plotly") draw_3d_interactive = functools.partialmethod( - draw, dim=3, backend='plotly' + draw, dim=3, backend="plotly" ) graph = draw_tn + visualize_tensors = visualize_tensors + TNLO_HANDLED_FUNCTIONS = {} @@ -10419,8 +10658,17 @@ class TNLinearOperator(spla.LinearOperator): TNLinearOperator1D """ - def __init__(self, tns, left_inds, right_inds, ldims=None, rdims=None, - optimize=None, backend=None, is_conj=False): + def __init__( + self, + tns, + left_inds, + right_inds, + ldims=None, + rdims=None, + optimize=None, + backend=None, + is_conj=False, + ): if backend is None: self.backend = get_tensor_linop_backend() else: @@ -10470,14 +10718,18 @@ def _matvec(self, vec): in_data = conj(in_data) # cache the contractor - if 'matvec' not in self._contractors: + if "matvec" not in self._contractors: # generate a expression that acts directly on the data iT = Tensor(in_data, inds=self.right_inds) - self._contractors['matvec'] = tensor_contract( - *self._tensors, iT, output_inds=self.left_inds, - optimize=self.optimize, **self._kws) + self._contractors["matvec"] = tensor_contract( + *self._tensors, + iT, + output_inds=self.left_inds, + optimize=self.optimize, + **self._kws, + ) - fn = self._contractors['matvec'] + fn = self._contractors["matvec"] out_data = fn(*self._ins, in_data, backend=self.backend) if self.is_conj: @@ -10498,11 +10750,15 @@ def _matmat(self, mat): # cache the contractor if key not in self._contractors: # generate a expression that acts directly on the data - iT = Tensor(in_data, inds=(*self.right_inds, '_mat_ix')) - o_ix = (*self.left_inds, '_mat_ix') + iT = Tensor(in_data, inds=(*self.right_inds, "_mat_ix")) + o_ix = (*self.left_inds, "_mat_ix") self._contractors[key] = tensor_contract( - *self._tensors, iT, output_inds=o_ix, - optimize=self.optimize, **self._kws) + *self._tensors, + iT, + output_inds=o_ix, + optimize=self.optimize, + **self._kws, + ) fn = self._contractors[key] out_data = fn(*self._ins, in_data, backend=self.backend) @@ -10513,11 +10769,12 @@ def _matmat(self, mat): return do("reshape", out_data, (-1, d)) def trace(self): - if 'trace' not in self._contractors: + if "trace" not in self._contractors: tn = TensorNetwork(self._tensors) - self._contractors['trace'] = tn.trace( - self.left_inds, self.right_inds, optimize=self.optimize) - return self._contractors['trace'] + self._contractors["trace"] = tn.trace( + self.left_inds, self.right_inds, optimize=self.optimize + ) + return self._contractors["trace"] def copy(self, conj=False, transpose=False): if transpose: @@ -10532,8 +10789,14 @@ def copy(self, conj=False, transpose=False): else: is_conj = self.is_conj - return TNLinearOperator(self._tensors, *inds, *dims, is_conj=is_conj, - optimize=self.optimize, backend=self.backend) + return TNLinearOperator( + self._tensors, + *inds, + *dims, + is_conj=is_conj, + optimize=self.optimize, + backend=self.backend, + ) def conj(self): if self._conj_linop is None: @@ -10546,8 +10809,7 @@ def _transpose(self): return self._transpose_linop def _adjoint(self): - """Hermitian conjugate of this TNLO. - """ + """Hermitian conjugate of this TNLO.""" # cache the adjoint if self._adjoint_linop is None: self._adjoint_linop = self.copy(conj=True, transpose=True) @@ -10557,7 +10819,7 @@ def to_dense(self, *inds_seq, to_qarray=False, **contract_opts): """Convert this TNLinearOperator into a dense array, defaulting to grouping the left and right indices respectively. """ - contract_opts.setdefault('optimize', self.optimize) + contract_opts.setdefault("optimize", self.optimize) if self.is_conj: ts = (t.conj() for t in self._tensors) @@ -10568,34 +10830,40 @@ def to_dense(self, *inds_seq, to_qarray=False, **contract_opts): inds_seq = self.left_inds, self.right_inds return tensor_contract(*ts, **contract_opts).to_dense( - *inds_seq, to_qarray=to_qarray, + *inds_seq, + to_qarray=to_qarray, ) to_qarray = functools.partialmethod(to_dense, to_qarray=True) @functools.wraps(tensor_split) def split(self, **split_opts): - return tensor_split(self, left_inds=self.left_inds, - right_inds=self.right_inds, **split_opts) + return tensor_split( + self, + left_inds=self.left_inds, + right_inds=self.right_inds, + **split_opts, + ) @property def A(self): return self.to_dense() def astype(self, dtype): - """Convert this ``TNLinearOperator`` to type ``dtype``. - """ + """Convert this ``TNLinearOperator`` to type ``dtype``.""" return TNLinearOperator( (t.astype(dtype) for t in self._tensors), - left_inds=self.left_inds, right_inds=self.right_inds, - ldims=self.ldims, rdims=self.rdims, - optimize=self.optimize, backend=self.backend, + left_inds=self.left_inds, + right_inds=self.right_inds, + ldims=self.ldims, + rdims=self.rdims, + optimize=self.optimize, + backend=self.backend, ) def __array_function__(self, func, types, args, kwargs): - if ( - (func not in TNLO_HANDLED_FUNCTIONS) or - (not all(issubclass(t, self.__class__) for t in types)) + if (func not in TNLO_HANDLED_FUNCTIONS) or ( + not all(issubclass(t, self.__class__) for t in types) ): return NotImplemented return TNLO_HANDLED_FUNCTIONS[func](*args, **kwargs) @@ -10605,6 +10873,7 @@ def tnlo_implements(np_function): """Register an __array_function__ implementation for TNLinearOperator objects. """ + def decorator(func): TNLO_HANDLED_FUNCTIONS[np_function] = func return func @@ -10641,7 +10910,7 @@ class PTensor(Tensor): PTensor """ - __slots__ = ('_data', '_inds', '_tags', '_left_inds', '_owners') + __slots__ = ("_data", "_inds", "_tags", "_left_inds", "_owners") def __init__(self, fn, params, inds=(), tags=None, left_inds=None): super().__init__( @@ -10663,13 +10932,12 @@ def from_parray(cls, parray, inds=(), tags=None, left_inds=None): return obj def copy(self): - """Copy this parametrized tensor. - """ + """Copy this parametrized tensor.""" return PTensor.from_parray( self._data.copy(), inds=self.inds, tags=self.tags, - left_inds=self.left_inds + left_inds=self.left_inds, ) def _set_data(self, x): @@ -10679,7 +10947,8 @@ def _set_data(self, x): "another ``PArray``. You can chain another function with the " "``.modify(apply=fn)`` method. Alternatively you can convert " "this ``PTensor to a normal ``Tensor`` with " - "``t.unparametrize()``") + "``t.unparametrize()``" + ) self._data = x @property @@ -10695,13 +10964,11 @@ def fn(self, x): self._data.fn = x def get_params(self): - """Get the parameters of this ``PTensor``. - """ + """Get the parameters of this ``PTensor``.""" return self._data.params def set_params(self, params): - """Set the parameters of this ``PTensor``. - """ + """Set the parameters of this ``PTensor``.""" self._data.params = params @property @@ -10718,8 +10985,7 @@ def shape(self): @property def backend(self): - """The backend inferred from the data. - """ + """The backend inferred from the data.""" return infer_backend(self.params) def _apply_function(self, fn): @@ -10739,8 +11005,7 @@ def conj(self, inplace=False): conj_ = functools.partialmethod(conj, inplace=True) def unparametrize(self): - """Turn this PTensor into a normal Tensor. - """ + """Turn this PTensor into a normal Tensor.""" return Tensor( data=self.data, inds=self.inds, @@ -10763,7 +11028,7 @@ class IsoTensor(Tensor): when its data is changed. """ - __slots__ = ('_data', '_inds', '_tags', '_left_inds', '_owners') + __slots__ = ("_data", "_inds", "_tags", "_left_inds", "_owners") def modify(self, **kwargs): kwargs.setdefault("left_inds", self.left_inds)