From df192666a472ec51d236a31c471897d600fdf94e Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Wed, 24 Apr 2024 23:28:10 -0700 Subject: [PATCH] add MPO.from_dense and MPO auto filling --- docs/changelog.md | 17 ++ quimb/tensor/tensor_1d.py | 281 ++++++++++++++++++++++---- quimb/tensor/tensor_arbgeom.py | 6 +- quimb/tensor/tensor_builder.py | 49 ++--- quimb/tensor/tensor_core.py | 115 +++++++++++ tests/test_tensor/test_tensor_1d.py | 42 +++- tests/test_tensor/test_tensor_core.py | 14 ++ 7 files changed, 458 insertions(+), 66 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index c96f681d..edffb7cf 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -3,6 +3,23 @@ Release notes for `quimb`. +(whats-new-1-8-1)= +## v1.8.1 (unreleased) + +**Enhancements:** + +- add [`MatrixProductOperator.from_dense`](quimb.tensor.tensor_1d.MatrixProductOperator.from_dense) for constructing MPOs from dense matrices, including an only subset of sites +- add [`MatrixProductOperator.fill_empty_sites_with_identities`](quimb.tensor.tensor_1d.MatrixProductOperator.fill_empty_sites_with_identities) for 'completing' an MPO which only has tensors on a subset of sites with identities +- add [`TensorNetwork.drape_bond_between`](quimb.tensor.tensor_core.TensorNetwork.drape_bond_between) for 'draping' an existing bond between two tensors through a third +- add [`Tensor.new_ind_pair_with_identity`](quimb.tensor.tensor_core.Tensor.new_ind_pair_with_identity) +- TN2D, TN3D and arbitrary geom classical partition function builders now all support `outputs=` kwarg specifying non-marginalized variables +- add simple dense 1-norm belief propagation algorithm [`D1BP`](quimb.experimental.belief_propagation.d1bp.D1BP) + +**Bug fixes:** + +- [`Circuit.apply_gate_raw`](quimb.tensor.circuit.Circuit.apply_gate_raw): fix kwarg bug ({pull}`226`) + + (whats-new-1-8-0)= ## v1.8.0 (2024-04-10) diff --git a/quimb/tensor/tensor_1d.py b/quimb/tensor/tensor_1d.py index ce7fa08a..91b7b9bc 100644 --- a/quimb/tensor/tensor_1d.py +++ b/quimb/tensor/tensor_1d.py @@ -4,16 +4,23 @@ import functools import itertools import operator -from math import log2 +from math import log, log2 from numbers import Integral import scipy.sparse.linalg as spla -from autoray import conj, dag, do, get_dtype_name, reshape, transpose +from autoray import conj, dag, do, get_dtype_name, reshape, size, transpose import quimb as qu from ..linalg.base_linalg import norm_trace_dense -from ..utils import deprecated, ensure_dict, partition_all, print_multi_line +from ..utils import ( + check_opt, + deprecated, + ensure_dict, + pairwise, + partition_all, + print_multi_line, +) from . import array_ops as ops from .tensor_arbgeom import ( TensorNetworkGen, @@ -28,6 +35,7 @@ Tensor, TensorNetwork, bonds, + new_bond, oset, rand_uuid, tags_to_oset, @@ -1501,7 +1509,13 @@ def from_fill_fn( @classmethod def from_dense( - cls, psi, dims, site_ind_id="k{}", site_tag_id="I{}", **split_opts + cls, + psi, + dims=2, + tags=None, + site_ind_id="k{}", + site_tag_id="I{}", + **split_opts, ): """Create a ``MatrixProductState`` directly from a dense vector @@ -1509,8 +1523,11 @@ def from_dense( ---------- psi : array_like The dense state to convert to MPS from. - dims : sequence of int - Physical subsystem dimensions of each site. + dims : int or sequence of int + Physical subsystem dimensions of each site. If a single int, all + sites have this same dimension, by default, 2. + tags : str or sequence of str, optional + Global tags to attach to all tensors. site_ind_id : str, optional How to index the physical sites, see :class:`~quimb.tensor.tensor_1d.MatrixProductState`. @@ -1540,48 +1557,52 @@ def from_dense( """ set_default_compress_mode(split_opts) # ensure compression is canonical / optimal - split_opts.setdefault("absorb", "left") + split_opts.setdefault("absorb", "right") - L = len(dims) - inds = [site_ind_id.format(i) for i in range(L)] + # make sure array_like + psi = ops.asarray(psi) - def gen_tensors(): - # split - # <-- : yield - # : : - # OOOOOOO--O-O-O - # ||||||| | | | - # ....... - # left_inds - tm = Tensor( - data=reshape(ops.asarray(psi), dims), - inds=inds, - ) - for i in range(L - 1, 0, -1): - tm, tr = tm.split( - left_inds=inds[:i], - get="tensors", - rtags=site_tag_id.format(i), - **split_opts, - ) - yield tr - tm.add_tag(site_tag_id.format(0)) - yield tm - - # the reverse is purely asthetic so the tensors are stored in the - # TN dictionary in the same order as the sites - ts = tuple(gen_tensors())[::-1] - - mps = TensorNetwork(ts) - # cast as correct TN class - return mps.view_as_( - cls, + if isinstance(dims, Integral): + # assume all sites have the same dimension + L = round(log(size(psi), dims)) + dims = (dims,) * L + else: + dims = tuple(dims) + L = len(dims) + + # create a bare MPS TN object + mps = cls.new( L=L, cyclic=False, site_ind_id=site_ind_id, site_tag_id=site_tag_id, ) + inds = [mps.site_ind(i) for i in range(L)] + + tm = Tensor(data=reshape(psi, dims), inds=inds) + for i in range(L - 1): + # progressively split off one more physical index + tl, tm = tm.split( + left_inds=None, + right_inds=inds[i + 1 :], + ltags=mps.site_tag(i), + get="tensors", + **split_opts, + ) + # add left tensor + mps |= tl + + # add final right tensor + tm.add_tag(mps.site_tag(L - 1)) + mps |= tm + + # add global tags + if tags is not None: + mps.add_tag(tags) + + return mps + def add_MPS(self, other, inplace=False, **kwargs): """Add another MatrixProductState to this one.""" return tensor_network_ag_sum(self, other, inplace=inplace, **kwargs) @@ -2890,6 +2911,186 @@ def from_fill_fn( return mpo + @classmethod + def from_dense( + cls, + A, + dims=2, + sites=None, + L=None, + tags=None, + site_tag_id="I{}", + upper_ind_id="k{}", + lower_ind_id="b{}", + **split_opts, + ): + """Build an MPO from a raw dense matrix. + + Parameters + ---------- + A : array + The dense operator, it should be reshapeable to ``(*dims, *dims)``. + dims : int, sequence of int, optional + The physical subdimensions of the operator. If any integer, assume + all sites have the same dimension. If a sequence, the dimension of + each site. Default is 2. + sites : sequence of int, optional + The sites to place the operator on. If None, will place it on + first `len(dims)` sites. + L : int, optional + The total number of sites in the MPO, if the operator represents + only a subset. + tags : str or sequence of str, optional + Global tags to attach to all tensors. + site_tag_id : str, optional + The string to use to label the site tags. + upper_ind_id : str, optional + The string to use to label the upper physical indices. + lower_ind_id : str, optional + The string to use to label the lower physical indices. + split_opts + Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`. + + Returns + ------- + MatrixProductOperator + """ + set_default_compress_mode(split_opts) + # ensure compression is canonical / optimal + split_opts.setdefault("absorb", "right") + + # make sure array_like + A = ops.asarray(A) + + if isinstance(dims, Integral): + # assume all sites have the same dimension + ng = round(log(size(A), dims) / 2) + dims = (dims,) * ng + else: + dims = tuple(dims) + ng = len(dims) + + if sites is None: + sorted_sites = sites = range(ng) + else: + sorted_sites = sorted(sites) + + if L is None: + L = max(sites) + 1 + + # create a bare MPO TN object + mpo = cls.new( + L=L, + cyclic=False, + upper_ind_id=upper_ind_id, + lower_ind_id=lower_ind_id, + site_tag_id=site_tag_id, + ) + + # initial inds and tensor contains desired site order ... + uix = [mpo.upper_ind(i) for i in sites] + lix = [mpo.lower_ind(i) for i in sites] + tm = Tensor(data=reshape(A, (*dims, *dims)), inds=uix + lix) + + # ... but want to create MPO in sorted site order + uix = [mpo.upper_ind(i) for i in sorted_sites] + lix = [mpo.lower_ind(i) for i in sorted_sites] + + for i, site in enumerate(sorted_sites[:-1]): + # progressively split off one more pair of physical indices + tl, tm = tm.split( + left_inds=None, + right_inds=uix[i + 1 :] + lix[i + 1 :], + ltags=mpo.site_tag(site), + get="tensors", + **split_opts, + ) + # add left tensor + mpo |= tl + + # add final right tensor + tm.add_tag(mpo.site_tag(sorted_sites[-1])) + mpo |= tm + + # add global tags + if tags is not None: + mpo.add_tag(tags) + + return mpo + + def fill_empty_sites_with_identities( + self, mode="full", phys_dim=None, inplace=False + ): + """Fill any empty sites of this MPO with identity tensors, adding + size 1 bonds or draping existing bonds where necessary such that the + resulting tensor has nearest neighbor bonds only. + + Parameters + ---------- + mode : {'full', 'minimal'}, optional + Whether to fill in all sites, including at either end, or simply + the minimal range covering the min to max current sites present. + phys_dim : int, optional + The physical dimension of the identity tensors to add. If not + specified, will use the upper physical dimension of the first + present site. + inplace : bool, optional + Whether to perform the operation inplace. + + Returns + ------- + MatrixProductOperator + The modified MPO. + """ + check_opt("mode", mode, ("full", "minimal")) + + mpo = self if inplace else self.copy() + + sites_present = tuple(mpo.gen_sites_present()) + sites_present_set = set(sites_present) + sitei = sites_present[0] + sitef = sites_present[-1] + + t0 = mpo[sitei] + + if phys_dim is None: + d = mpo.phys_dim(sitei) + + if mode == "full": + site_range = range(mpo.L) + else: # mode == "minimal" + site_range = range(sitei, sitef + 1) + + for site in site_range: + if site not in sites_present_set: + mpo |= Tensor( + data=do("eye", d, dtype=t0.dtype, like=t0.data), + inds=(mpo.upper_ind(site), mpo.lower_ind(site)), + tags=mpo.site_tag(site), + ) + + for si, sj in pairwise(sites_present): + if bonds(mpo[si], mpo[sj]): + # need to drape existing bonds + for i in range(si, sj - 1): + mpo.drape_bond_between_(i, sj, i + 1) + else: + # just add bond dim 1 + for i in range(si, sj): + new_bond(mpo[i], mpo[i + 1]) + + if mode == "full": + for i in range(0, sitei): + new_bond(mpo[i], mpo[i + 1]) + for i in range(sitef, mpo.L - 1): + new_bond(mpo[i], mpo[i + 1]) + + return mpo + + fill_empty_sites_with_identities_ = functools.partialmethod( + fill_empty_sites_with_identities, inplace=True + ) + def add_MPO(self, other, inplace=False, **kwargs): return tensor_network_ag_sum(self, other, inplace=inplace, **kwargs) diff --git a/quimb/tensor/tensor_arbgeom.py b/quimb/tensor/tensor_arbgeom.py index 13f78c4a..139a09f1 100644 --- a/quimb/tensor/tensor_arbgeom.py +++ b/quimb/tensor/tensor_arbgeom.py @@ -177,14 +177,16 @@ def tensor_network_apply_op_vec( f"Invalid `which_A`: {which_A}, should be 'lower' or 'upper'." ) - x.reindex_sites_(inner_ind_id) + # only want to reindex on sites that being acted on + sites_present_in_A = tuple(A.gen_sites_present()) + x.reindex_sites_(inner_ind_id, where=sites_present_in_A) # combine the tensor networks x |= A if contract: # optionally contract all tensor at each site - for site in x.gen_sites_present(): + for site in sites_present_in_A: x ^= site if fuse_multibonds: diff --git a/quimb/tensor/tensor_builder.py b/quimb/tensor/tensor_builder.py index b5d369d0..d9662dc3 100644 --- a/quimb/tensor/tensor_builder.py +++ b/quimb/tensor/tensor_builder.py @@ -2658,7 +2658,6 @@ def TN_classical_partition_function_from_edges( to_contract = collections.defaultdict(list) ts = [] for node_a, node_b in gen_unique_edges(edges): - # the variable indices (unless the node is # an output, these will be contracted) ix_a = ind_id.format(node_a) @@ -3978,15 +3977,13 @@ def MPO_zeros(L, phys_dim=2, dtype="float64", cyclic=False, **mpo_opts): ------- MatrixProductOperator """ - cyc_dim = (1,) if cyclic else () - def gen_arrays(): - yield np.zeros((*cyc_dim, 1, phys_dim, phys_dim), dtype=dtype) - for _ in range(L - 2): - yield np.zeros((1, 1, phys_dim, phys_dim), dtype=dtype) - yield np.zeros((1, *cyc_dim, phys_dim, phys_dim), dtype=dtype) + def fill_fn(shape): + return np.zeros(shape, dtype=dtype) - return MatrixProductOperator(gen_arrays(), **mpo_opts) + return MatrixProductOperator.from_fill_fn( + fill_fn, L=L, bond_dim=1, phys_dim=phys_dim, cyclic=cyclic, **mpo_opts + ) def MPO_zeros_like(mpo, **mpo_opts): @@ -4089,25 +4086,31 @@ def MPO_rand( mpo_opts Supplied to :class:`~quimb.tensor.tensor_1d.MatrixProductOperator`. """ - cyc_shp = (bond_dim,) if cyclic else () - - shapes = [ - (*cyc_shp, bond_dim, phys_dim, phys_dim), - *((bond_dim, bond_dim, phys_dim, phys_dim),) * (L - 2), - (bond_dim, *cyc_shp, phys_dim, phys_dim), - ] + base_fill_fn = get_rand_fill_fn( + dtype=dtype, dist=dist, loc=loc, scale=scale + ) - def gen_data(shape): - data = randn(shape, dtype=dtype, dist=dist, loc=loc, scale=scale) - if not herm: - return data + if not herm: - trans = (0, 2, 1) if len(shape) == 3 else (0, 1, 3, 2) - return data + data.transpose(*trans).conj() + def fill_fn(shape): + return sensibly_scale(base_fill_fn(shape)) + else: - arrays = map(sensibly_scale, map(gen_data, shapes)) + def fill_fn(shape): + data = base_fill_fn(shape) + trans = (0, 2, 1) if len(shape) == 3 else (0, 1, 3, 2) + data += data.transpose(*trans).conj() + return sensibly_scale(data) - rmpo = MatrixProductOperator(arrays, **mpo_opts) + rmpo = MatrixProductOperator.from_fill_fn( + fill_fn, + L=L, + bond_dim=bond_dim, + phys_dim=phys_dim, + cyclic=cyclic, + shape="lrud", + **mpo_opts, + ) if normalize: rmpo /= (rmpo.H @ rmpo) ** 0.5 diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index ae5e677d..eec167bf 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -2173,6 +2173,47 @@ def new_ind_with_identity(self, name, left_inds, right_inds, axis=0): new_inds.insert(axis, name) self.modify(data=new_data, inds=new_inds) + def new_ind_pair_with_identity( + self, new_left_ind, new_right_ind, d, inplace=False, + ): + """Expand this tensor with two new indices of size ``d``, by taking an + (outer) tensor product with the identity operator. The two new indices + are added as axes at the start of the tensor. + + Parameters + ---------- + new_left_ind : str + Name of the new left index. + new_right_ind : str + Name of the new right index. + d : int + Size of the new indices. + inplace : bool, optional + Whether to perform the expansion inplace. + + Returns + ------- + Tensor + """ + t = self if inplace else self.copy() + + # tensor product identity in + x_id = do("eye", d, dtype=t.dtype, like=t.data) + output = tuple(range(t.ndim + 2)) + new_data = array_contract( + arrays=(x_id, t.data), + inputs=(output[:2], output[2:]), + output=output + ) + # update indices + new_inds = (new_left_ind, new_right_ind, *t.inds) + t.modify(data=new_data, inds=new_inds) + return t + + new_ind_pair_with_identity_ = functools.partialmethod( + new_ind_pair_with_identity, inplace=True + ) + def conj(self, inplace=False): """Conjugate this tensors data (does nothing to indices).""" t = self if inplace else self.copy() @@ -8151,6 +8192,80 @@ def cut_bond(self, bond, new_left_ind=None, new_right_ind=None): tr.reindex_({bond: new_right_ind}) return new_left_ind, new_right_ind + def drape_bond_between( + self, + tagsa, + tagsb, + tags_target, + left_ind=None, + right_ind=None, + inplace=False, + ): + r"""Take the bond(s) connecting the tensors tagged at ``tagsa`` and + ``tagsb``, and 'drape' it through the tensor tagged at ``tags_target``, + effectively adding an identity tensor between the two and contracting + it with the third:: + + ┌─┐ ┌─┐ ┌─┐ ┌─┐ + ─┤A├─Id─┤B├─ ─┤A├─┐ ┌─┤B├─ + └─┘ └─┘ └─┘ │ │ └─┘ + left_ind│ │right_ind + ┌─┐ --> ├─┤ + ─┤C├─ ─┤D├─ + └┬┘ └┬┘ where D = C ⊗ Id + │ │ + + This increases the size of the target tensor by ``d**2``, and + disconnects the tensors at ``tagsa`` and ``tagsb``. + + Parameters + ---------- + tagsa : str or sequence of str + The tag(s) identifying the first tensor. + tagsb : str or sequence of str + The tag(s) identifying the second tensor. + tags_target : str or sequence of str + The tag(s) identifying the target tensor. + left_ind : str, optional + The new index to give to the left tensor. + right_ind : str, optional + The new index to give to the right tensor. + inplace : bool, optional + Whether to perform the draping inplace. + + Returns + ------- + TensorNetwork + """ + # TODO: tids version? + tn = self if inplace else self.copy() + + ta = tn[tagsa] + tb = tn[tagsb] + _, bix, _ = tensor_make_single_bond(ta, tb) + d = ta.ind_size(bix) + + if left_ind is None: + left_ind = rand_uuid() + if left_ind != bix: + ta.reindex_({bix: left_ind}) + + if right_ind is None: + right_ind = rand_uuid() + elif right_ind == left_ind: + raise ValueError("right_ind cannot be the same as left_ind") + if right_ind != bix: + tb.reindex_({bix: right_ind}) + + t_target = tn[tags_target] + t_target.new_ind_pair_with_identity_(left_ind, right_ind, d) + + return tn + + drape_bond_between_ = functools.partialmethod( + drape_bond_between, inplace=True + ) + def isel(self, selectors, inplace=False): """Select specific values for some dimensions/indices of this tensor network, thereby removing them. diff --git a/tests/test_tensor/test_tensor_1d.py b/tests/test_tensor/test_tensor_1d.py index a40681ec..0bceac9d 100644 --- a/tests/test_tensor/test_tensor_1d.py +++ b/tests/test_tensor/test_tensor_1d.py @@ -82,7 +82,7 @@ def test_trans_invar(self): def test_from_dense(self): L = 8 psi = qu.rand_ket(2**L) - mps = MatrixProductState.from_dense(psi, dims=[2] * L) + mps = MatrixProductState.from_dense(psi) assert mps.tags == oset(f"I{i}" for i in range(L)) assert mps.site_inds == tuple(f"k{i}" for i in range(L)) assert mps.L == L @@ -956,6 +956,46 @@ def test_permute_arrays(self): Af = mpo.to_qarray() assert_allclose(A0, Af) + def test_from_dense(self): + A = qu.rand_uni(2**4) + mpo = MatrixProductOperator.from_dense(A) + assert mpo.L == 4 + assert_allclose(A, mpo.to_dense()) + + def test_from_dense_sites(self): + dims = [2, 3, 4, 5] + A = qu.rand_uni(2 * 3 * 4 * 5) + sites = [3, 1, 0, 2] + mpo = MatrixProductOperator.from_dense(A, dims, sites=sites) + assert mpo.L == 4 + perm = [sites.index(i) for i in range(4)] + assert_allclose(qu.permute(A, dims, perm), mpo.to_dense()) + + def test_fill_empty_sites_with_identities(self): + mps = MPS_rand_state(7, 3) + k = mps.to_dense() + A, B, C = (qu.rand_uni(2) for _ in range(3)) + Ak = qu.ikron((A, B, C), [2] * 7, [5, 2, 3]) @ k + + ABC = A & B & C + mpo = MatrixProductOperator.from_dense(ABC, sites=[5, 2, 3], L=7) + assert mpo.bond_size(2, 3) == 1 + assert mpo.num_tensors == 3 + assert mpo[3].bonds(mpo[5]) + mpo.fill_empty_sites_with_identities_("minimal") + assert not mpo[3].bonds(mpo[5]) + assert mpo.num_tensors == 4 + assert_allclose( + mps.gate_with_op_lazy(mpo).to_dense(), + Ak, + ) + mpo.fill_empty_sites_with_identities_("full") + assert mpo.num_tensors == 7 + assert_allclose( + mps.gate_with_op_lazy(mpo).to_dense(), + Ak, + ) + # --------------------------------------------------------------------------- # # Test specific 1D instances # diff --git a/tests/test_tensor/test_tensor_core.py b/tests/test_tensor/test_tensor_core.py index 1a016b03..a8f20b0a 100644 --- a/tests/test_tensor/test_tensor_core.py +++ b/tests/test_tensor/test_tensor_core.py @@ -1615,6 +1615,20 @@ def test_cut_bond(self): assert ta.inds == ("a", "b", "l") assert tb.inds == ("r", "d", "e") + def test_drape_bond_between(self): + tx = qtn.rand_tensor([2, 3, 4], ['a', 'b', 'c'], tags="X") + ty = qtn.rand_tensor([3, 4, 6], ['b', 'd', 'e'], tags="Y") + tz = qtn.rand_tensor([5], ['f'], tags="Z") + tn = (tx | ty | tz) + assert tn.num_indices == 6 + assert len(tn.subgraphs()) == 2 + te = tn.contract() + tn.drape_bond_between_("X", "Y", "Z") + assert tn.num_indices == 7 + assert len(tn.subgraphs()) == 1 + t = tn.contract() + assert t.distance_normalized(te) == pytest.approx(0.0) + def test_draw(self): import matplotlib from matplotlib import pyplot as plt