From d410482ca2f7b345252a15c174243f4d41b54087 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Thu, 30 Jan 2025 16:06:27 -0800 Subject: [PATCH] move BP from experimental to quimb.tensor.belief_propagation --- docs/changelog.md | 12 +- .../belief_propagation/__init__.py | 112 +----- .../belief_propagation/regions.py | 194 --------- quimb/tensor/belief_propagation/__init__.py | 110 ++++++ .../belief_propagation/bp_common.py | 0 .../belief_propagation/d1bp.py | 0 .../belief_propagation/d2bp.py | 0 .../belief_propagation/diis.py | 0 .../belief_propagation/hd1bp.py | 0 .../belief_propagation/hv1bp.py | 0 .../belief_propagation/l1bp.py | 0 .../belief_propagation/l2bp.py | 0 quimb/tensor/belief_propagation/regions.py | 373 ++++++++++++++++++ quimb/tensor/tensor_arbgeom.py | 6 +- quimb/tensor/tensor_arbgeom_compress.py | 4 +- .../test_belief_propagation/test_d1bp.py | 11 +- .../test_belief_propagation/test_d2bp.py | 16 +- .../test_belief_propagation/test_hd1bp.py | 16 +- .../test_belief_propagation/test_hv1bp.py | 13 +- .../test_belief_propagation/test_l1bp.py | 21 +- .../test_belief_propagation/test_l2bp.py | 15 +- 21 files changed, 540 insertions(+), 363 deletions(-) delete mode 100644 quimb/experimental/belief_propagation/regions.py create mode 100644 quimb/tensor/belief_propagation/__init__.py rename quimb/{experimental => tensor}/belief_propagation/bp_common.py (100%) rename quimb/{experimental => tensor}/belief_propagation/d1bp.py (100%) rename quimb/{experimental => tensor}/belief_propagation/d2bp.py (100%) rename quimb/{experimental => tensor}/belief_propagation/diis.py (100%) rename quimb/{experimental => tensor}/belief_propagation/hd1bp.py (100%) rename quimb/{experimental => tensor}/belief_propagation/hv1bp.py (100%) rename quimb/{experimental => tensor}/belief_propagation/l1bp.py (100%) rename quimb/{experimental => tensor}/belief_propagation/l2bp.py (100%) create mode 100644 quimb/tensor/belief_propagation/regions.py diff --git a/docs/changelog.md b/docs/changelog.md index 713fda7a..4f1ba3f7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,6 +5,14 @@ Release notes for `quimb`. (whats-new-1-10-1)= ## v1.10.1 (unreleased) +**Breaking Changes** + +- move belief propagation to `quimb.tensor.belief_propagation` + +**Breaking Changes** + +- move belief propagation to `quimb.tensor.belief_propagation` + **Enhancements:** - [`MatrixProductState.measure`](quimb.tensor.tensor_1d.MatrixProductState.measure): add a `seed` kwarg @@ -128,7 +136,7 @@ Release notes for `quimb`. - add [`TensorNetwork.drape_bond_between`](quimb.tensor.tensor_core.TensorNetwork.drape_bond_between) for 'draping' an existing bond between two tensors through a third - add [`Tensor.new_ind_pair_with_identity`](quimb.tensor.tensor_core.Tensor.new_ind_pair_with_identity) - TN2D, TN3D and arbitrary geom classical partition function builders ([`TN_classical_partition_function_from_edges`](quimb.tensor.tensor_builder.TN_classical_partition_function_from_edges)) now all support `outputs=` kwarg specifying non-marginalized variables -- add simple dense 1-norm belief propagation algorithm [`D1BP`](quimb.experimental.belief_propagation.d1bp.D1BP) +- add simple dense 1-norm belief propagation algorithm [`D1BP`](quimb.tensor.belief_propagation.d1bp.D1BP) - add [`qtn.enforce_1d_like`](quimb.tensor.tensor_1d_compress.enforce_1d_like) for checking whether a tensor network is 1D-like, including automatically adding strings of identities between non-local bonds, expanding applicability of [`tensor_network_1d_compress`](quimb.tensor.tensor_1d_compress.tensor_network_1d_compress) - add [`MatrixProductState.canonicalize`](quimb.tensor.tensor_1d.MatrixProductState.canonicalize) as (by default *non-inplace*) version of `canonize`, to follow the pattern of other tensor network methods. `canonize` is now an alias for `canonicalize_` [note trailing underscore]. - add [`MatrixProductState.left_canonicalize`](quimb.tensor.tensor_1d.MatrixProductState.left_canonicalize) as (by default *non-inplace*) version of `left_canonize`, to follow the pattern of other tensor network methods. `left_canonize` is now an alias for `left_canonicalize_` [note trailing underscore]. @@ -339,7 +347,7 @@ Release notes for `quimb`. [Tensor.idxmax](quimb.tensor.Tensor.idxmax) for finding the index of the minimum/maximum element. - 2D and 3D classical partition function TN builders: allow output indices. -- [`quimb.experimental.belief_propagation`]([`quimb.experimental.belief_propagation`]): +- [`quimb.tensor.belief_propagation`]([`quimb.tensor.belief_propagation`]): add various 1-norm/2-norm dense/lazy BP algorithms. **Bug fixes:** diff --git a/quimb/experimental/belief_propagation/__init__.py b/quimb/experimental/belief_propagation/__init__.py index 7000590a..56cafcfc 100644 --- a/quimb/experimental/belief_propagation/__init__.py +++ b/quimb/experimental/belief_propagation/__init__.py @@ -1,110 +1,8 @@ -"""Belief propagation (BP) routines. There are three potential categorizations -of BP and each combination of them is potentially valid specific algorithm. +import warnings -1-norm vs 2-norm BP -------------------- +from quimb.tensor.belief_propagation import * -- 1-norm (normal): BP runs directly on the tensor network, messages have size - ``d`` where ``d`` is the size of the bond(s) connecting two tensors or - regions. -- 2-norm (quantum): BP runs on the squared tensor network, messages have size - ``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or - regions. Each local tensor or region is partially traced (over dangling - indices) with its conjugate to create a single node. - - -Graph vs Hypergraph BP ----------------------- - -- Graph (simple): the tensor network lives on a graph, where indices either - appear on two tensors (a bond), or appear on a single tensor (are outputs). - In this case, messages are exchanged directly between tensors. -- Hypergraph: the tensor network lives on a hypergraph, where indices can - appear on any number of tensors. In this case, the update procedure is two - parts, first all 'tensor' messages are computed, these are then used in the - second step to compute all the 'index' messages, which are then fed back into - the 'tensor' message update and so forth. For 2-norm BP one likely needs to - specify which indices are outputs and should be traced over. - -The hypergraph case of course includes the graph case, but since the 'index' -message update is simply the identity, it is convenient to have a separate -simpler implementation, where the standard TN bond vs physical index -definitions hold. - - -Dense vs Vectorized vs Lazy BP ------------------------------- - -- Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all - multibonds have been fused, then each message is a vector (1-norm case) or - matrix (2-norm case). -- Vectorized: the same as the above, but all matching tensor update and message - updates are stacked and performed simultaneously. This can be enormously more - efficient for large numbers of small tensors. -- Lazy: each node is potentially a tensor network itself with arbitrary inner - structure and number of bonds connecting to other nodes. The message are - generally tensors and each update is a lazy contraction, which is potentially - much cheaper / requires less memory than forming the 'dense' node for large - tensors. - -(There is also the MPS flavor where each node has a 1D structure and the -messages are matrix product states, with updates involving compression.) - - -Overall that gives 12 possible BP flavors, some implemented here: - -- [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm -- [x] (HD2BP) hyper, dense, 2-norm -- [x] (HV1BP) hyper, vectorized, 1-norm -- [ ] (HV2BP) hyper, vectorized, 2-norm -- [ ] (HL1BP) hyper, lazy, 1-norm -- [ ] (HL2BP) hyper, lazy, 2-norm -- [x] (D1BP) simple, dense, 1-norm - simple BP for simple tensor networks -- [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm -- [ ] (V1BP) simple, vectorized, 1-norm -- [ ] (V2BP) simple, vectorized, 2-norm -- [x] (L1BP) simple, lazy, 1-norm -- [x] (L2BP) simple, lazy, 2-norm - -The 2-norm methods can be used to compress bonds or estimate the 2-norm. -The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value. -Both methods can be used to compute index marginals and thus perform sampling. - -The vectorized methods can be extremely fast for large numbers of small -tensors, but do currently require all dimensions to match. - -The dense and lazy methods can can converge messages *locally*, i.e. only -update messages adjacent to messages which have changed. -""" - -from .bp_common import combine_local_contractions, initialize_hyper_messages -from .d1bp import D1BP, contract_d1bp -from .d2bp import D2BP, compress_d2bp, contract_d2bp, sample_d2bp -from .hd1bp import HD1BP, contract_hd1bp, sample_hd1bp -from .hv1bp import HV1BP, contract_hv1bp, sample_hv1bp -from .l1bp import L1BP, contract_l1bp -from .l2bp import L2BP, compress_l2bp, contract_l2bp -from .regions import RegionGraph - -__all__ = ( - "combine_local_contractions", - "compress_d2bp", - "compress_l2bp", - "contract_d1bp", - "contract_d2bp", - "contract_hd1bp", - "contract_hv1bp", - "contract_l1bp", - "contract_l2bp", - "D1BP", - "D2BP", - "HD1BP", - "HV1BP", - "initialize_hyper_messages", - "L1BP", - "L2BP", - "RegionGraph", - "sample_d2bp", - "sample_hd1bp", - "sample_hv1bp", +warnings.warn( + "Most functionality of 'quimb.experimental.belief_propagation' " + "has been moved to `quimb.tensor.belief_propagation`.", ) diff --git a/quimb/experimental/belief_propagation/regions.py b/quimb/experimental/belief_propagation/regions.py deleted file mode 100644 index fb059ba5..00000000 --- a/quimb/experimental/belief_propagation/regions.py +++ /dev/null @@ -1,194 +0,0 @@ -class RegionGraph: - def __init__(self, regions=(), autocomplete=True): - self.lookup = {} - self.parents = {} - self.children = {} - self.counts = {} - for region in regions: - self.add_region(region) - if autocomplete: - self.autocomplete() - - @property - def regions(self): - return tuple(self.children) - - def neighbor_regions(self, region): - """Get all regions that intersect with the given region.""" - region = frozenset(region) - - other_regions = set.union(*(self.lookup[node] for node in region)) - other_regions.discard(region) - return other_regions - - def add_region(self, region): - """Add a new region and update parent-child relationships. - - Parameters - ---------- - region : Sequence[Hashable] - The new region to add. - """ - region = frozenset(region) - - if region in self.parents: - # already added - return - - # populate data structures - self.parents[region] = set() - self.children[region] = set() - for node in region: - # collect regions that contain nodes for fast neighbor lookup - self.lookup.setdefault(node, set()).add(region) - - # add parent-child relationships - for other in self.neighbor_regions(region): - if region.issubset(other): - self.parents[region].add(other) - self.children[other].add(region) - elif other.issubset(region): - self.children[region].add(other) - self.parents[other].add(region) - - # prune redundant parents and children - children = sorted(self.children[region], key=len) - for i, c in enumerate(children): - if any(c.issubset(cc) for cc in children[i + 1 :]): - # child is a subset of larger child -> remove - self.children[region].remove(c) - self.parents[c].remove(region) - - parents = sorted(self.parents[region], key=len, reverse=True) - for i, p in enumerate(parents): - if any(p.issuperset(pp) for pp in parents[i + 1 :]): - # parent is a superset of smaller parent -> remove - self.parents[region].remove(p) - self.children[p].remove(region) - - self.counts.clear() - - def autocomplete(self): - """Add all missing intersecting sub-regions.""" - for r in self.regions: - for other in self.neighbor_regions(r): - self.add_region(r & other) - - def autoextend(self, regions=None): - """Extend this region graph upwards by adding in all pairwise unions of - regions. If regions is specified, take this as one set of pairs. - """ - if regions is None: - regions = self.regions - - neighbors = {} - for r in regions: - for other in self.neighbor_regions(r): - neighbors.setdefault(r, []).append(other) - - for r, others in neighbors.items(): - for other in others: - self.add_region(r | other) - - def get_parents(self, region): - """Get all ancestors that contain the given region, but do not contain - any other regions that themselves contain the given region. - """ - return self.parents[region] - - def get_children(self, region): - """Get all regions that are contained by the given region, but are not - contained by any other descendents of the given region. - """ - return self.children[region] - - def get_ancestors(self, region): - """Get all regions that contain the given region, not just direct - parents. - """ - seen = set() - queue = [region] - while queue: - r = queue.pop() - for rp in self.parents[r]: - if rp not in seen: - seen.add(rp) - queue.append(rp) - return seen - - def get_descendents(self, region): - """Get all regions that are contained by the given region, not just - direct children. - """ - seen = set() - queue = [region] - while queue: - r = queue.pop() - for rc in self.children[r]: - if rc not in seen: - seen.add(rc) - queue.append(rc) - return seen - - def get_count(self, region): - """Get the count of the given region, i.e. the correct weighting to - apply when summing over all regions to avoid overcounting. - """ - try: - C = self.counts[region] - except KeyError: - # n.b. cache is cleared when any new region is added - C = self.counts[region] = 1 - sum( - self.get_count(a) for a in self.get_ancestors(region) - ) - return C - - def get_total_count(self): - return sum(map(self.get_count, self.regions)) - - def get_level(self, region): - """Get the level of the given region, i.e. the distance to an ancestor - with no parents. - """ - if not self.parents[region]: - return 0 - else: - return min(self.get_level(p) for p in self.parents[region]) - 1 - - def draw(self, pos=None, a=20, scale=1.0, radius=0.1, **drawing_opts): - from quimb.schematic import Drawing, hash_to_color - - if pos is None: - pos = {node: node for node in self.lookup} - - def get_draw_pos(coo): - return tuple(scale * s for s in pos[coo]) - - sizes = {len(r) for r in self.regions} - levelmap = {s: i for i, s in enumerate(sorted(sizes))} - - d = Drawing(a=a, **drawing_opts) - for region in sorted(self.regions, key=len, reverse=True): - # level = self.get_level(region) - # level = len(region) - level = levelmap[len(region)] - - coos = [(*get_draw_pos(coo), 2.0 * level) for coo in region] - - d.patch_around( - coos, - radius=radius, - # edgecolor=hash_to_color(str(region)), - facecolor=hash_to_color(str(region)), - alpha=1 / 3, - linestyle="", - linewidth=3, - ) - - return d.fig, d.ax - - def __repr__(self): - return ( - f"" - ) diff --git a/quimb/tensor/belief_propagation/__init__.py b/quimb/tensor/belief_propagation/__init__.py new file mode 100644 index 00000000..7000590a --- /dev/null +++ b/quimb/tensor/belief_propagation/__init__.py @@ -0,0 +1,110 @@ +"""Belief propagation (BP) routines. There are three potential categorizations +of BP and each combination of them is potentially valid specific algorithm. + +1-norm vs 2-norm BP +------------------- + +- 1-norm (normal): BP runs directly on the tensor network, messages have size + ``d`` where ``d`` is the size of the bond(s) connecting two tensors or + regions. +- 2-norm (quantum): BP runs on the squared tensor network, messages have size + ``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or + regions. Each local tensor or region is partially traced (over dangling + indices) with its conjugate to create a single node. + + +Graph vs Hypergraph BP +---------------------- + +- Graph (simple): the tensor network lives on a graph, where indices either + appear on two tensors (a bond), or appear on a single tensor (are outputs). + In this case, messages are exchanged directly between tensors. +- Hypergraph: the tensor network lives on a hypergraph, where indices can + appear on any number of tensors. In this case, the update procedure is two + parts, first all 'tensor' messages are computed, these are then used in the + second step to compute all the 'index' messages, which are then fed back into + the 'tensor' message update and so forth. For 2-norm BP one likely needs to + specify which indices are outputs and should be traced over. + +The hypergraph case of course includes the graph case, but since the 'index' +message update is simply the identity, it is convenient to have a separate +simpler implementation, where the standard TN bond vs physical index +definitions hold. + + +Dense vs Vectorized vs Lazy BP +------------------------------ + +- Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all + multibonds have been fused, then each message is a vector (1-norm case) or + matrix (2-norm case). +- Vectorized: the same as the above, but all matching tensor update and message + updates are stacked and performed simultaneously. This can be enormously more + efficient for large numbers of small tensors. +- Lazy: each node is potentially a tensor network itself with arbitrary inner + structure and number of bonds connecting to other nodes. The message are + generally tensors and each update is a lazy contraction, which is potentially + much cheaper / requires less memory than forming the 'dense' node for large + tensors. + +(There is also the MPS flavor where each node has a 1D structure and the +messages are matrix product states, with updates involving compression.) + + +Overall that gives 12 possible BP flavors, some implemented here: + +- [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm +- [x] (HD2BP) hyper, dense, 2-norm +- [x] (HV1BP) hyper, vectorized, 1-norm +- [ ] (HV2BP) hyper, vectorized, 2-norm +- [ ] (HL1BP) hyper, lazy, 1-norm +- [ ] (HL2BP) hyper, lazy, 2-norm +- [x] (D1BP) simple, dense, 1-norm - simple BP for simple tensor networks +- [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm +- [ ] (V1BP) simple, vectorized, 1-norm +- [ ] (V2BP) simple, vectorized, 2-norm +- [x] (L1BP) simple, lazy, 1-norm +- [x] (L2BP) simple, lazy, 2-norm + +The 2-norm methods can be used to compress bonds or estimate the 2-norm. +The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value. +Both methods can be used to compute index marginals and thus perform sampling. + +The vectorized methods can be extremely fast for large numbers of small +tensors, but do currently require all dimensions to match. + +The dense and lazy methods can can converge messages *locally*, i.e. only +update messages adjacent to messages which have changed. +""" + +from .bp_common import combine_local_contractions, initialize_hyper_messages +from .d1bp import D1BP, contract_d1bp +from .d2bp import D2BP, compress_d2bp, contract_d2bp, sample_d2bp +from .hd1bp import HD1BP, contract_hd1bp, sample_hd1bp +from .hv1bp import HV1BP, contract_hv1bp, sample_hv1bp +from .l1bp import L1BP, contract_l1bp +from .l2bp import L2BP, compress_l2bp, contract_l2bp +from .regions import RegionGraph + +__all__ = ( + "combine_local_contractions", + "compress_d2bp", + "compress_l2bp", + "contract_d1bp", + "contract_d2bp", + "contract_hd1bp", + "contract_hv1bp", + "contract_l1bp", + "contract_l2bp", + "D1BP", + "D2BP", + "HD1BP", + "HV1BP", + "initialize_hyper_messages", + "L1BP", + "L2BP", + "RegionGraph", + "sample_d2bp", + "sample_hd1bp", + "sample_hv1bp", +) diff --git a/quimb/experimental/belief_propagation/bp_common.py b/quimb/tensor/belief_propagation/bp_common.py similarity index 100% rename from quimb/experimental/belief_propagation/bp_common.py rename to quimb/tensor/belief_propagation/bp_common.py diff --git a/quimb/experimental/belief_propagation/d1bp.py b/quimb/tensor/belief_propagation/d1bp.py similarity index 100% rename from quimb/experimental/belief_propagation/d1bp.py rename to quimb/tensor/belief_propagation/d1bp.py diff --git a/quimb/experimental/belief_propagation/d2bp.py b/quimb/tensor/belief_propagation/d2bp.py similarity index 100% rename from quimb/experimental/belief_propagation/d2bp.py rename to quimb/tensor/belief_propagation/d2bp.py diff --git a/quimb/experimental/belief_propagation/diis.py b/quimb/tensor/belief_propagation/diis.py similarity index 100% rename from quimb/experimental/belief_propagation/diis.py rename to quimb/tensor/belief_propagation/diis.py diff --git a/quimb/experimental/belief_propagation/hd1bp.py b/quimb/tensor/belief_propagation/hd1bp.py similarity index 100% rename from quimb/experimental/belief_propagation/hd1bp.py rename to quimb/tensor/belief_propagation/hd1bp.py diff --git a/quimb/experimental/belief_propagation/hv1bp.py b/quimb/tensor/belief_propagation/hv1bp.py similarity index 100% rename from quimb/experimental/belief_propagation/hv1bp.py rename to quimb/tensor/belief_propagation/hv1bp.py diff --git a/quimb/experimental/belief_propagation/l1bp.py b/quimb/tensor/belief_propagation/l1bp.py similarity index 100% rename from quimb/experimental/belief_propagation/l1bp.py rename to quimb/tensor/belief_propagation/l1bp.py diff --git a/quimb/experimental/belief_propagation/l2bp.py b/quimb/tensor/belief_propagation/l2bp.py similarity index 100% rename from quimb/experimental/belief_propagation/l2bp.py rename to quimb/tensor/belief_propagation/l2bp.py diff --git a/quimb/tensor/belief_propagation/regions.py b/quimb/tensor/belief_propagation/regions.py new file mode 100644 index 00000000..b56866f8 --- /dev/null +++ b/quimb/tensor/belief_propagation/regions.py @@ -0,0 +1,373 @@ +"""Region graph functionality - for GBP and cluster expansions. +""" + +import functools +import itertools + + +def cached_region_property(name): + """Decorator for caching information about regions.""" + + def wrapper(meth): + @functools.wraps(meth) + def getter(self, region): + try: + return self.info[region][name] + except KeyError: + region_info = self.info.setdefault(region, {}) + region_info[name] = value = meth(self, region) + return value + + return getter + + return wrapper + + +class RegionGraph: + """A graph of regions, where each region is a set of nodes. For generalized + belief propagation or cluster expansion methods. + """ + + def __init__(self, regions=(), autocomplete=True, autoprune=True): + self.lookup = {} + self.parents = {} + self.children = {} + self.info = {} + + for region in regions: + self.add_region(region) + if autocomplete: + self.autocomplete() + if autoprune: + self.autoprune() + + def reset_info(self): + """Remove all cached region properties. + """ + self.info.clear() + + @property + def regions(self): + return tuple(self.children) + + def get_overlapping(self, region): + """Get all regions that intersect with the given region.""" + region = frozenset(region) + return { + other_region + for node in region + for other_region in self.lookup[node] + if other_region != region + } + + def add_region(self, region): + """Add a new region and update parent-child relationships. + + Parameters + ---------- + region : Sequence[Hashable] + The new region to add. + """ + region = frozenset(region) + + if region in self.parents: + # already added + return + + # populate data structures + self.parents[region] = set() + self.children[region] = set() + for node in region: + # collect regions that contain nodes for fast neighbor lookup + self.lookup.setdefault(node, set()).add(region) + + # add parent-child relationships + for other in self.get_overlapping(region): + if region.issubset(other): + self.parents[region].add(other) + self.children[other].add(region) + elif other.issubset(region): + self.children[region].add(other) + self.parents[other].add(region) + + # prune redundant parents and children + children = sorted(self.children[region], key=len) + for i, c in enumerate(children): + if any(c.issubset(cc) for cc in children[i + 1 :]): + # child is a subset of larger child -> remove + self.children[region].remove(c) + self.parents[c].remove(region) + + parents = sorted(self.parents[region], key=len, reverse=True) + for i, p in enumerate(parents): + if any(p.issuperset(pp) for pp in parents[i + 1 :]): + # parent is a superset of smaller parent -> remove + self.parents[region].remove(p) + self.children[p].remove(region) + + for c in self.children[region]: + if p.issuperset(c): + # parent is a superset of child -> ensure no link + self.parents[c].discard(p) + self.children[p].discard(c) + + self.reset_info() + + def remove_region(self, region): + """Remove a region and update parent-child relationships. + """ + # remove from lookup + for node in region: + self.lookup[node].remove(region) + + # remove from parents and children, joining those up + parents = self.parents.pop(region) + children = self.children.pop(region) + for p in parents: + self.children[p].remove(region) + self.children[p].update(children) + for c in children: + self.parents[c].remove(region) + self.parents[c].update(parents) + + self.reset_info() + + def autocomplete(self): + """Add all missing intersecting sub-regions.""" + for r in self.regions: + for other in self.get_overlapping(r): + self.add_region(r & other) + + def autoprune(self): + """Remove all regions with a count of zero.""" + for r in self.regions: + if self.get_count(r) == 0: + self.remove_region(r) + + def autoextend(self, regions=None): + """Extend this region graph upwards by adding in all pairwise unions of + regions. If regions is specified, take this as one set of pairs. + """ + if regions is None: + regions = self.regions + + neighbors = {} + for r in regions: + for other in self.get_overlapping(r): + neighbors.setdefault(r, []).append(other) + + for r, others in neighbors.items(): + for other in others: + self.add_region(r | other) + + def get_parents(self, region): + """Get all ancestors that contain the given region, but do not contain + any other regions that themselves contain the given region. + """ + return self.parents[region] + + def get_children(self, region): + """Get all regions that are contained by the given region, but are not + contained by any other descendents of the given region. + """ + return self.children[region] + + @cached_region_property("ancestors") + def get_ancestors(self, region): + """Get all regions that contain the given region, not just direct + parents. + """ + seen = set() + queue = [region] + while queue: + r = queue.pop() + for rp in self.parents[r]: + if rp not in seen: + seen.add(rp) + queue.append(rp) + return seen + + @cached_region_property("descendents") + def get_descendents(self, region): + """Get all regions that are contained by the given region, not just + direct children. + """ + seen = set() + queue = [region] + while queue: + r = queue.pop() + for rc in self.children[r]: + if rc not in seen: + seen.add(rc) + queue.append(rc) + return seen + + @cached_region_property("coparent_pairs") + def get_coparent_pairs(self, region): + """Get all regions which are direct parents of any descendant of the + given region, but not themselves descendants of the given region. + """ + # start with direct parents + coparent_pairs = [(p, region) for p in self.get_parents(region)] + + # get all descendents + rds = self.get_descendents(region) + + # exclude the region and its descendents + seen = {region, *rds} + + # for each descendant + for rd in rds: + # add only its parents... + for rdp in self.get_parents(rd): + # ... which are not themselves descendents + if rdp not in seen: + coparent_pairs.append((rdp, rd)) + seen.add(rdp) + + return coparent_pairs + + @cached_region_property("count") + def get_count(self, region): + """Get the count of the given region, i.e. the correct weighting to + apply when summing over all regions to avoid overcounting. + """ + return 1 - sum(self.get_count(a) for a in self.get_ancestors(region)) + + def get_total_count(self): + """Get the total count of all regions.""" + return sum(map(self.get_count, self.regions)) + + @cached_region_property("level") + def get_level(self, region): + """Get the level of the given region, i.e. the distance to an ancestor + with no parents. + """ + if not self.parents[region]: + return 0 + else: + return min(self.get_level(p) for p in self.get_parents(region)) - 1 + + @cached_region_property("message_parts") + def get_message_parts(self, pair): + """Get the three contribution groups for a GBP message from region + `source` to region `target`. 1. The part of region `source` that is + not part of target, i.e. the factors to include. 2. The messages that + appear in the numerator of the update equation. 3. The messages that + appear in the denominator of the update equation. + + Parameters + ---------- + source : Region + The source region, should be a parent of `target`. + target : Region + The target region, should be a child of `source`. + + Returns + ------- + factors : Region + The difference of `source` and `target`, which will include the + factors to appear in the numerator of the update equation. + pairs_mul : set[(Region, Region)] + The messages that appear in the numerator of the update equation, + after cancelling out those that appear in the denominator. + pairs_div : set[(Region, Region)] + The messages that appear in the denominator of the update equation, + after cancelling out those that appear in the numerator. + """ + source, target = pair + factors = source - target + + # we want to cancel out messages that appear in both of: + # the messages that go into the belief of region `source` + source_pairs = set(self.get_coparent_pairs(source)) + # the messages that go into the belief of region `target` + target_pairs = set(self.get_coparent_pairs(target)) + # the current message to be updated by defn appears directly in the + # update numerator, but also target belief region, so can be cancelled + target_pairs.remove((source, target)) + + pairs_mul = source_pairs - target_pairs + pairs_div = target_pairs - source_pairs + + return factors, pairs_mul, pairs_div + + def check(self): + """Run some basic consistency checks on the region graph.""" + for r, rps in self.parents.items(): + for rp in rps: + assert r.issubset(rp) + assert r in self.get_children(rp) + + for r in self.regions: + for rd in self.get_descendents(r): + assert r.issuperset(rd) + assert r in self.get_ancestors(rd) + + for ra in self.get_ancestors(r): + assert r.issubset(ra) + assert r in self.get_descendents(ra) + + rps = self.get_parents(r) + for rpa, rpb in itertools.combinations(rps, 2): + assert not rpa.issubset(rpb) + assert not rpb.issubset(rpa) + + rcs = self.get_children(r) + for rca, rcb in itertools.combinations(rcs, 2): + assert not rca.issubset(rcb) + assert not rcb.issubset(rca) + + def draw(self, pos=None, a=20, scale=1.0, radius=0.1, **drawing_opts): + from quimb.schematic import Drawing, hash_to_color + + if pos is None: + pos = {node: node for node in self.lookup} + + def get_draw_pos(coo): + return tuple(scale * s for s in pos[coo]) + + sizes = {len(r) for r in self.regions} + levelmap = {s: i for i, s in enumerate(sorted(sizes))} + centers = {} + + d = Drawing(a=a, **drawing_opts) + for region in sorted(self.regions, key=len, reverse=True): + # level = self.get_level(region) + # level = len(region) + level = levelmap[len(region)] + + coos = [(*get_draw_pos(coo), 2.0 * level) for coo in region] + + average_coo = tuple(map(sum, zip(*coos))) + centers[region] = tuple(c / len(coos) for c in average_coo) + + d.patch_around( + coos, + radius=radius, + # edgecolor=hash_to_color(str(region)), + facecolor=hash_to_color(str(region)), + alpha=1 / 3, + linestyle="", + linewidth=3, + ) + + for region in self.regions: + for child in self.get_children(region): + d.line( + centers[region], + centers[child], + linewidth=.5, + linestyle="-", + color=(.5, .5, .5), + alpha=0.5, + arrowhead={}, + ) + + return d.fig, d.ax + + def __repr__(self): + return ( + f"" + ) diff --git a/quimb/tensor/tensor_arbgeom.py b/quimb/tensor/tensor_arbgeom.py index 6197d68d..01c4d4fb 100644 --- a/quimb/tensor/tensor_arbgeom.py +++ b/quimb/tensor/tensor_arbgeom.py @@ -1926,7 +1926,7 @@ def local_expectation_loop_expansion( ------- expec : scalar """ - from quimb.experimental.belief_propagation import RegionGraph + from quimb.tensor.belief_propagation import RegionGraph info = info if info is not None else {} info.setdefault("tns", {}) @@ -2147,7 +2147,7 @@ def local_expectation_cluster_expansion( ------- expec : scalar """ - from quimb.experimental.belief_propagation import RegionGraph + from quimb.tensor.belief_propagation import RegionGraph info = info if info is not None else {} info.setdefault("tns", {}) @@ -2277,7 +2277,7 @@ def norm_cluster_expansion( """Compute the norm of this tensor network by expanding it in terms of clusters of tensors. """ - from quimb.experimental.belief_propagation import RegionGraph + from quimb.tensor.belief_propagation import RegionGraph if isinstance(clusters, int): max_cluster_size = clusters diff --git a/quimb/tensor/tensor_arbgeom_compress.py b/quimb/tensor/tensor_arbgeom_compress.py index 6dbdbcce..3ffca4e3 100644 --- a/quimb/tensor/tensor_arbgeom_compress.py +++ b/quimb/tensor/tensor_arbgeom_compress.py @@ -453,13 +453,13 @@ def tensor_network_ag_compress_l2bp( Whether to perform the compression inplace. compress_opts Supplied to - :func:`~quimb.experimental.belief_propagation.l2bp.compress_l2bp`. + :func:`~quimb.tensor.belief_propagation.l2bp.compress_l2bp`. Returns ------- TensorNetwork """ - from quimb.experimental.belief_propagation.l2bp import compress_l2bp + from quimb.tensor.belief_propagation.l2bp import compress_l2bp if not canonize: compress_opts.setdefault("max_iterations", 1) diff --git a/tests/test_tensor/test_belief_propagation/test_d1bp.py b/tests/test_tensor/test_belief_propagation/test_d1bp.py index 9d23b87a..24ee7644 100644 --- a/tests/test_tensor/test_belief_propagation/test_d1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_d1bp.py @@ -2,10 +2,7 @@ import quimb as qu import quimb.tensor as qtn -from quimb.experimental.belief_propagation import ( - D1BP, - contract_d1bp, -) +import quimb.tensor.belief_propagation as qbp @pytest.mark.parametrize("local_convergence", [False, True]) @@ -13,7 +10,7 @@ def test_contract_tree_exact(local_convergence): tn = qtn.TN_rand_tree(20, 3) Z = tn.contract() info = {} - Z_bp = contract_d1bp( + Z_bp = qbp.contract_d1bp( tn, info=info, local_convergence=local_convergence, progbar=True ) assert info["converged"] @@ -26,7 +23,7 @@ def test_contract_normal(damping, diis): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() info = {} - Z_bp = contract_d1bp( + Z_bp = qbp.contract_d1bp( tn, damping=damping, diis=diis, info=info, progbar=True ) assert info["converged"] @@ -36,7 +33,7 @@ def test_contract_normal(damping, diis): def test_get_gauged_tn(): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() - bp = D1BP(tn) + bp = qbp.D1BP(tn) bp.run() Zbp = bp.contract() assert Z == pytest.approx(Zbp, rel=1e-1) diff --git a/tests/test_tensor/test_belief_propagation/test_d2bp.py b/tests/test_tensor/test_belief_propagation/test_d2bp.py index e6c8afcf..e6ecf9bf 100644 --- a/tests/test_tensor/test_belief_propagation/test_d2bp.py +++ b/tests/test_tensor/test_belief_propagation/test_d2bp.py @@ -1,11 +1,7 @@ import pytest import quimb.tensor as qtn -from quimb.experimental.belief_propagation.d2bp import ( - compress_d2bp, - contract_d2bp, - sample_d2bp, -) +import quimb.tensor.belief_propagation as qbp @pytest.mark.parametrize("damping", [0.0, 0.1]) @@ -16,7 +12,7 @@ def test_contract(damping, dtype, diis): # normalize exactly peps /= (peps.H @ peps) ** 0.5 info = {} - N_ap = contract_d2bp( + N_ap = qbp.contract_d2bp( peps, damping=damping, diis=diis, info=info, progbar=True ) assert info["converged"] @@ -29,7 +25,7 @@ def test_tree_exact(dtype, local_convergence): psi = qtn.TN_rand_tree(20, 3, 2, dtype=dtype, seed=42) norm2 = psi.H @ psi info = {} - norm2_bp = contract_d2bp( + norm2_bp = qbp.contract_d2bp( psi, info=info, local_convergence=local_convergence, progbar=True ) assert info["converged"] @@ -46,7 +42,7 @@ def test_compress(damping, dtype, diis): peps_c1 = peps.compress_all(max_bond=2) info = {} peps_c2 = peps.copy() - compress_d2bp( + qbp.compress_d2bp( peps_c2, max_bond=2, damping=damping, @@ -67,7 +63,7 @@ def test_sample(dtype): peps = qtn.PEPS.rand(3, 4, 3, seed=42, dtype=dtype) # normalize exactly peps /= (peps.H @ peps) ** 0.5 - config, peps_config, omega = sample_d2bp(peps, seed=42, progbar=True) + config, peps_config, omega = qbp.sample_d2bp(peps, seed=42, progbar=True) assert all(ix in config for ix in peps.site_inds) assert 0.0 < omega < 1.0 assert peps_config.outer_inds() == () @@ -75,7 +71,7 @@ def test_sample(dtype): ptotal = 0.0 nrepeat = 4 for _ in range(nrepeat): - _, peps_config, _ = sample_d2bp(peps, seed=42, progbar=True) + _, peps_config, _ = qbp.sample_d2bp(peps, seed=42, progbar=True) ptotal += abs(peps_config.contract()) ** 2 # check we are doing better than random guessing diff --git a/tests/test_tensor/test_belief_propagation/test_hd1bp.py b/tests/test_tensor/test_belief_propagation/test_hd1bp.py index 2c010a5e..5f71b17d 100644 --- a/tests/test_tensor/test_belief_propagation/test_hd1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_hd1bp.py @@ -2,18 +2,14 @@ import quimb as qu import quimb.tensor as qtn -from quimb.experimental.belief_propagation.hd1bp import ( - HD1BP, - contract_hd1bp, - sample_hd1bp, -) +import quimb.tensor.belief_propagation as qbp @pytest.mark.parametrize("damping", [0.0, 0.1]) def test_contract_hyper(damping): htn = qtn.HTN_random_ksat(3, 50, alpha=2.0, seed=42, mode="dense") info = {} - num_solutions = contract_hd1bp( + num_solutions = qbp.contract_hd1bp( htn, damping=damping, info=info, progbar=True ) assert info["converged"] @@ -25,7 +21,7 @@ def test_contract_tree_exact(normalize): tn = qtn.TN_rand_tree(20, 3) Z = tn.contract() info = {} - Z_bp = contract_hd1bp( + Z_bp = qbp.contract_hd1bp( tn, info=info, normalize=normalize, @@ -41,7 +37,7 @@ def test_contract_normal(damping, diis): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() info = {} - Z_bp = contract_hd1bp( + Z_bp = qbp.contract_hd1bp( tn, damping=damping, diis=diis, info=info, progbar=True ) assert info["converged"] @@ -52,7 +48,7 @@ def test_contract_normal(damping, diis): def test_sample(damping): nvars = 20 htn = qtn.HTN_random_ksat(3, nvars, alpha=2.0, seed=42, mode="dense") - config, tn_config, omega = sample_hd1bp( + config, tn_config, omega = qbp.sample_hd1bp( htn, damping=damping, seed=42, progbar=True ) assert len(config) == nvars @@ -64,7 +60,7 @@ def test_sample(damping): def test_get_gauged_tn(): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() - bp = HD1BP(tn) + bp = qbp.HD1BP(tn) bp.run() Zbp = bp.contract() assert Z == pytest.approx(Zbp, rel=1e-1) diff --git a/tests/test_tensor/test_belief_propagation/test_hv1bp.py b/tests/test_tensor/test_belief_propagation/test_hv1bp.py index 4b232666..fedee99a 100644 --- a/tests/test_tensor/test_belief_propagation/test_hv1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_hv1bp.py @@ -2,10 +2,7 @@ import quimb as qu import quimb.tensor as qtn -from quimb.experimental.belief_propagation.hv1bp import ( - contract_hv1bp, - sample_hv1bp, -) +import quimb.tensor.belief_propagation as qbp @pytest.mark.parametrize("damping", [0.0, 0.1, 0.5]) @@ -13,7 +10,7 @@ def test_contract_hyper(damping, diis): htn = qtn.HTN_random_ksat(3, 50, alpha=2.0, seed=42, mode="dense") info = {} - num_solutions = contract_hv1bp( + num_solutions = qbp.contract_hv1bp( htn, damping=damping, diis=diis, info=info, progbar=True ) assert info["converged"] @@ -31,7 +28,7 @@ def test_contract_tree_exact(messages): def messages(shape): return qu.randn(shape, dist="uniform") - Z_bp = contract_hv1bp(tn, messages=messages, info=info, progbar=True) + Z_bp = qbp.contract_hv1bp(tn, messages=messages, info=info, progbar=True) assert info["converged"] assert Z == pytest.approx(Z_bp, rel=1e-12) @@ -42,7 +39,7 @@ def test_contract_normal(damping, diis): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() info = {} - Z_bp = contract_hv1bp( + Z_bp = qbp.contract_hv1bp( tn, damping=damping, diis=diis, info=info, progbar=True ) assert info["converged"] @@ -53,7 +50,7 @@ def test_contract_normal(damping, diis): def test_sample(damping): nvars = 20 htn = qtn.HTN_random_ksat(3, nvars, alpha=2.0, seed=42, mode="dense") - config, tn_config, omega = sample_hv1bp( + config, tn_config, omega = qbp.sample_hv1bp( htn, damping=damping, seed=42, progbar=True ) assert len(config) == nvars diff --git a/tests/test_tensor/test_belief_propagation/test_l1bp.py b/tests/test_tensor/test_belief_propagation/test_l1bp.py index 697b919b..eb7ca779 100644 --- a/tests/test_tensor/test_belief_propagation/test_l1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_l1bp.py @@ -2,8 +2,7 @@ import quimb as qu import quimb.tensor as qtn -from quimb.experimental.belief_propagation.l1bp import contract_l1bp -from quimb.experimental.belief_propagation.d2bp import contract_d2bp +import quimb.tensor.belief_propagation as qbp @pytest.mark.parametrize("dtype", ["float32", "complex64"]) @@ -13,7 +12,7 @@ def test_contract_tree_exact(dtype, local_convergence, normalize): tn = qtn.TN_rand_tree(10, 3, seed=42, dtype=dtype) Z_ex = tn.contract() info = {} - Z_bp = contract_l1bp( + Z_bp = qbp.contract_l1bp( tn, info=info, normalize=normalize, @@ -31,7 +30,7 @@ def test_contract_loopy_approx(dtype, damping, diis): tn = qtn.TN2D_rand(3, 4, 5, dtype=dtype, dist="uniform") Z_ex = tn.contract() info = {} - Z_bp = contract_l1bp( + Z_bp = qbp.contract_l1bp( tn, damping=damping, diis=diis, @@ -50,7 +49,7 @@ def test_contract_double_loopy_approx(dtype, damping, update): tn = peps.H & peps Z_ex = tn.contract() info = {} - Z_bp1 = contract_l1bp( + Z_bp1 = qbp.contract_l1bp( tn, damping=damping, update=update, @@ -60,7 +59,7 @@ def test_contract_double_loopy_approx(dtype, damping, update): assert info["converged"] assert Z_bp1 == pytest.approx(Z_ex, rel=0.3) # compare with 2-norm BP on the peps directly - Z_bp2 = contract_d2bp(peps) + Z_bp2 = qbp.contract_d2bp(peps) assert Z_bp1 == pytest.approx(Z_bp2, rel=5e-5) @@ -94,7 +93,7 @@ def test_contract_tree_triple_sandwich_exact(dtype): tn = bra.H | op | ket Z_ex = tn.contract() info = {} - Z_bp = contract_l1bp(tn, info=info, progbar=True) + Z_bp = qbp.contract_l1bp(tn, info=info, progbar=True) assert info["converged"] assert Z_ex == pytest.approx(Z_bp, rel=5e-6) @@ -120,7 +119,7 @@ def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping, diis): tn = ket.H | G_ket Z_ex = tn.contract() info = {} - Z_bp = contract_l1bp( + Z_bp = qbp.contract_l1bp( tn, damping=damping, diis=diis, @@ -134,7 +133,7 @@ def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping, diis): def test_contract_cluster_approx(): tn = qtn.TN2D_classical_ising_partition_function(8, 8, 0.4, h=0.2) f_ex = qu.log(tn.contract()) - f_bp = qu.log(contract_l1bp(tn)) + f_bp = qu.log(qbp.contract_l1bp(tn)) assert f_bp == pytest.approx(f_ex, rel=0.3) cluster_tags = [] for i in range(0, 8, 2): @@ -147,7 +146,7 @@ def test_contract_cluster_approx(): cluster_tags.append(cluster_tag) info = {} f_bp2 = qu.log( - contract_l1bp(tn, site_tags=cluster_tags, info=info, progbar=True) + qbp.contract_l1bp(tn, site_tags=cluster_tags, info=info, progbar=True) ) assert info["converged"] assert f_bp == pytest.approx(f_ex, rel=0.1) @@ -161,7 +160,7 @@ def test_mps(): psiG = psi.copy() psiG.gate_(qu.pauli("X"), 5, contract=True) expec = psi.H & psiG - O = contract_l1bp( + O = qbp.contract_l1bp( expec, site_tags=[f"I{i}" for i in range(L)], ) diff --git a/tests/test_tensor/test_belief_propagation/test_l2bp.py b/tests/test_tensor/test_belief_propagation/test_l2bp.py index 668df51e..f82576cd 100644 --- a/tests/test_tensor/test_belief_propagation/test_l2bp.py +++ b/tests/test_tensor/test_belief_propagation/test_l2bp.py @@ -1,10 +1,7 @@ import pytest import quimb.tensor as qtn -from quimb.experimental.belief_propagation.l2bp import ( - compress_l2bp, - contract_l2bp, -) +import quimb.tensor.belief_propagation as qbp @pytest.mark.parametrize("dtype", ["float32", "complex64"]) @@ -12,7 +9,7 @@ def test_contract_tree_exact(dtype): psi = qtn.TN_rand_tree(20, 3, 2, dtype=dtype) norm2 = psi.H @ psi info = {} - norm2_bp = contract_l2bp(psi, info=info, progbar=True) + norm2_bp = qbp.contract_l2bp(psi, info=info, progbar=True) assert info["converged"] assert norm2_bp == pytest.approx(norm2, rel=1e-5) @@ -22,7 +19,7 @@ def test_contract_loopy_approx(dtype): peps = qtn.PEPS.rand(3, 4, 3, dtype=dtype, seed=42) norm_ex = peps.H @ peps info = {} - norm_bp = contract_l2bp(peps, damping=0.1, info=info, progbar=True) + norm_bp = qbp.contract_l2bp(peps, damping=0.1, info=info, progbar=True) assert info["converged"] assert norm_bp == pytest.approx(norm_ex, rel=0.2) @@ -35,7 +32,7 @@ def test_compress_loopy(damping, dtype): # local, naive compression scheme peps_c1 = peps.compress_all(max_bond=2) info = {} - peps_c2 = compress_l2bp( + peps_c2 = qbp.compress_l2bp( peps, max_bond=2, damping=damping, info=info, progbar=True ) assert info["converged"] @@ -61,7 +58,7 @@ def test_contract_double_layer_tree_exact(dtype): norm_ex = tn.H @ tn info = {} - norm_bp = contract_l2bp(tn, info=info, progbar=True) + norm_bp = qbp.contract_l2bp(tn, info=info, progbar=True) assert info["converged"] assert norm_bp == pytest.approx(norm_ex, rel=1e-6) @@ -85,7 +82,7 @@ def test_compress_double_layer_loopy(dtype, damping, update): # compress using BP info = {} - tn_bp = compress_l2bp( + tn_bp = qbp.compress_l2bp( tn_lazy, max_bond=3, damping=damping,