diff --git a/docs/changelog.md b/docs/changelog.md index 62b5bfa7..713fda7a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -11,6 +11,8 @@ Release notes for `quimb`. - belief propagation, implement DIIS (direct inversion in the iterative subspace) - belief propagation, unify various aspects such as message normalization and distance. - belief propagation, add a `plot` method. +- belief propagation, add a `contract_every` option. +- HV1BP: vectorize both contraction and message initialization - add `qu.plot_multi_series_zoom` for plotting multiple series with a zoomed inset, useful for various convergence plots such as BP **Bug fixes:** diff --git a/quimb/experimental/belief_propagation/bp_common.py b/quimb/experimental/belief_propagation/bp_common.py index d25cb277..fb52c2e3 100644 --- a/quimb/experimental/belief_propagation/bp_common.py +++ b/quimb/experimental/belief_propagation/bp_common.py @@ -71,6 +71,10 @@ class BeliefPropagationCommon: norms. 'L2phased' is like 'L2' but also normalizes the phases of the messages, by default used for complex dtypes if phased normalization is not already being used. + contract_every : int, optional + If not None, 'contract' (via BP) the tensor network every + ``contract_every`` iterations. The resulting values are stored in + ``zvals`` at corresponding points ``zval_its``. inplace : bool, optional Whether to perform any operations inplace on the input tensor network. """ @@ -83,6 +87,7 @@ def __init__( update="sequential", normalize=None, distance=None, + contract_every=None, inplace=False, ): self.tn = tn if inplace else tn.copy() @@ -109,9 +114,13 @@ def __init__( distance = "L2" self.distance = distance + self.contract_every = contract_every self.n = 0 + self.converged = False self.mdiffs = [] self.rdiffs = [] + self.zval_its = [] + self.zvals = [] @property def damping(self): @@ -124,10 +133,10 @@ def damping(self, damping): else: self._damping = damping - def damp(old, new): + def fn_damping(old, new): return damping * old + (1 - damping) * new - self.fn_damping = damp + self.fn_damping = fn_damping @property def normalize(self): @@ -247,6 +256,16 @@ def _distance_fn(x, y): self._distance = distance self._distance_fn = _distance_fn + def _maybe_contract(self): + should_contract = ( + (self.contract_every is not None) + and (self.n % self.contract_every == 0) + and ((not self.zval_its) or (self.zval_its[-1] != self.n)) + ) + if should_contract: + self.zval_its.append(self.n) + self.zvals.append(self.contract()) + def run( self, max_iterations=1000, @@ -311,6 +330,8 @@ def run( rdm = RollingDiffMean() self.converged = False while not self.converged and it < max_iterations: + self._maybe_contract() + # perform a single iteration of BP # we supply tol here for use with local convergence result = self.iterate(tol=tol) @@ -350,6 +371,8 @@ def run( it += 1 self.n += 1 + self._maybe_contract() + # finally: if pbar is not None: pbar.close() @@ -368,10 +391,15 @@ def run( info["max_mdiff"] = max_mdiff info["rolling_abs_mean_diff"] = rdm.absmeandiff() - def plot(self, **kwargs): + def plot(self, zvals_yscale="asinh", **kwargs): from quimb import plot_multi_series_zoom data = { + "zvals": { + "x": self.zval_its, + "y": self.zvals, + "yscale": zvals_yscale, + }, "mdiffs": self.mdiffs, "rdiffs": self.rdiffs, } @@ -381,6 +409,16 @@ def plot(self, **kwargs): kwargs.setdefault("yscale", "log") return plot_multi_series_zoom(data, **kwargs) + @property + def mdiff(self): + try: + return self.mdiffs[-1] + except IndexError: + return float("nan") + + def __repr__(self): + return f"{self.__class__.__name__}(n={self.n}, mdiff={self.mdiff:.3g})" + def initialize_hyper_messages( tn, diff --git a/quimb/experimental/belief_propagation/d1bp.py b/quimb/experimental/belief_propagation/d1bp.py index 89d11751..b66526a3 100644 --- a/quimb/experimental/belief_propagation/d1bp.py +++ b/quimb/experimental/belief_propagation/d1bp.py @@ -89,8 +89,12 @@ class D1BP(BeliefPropagationCommon): local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. - fill_fn : callable, optional - If specified, use this function to fill in the initial messages. + contract_every : int, optional + If not None, 'contract' (via BP) the tensor network every + ``contract_every`` iterations. The resulting values are stored in + ``zvals`` at corresponding points ``zval_its``. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. Attributes ---------- @@ -115,6 +119,7 @@ def __init__( distance=None, local_convergence=True, message_init_function=None, + contract_every=None, inplace=False, ): super().__init__( @@ -123,6 +128,7 @@ def __init__( update=update, normalize=normalize, distance=distance, + contract_every=contract_every, inplace=inplace, ) diff --git a/quimb/experimental/belief_propagation/d2bp.py b/quimb/experimental/belief_propagation/d2bp.py index b2e8680e..b7134504 100644 --- a/quimb/experimental/belief_propagation/d2bp.py +++ b/quimb/experimental/belief_propagation/d2bp.py @@ -67,6 +67,12 @@ class D2BP(BeliefPropagationCommon): local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. + contract_every : int, optional + If not None, 'contract' (via BP) the tensor network every + ``contract_every`` iterations. The resulting values are stored in + ``zvals`` at corresponding points ``zval_its``. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. contract_opts Other options supplied to ``cotengra.array_contract``. """ @@ -82,8 +88,9 @@ def __init__( update="sequential", normalize=None, distance=None, - inplace=False, local_convergence=True, + contract_every=None, + inplace=False, **contract_opts, ): super().__init__( @@ -92,6 +99,7 @@ def __init__( update=update, normalize=normalize, distance=distance, + contract_every=contract_every, inplace=inplace, ) diff --git a/quimb/experimental/belief_propagation/hv1bp.py b/quimb/experimental/belief_propagation/hv1bp.py index 481dfd50..03b2323b 100644 --- a/quimb/experimental/belief_propagation/hv1bp.py +++ b/quimb/experimental/belief_propagation/hv1bp.py @@ -349,6 +349,12 @@ class HV1BP(BeliefPropagationCommon): thread_pool : bool or int, optional Whether to use a thread pool for parallelization, if ``True`` use the default number of threads, if an integer use that many threads. + contract_every : int, optional + If not None, 'contract' (via BP) the tensor network every + ``contract_every`` iterations. The resulting values are stored in + ``zvals`` at corresponding points ``zval_its``. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. """ def __init__( @@ -360,9 +366,10 @@ def __init__( update="parallel", normalize="L2", distance="L2", - inplace=False, smudge_factor=1e-12, thread_pool=False, + contract_every=None, + inplace=False, ): super().__init__( tn, @@ -370,6 +377,7 @@ def __init__( update=update, normalize=normalize, distance=distance, + contract_every=contract_every, inplace=inplace, ) @@ -458,14 +466,26 @@ def _distance_fn(bx, by): self._distance_fn = _distance_fn def initialize_messages_batched(self, messages=None): - if messages is None: - # XXX: explicit use uniform distribution to avoid non-vectorized - # contractions? - messages = initialize_hyper_messages(self.tn) - _stack = ar.get_lib_fn(self.backend, "stack") _array = ar.get_lib_fn(self.backend, "array") + if isinstance(messages, dict): + # 'dense' (i.e. non-batch) messages explicitly supplied + message_init_fn = None + elif callable(messages): + # custom message initialization function + message_init_fn = messages + messages = None + elif messages == "dense": + # explicitly create dense messages first + message_init_fn = None + messages = initialize_hyper_messages(self.tn) + elif messages is None: + # default to uniform messages + message_init_fn = ar.get_lib_fn(self.backend, "ones") + else: + raise ValueError(f"Unrecognized messages={messages}") + # here we are stacking all contractions with matching rank # # rank: number of incident messages to a tensor or hyper index @@ -479,27 +499,37 @@ def initialize_messages_batched(self, messages=None): batched_inputs_m = {} input_locs_m = {} output_locs_m = {} + shapes_m = {} + for ix, tids in self.tn.ind_map.items(): # all updates of the same rank can be performed simultaneously rank = len(tids) try: batch = batched_inputs_m[rank] + shape = shapes_m[rank] except KeyError: batch = batched_inputs_m[rank] = [[] for _ in range(rank)] + shape = shapes_m[rank] = [rank, 0, self.tn.ind_size(ix)] + # batch index + b = shape[1] for p, tid in enumerate(tids): - batch_p = batch[p] - # position in the stack - b = len(batch_p) + if message_init_fn is None: + # we'll stack the messages later + batch[p].append(messages[tid, ix]) input_locs_m[tid, ix] = (rank, p, b) output_locs_m[ix, tid] = (rank, p, b) - batch_p.append(messages[tid, ix]) + + # increment batch index + shape[1] = b + 1 # prepare tensor messages batched_tensors = {} batched_inputs_t = {} input_locs_t = {} output_locs_t = {} + shapes_t = {} + for tid, t in self.tn.tensor_map.items(): # all updates of the same rank can be performed simultaneously rank = t.ndim @@ -510,26 +540,39 @@ def initialize_messages_batched(self, messages=None): try: batch = batched_inputs_t[rank] batch_t = batched_tensors[rank] + shape = shapes_t[rank] except KeyError: batch = batched_inputs_t[rank] = [[] for _ in range(rank)] batch_t = batched_tensors[rank] = [] + shape = shapes_t[rank] = [rank, 0, t.shape[0]] + # batch index + b = shape[1] for p, ix in enumerate(t.inds): - batch_p = batch[p] - # position in the stack - b = len(batch_p) + if message_init_fn is None: + # we'll stack the messages later + batch[p].append(messages[ix, tid]) input_locs_t[ix, tid] = (rank, p, b) output_locs_t[tid, ix] = (rank, p, b) - batch_p.append(messages[ix, tid]) batch_t.append(t.data) + # increment batch index + shape[1] = b + 1 + + # combine or create batch message arrays + for batched_inputs, shapes in zip( + (batched_inputs_m, batched_inputs_t), (shapes_m, shapes_t) + ): + for rank, batch in batched_inputs.items(): + if isinstance(messages, dict): + # stack given messages into single arrays + batched_inputs[rank] = _stack( + tuple(_stack(batch_p) for batch_p in batch) + ) + else: + # create message arrays directly + batched_inputs[rank] = message_init_fn(shapes[rank]) - # stack messages in into single arrays - for batched_inputs in (batched_inputs_m, batched_inputs_t): - for key, batch in batched_inputs.items(): - batched_inputs[key] = _stack( - tuple(_stack(batch_p) for batch_p in batch) - ) # stack all tensors of each rank into a single array for rank, tensors in batched_tensors.items(): batched_tensors[rank] = _stack(tensors) @@ -671,7 +714,7 @@ def iterate(self, tol=None): ) return max(t_max_mdiff, m_max_mdiff) - def get_messages(self): + def get_messages_dense(self): """Get messages in individual form from the batched stacks.""" return _extract_messages_from_inputs_batched( self.batched_inputs_m, @@ -680,6 +723,17 @@ def get_messages(self): self.input_locs_t, ) + def get_messages(self): + import warnings + + warnings.warn( + "get_messages() is deprecated, or in the future it might return " + "the batch messages, use get_messages_dense() instead.", + DeprecationWarning, + ) + + return self.get_messages_dense() + def contract(self, strip_exponent=False, check_zero=False): """Estimate the contraction of the tensor network using the current messages. Uses batched vectorized contractions for speed. @@ -751,7 +805,7 @@ def contract_dense(self, strip_exponent=False, check_zero=True): """ return contract_hyper_messages( self.tn, - self.get_messages(), + self.get_messages_dense(), strip_exponent=strip_exponent, check_zero=check_zero, backend=self.backend, @@ -959,7 +1013,7 @@ def run_belief_propagation_hv1bp( info=info, progbar=progbar, ) - return bp.get_messages(), bp.converged + return bp.get_messages_dense(), bp.converged def sample_hv1bp( diff --git a/quimb/experimental/belief_propagation/l1bp.py b/quimb/experimental/belief_propagation/l1bp.py index 6db25b39..694af682 100644 --- a/quimb/experimental/belief_propagation/l1bp.py +++ b/quimb/experimental/belief_propagation/l1bp.py @@ -51,6 +51,12 @@ class L1BP(BeliefPropagationCommon): input messages have converged then stop updating them. optimize : str or PathOptimizer, optional The path optimizer to use when contracting the messages. + contract_every : int, optional + If not None, 'contract' (via BP) the tensor network every + ``contract_every`` iterations. The resulting values are stored in + ``zvals`` at corresponding points ``zval_its``. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. contract_opts Other options supplied to ``cotengra.array_contract``. """ @@ -64,10 +70,11 @@ def __init__( update="sequential", normalize=None, distance=None, - inplace=False, local_convergence=True, optimize="auto-hq", message_init_function=None, + contract_every=None, + inplace=False, **contract_opts, ): super().__init__( @@ -76,6 +83,7 @@ def __init__( update=update, normalize=normalize, distance=distance, + contract_every=contract_every, inplace=inplace, ) diff --git a/quimb/experimental/belief_propagation/l2bp.py b/quimb/experimental/belief_propagation/l2bp.py index 62f2ba66..a06d03ba 100644 --- a/quimb/experimental/belief_propagation/l2bp.py +++ b/quimb/experimental/belief_propagation/l2bp.py @@ -54,8 +54,6 @@ class L2BP(BeliefPropagationCommon): norms. 'L2phased' is like 'L2' but also normalizes the phases of the messages, by default used for complex dtypes if phased normalization is not already being used. - inplace : bool, optional - Whether to perform any operations inplace on the input tensor network. symmetrize : bool or callable, optional Whether to symmetrize the messages, i.e. for each message ensure that it is hermitian with respect to its bra and ket indices. If a callable @@ -65,6 +63,12 @@ class L2BP(BeliefPropagationCommon): input messages have converged then stop updating them. optimize : str or PathOptimizer, optional The path optimizer to use when contracting the messages. + contract_every : int, optional + If not None, 'contract' (via BP) the tensor network every + ``contract_every`` iterations. The resulting values are stored in + ``zvals`` at corresponding points ``zval_its``. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. contract_opts Other options supplied to ``cotengra.array_contract``. """ @@ -78,10 +82,11 @@ def __init__( update="sequential", normalize=None, distance=None, - inplace=False, symmetrize=True, local_convergence=True, optimize="auto-hq", + contract_every=None, + inplace=False, **contract_opts, ): super().__init__( @@ -90,6 +95,7 @@ def __init__( update=update, normalize=normalize, distance=distance, + contract_every=contract_every, inplace=inplace, ) diff --git a/tests/test_tensor/test_belief_propagation/test_hv1bp.py b/tests/test_tensor/test_belief_propagation/test_hv1bp.py index 1f5fcd25..4b232666 100644 --- a/tests/test_tensor/test_belief_propagation/test_hv1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_hv1bp.py @@ -20,11 +20,18 @@ def test_contract_hyper(damping, diis): assert num_solutions == pytest.approx(309273226, rel=0.1) -def test_contract_tree_exact(): +@pytest.mark.parametrize("messages", [None, "dense", "random"]) +def test_contract_tree_exact(messages): tn = qtn.TN_rand_tree(20, 3) Z = tn.contract() info = {} - Z_bp = contract_hv1bp(tn, info=info, progbar=True) + + if messages == "random": + + def messages(shape): + return qu.randn(shape, dist="uniform") + + Z_bp = contract_hv1bp(tn, messages=messages, info=info, progbar=True) assert info["converged"] assert Z == pytest.approx(Z_bp, rel=1e-12)