From 3b577e1e597f3cd0a6aaf46c02405f53609f4c2a Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Wed, 29 Jan 2025 18:10:42 -0800 Subject: [PATCH] HV1BP: update interface options --- quimb/experimental/belief_propagation/diis.py | 5 +- .../experimental/belief_propagation/hv1bp.py | 157 +++++++++++++++++- .../test_belief_propagation/test_hv1bp.py | 12 +- 3 files changed, 159 insertions(+), 15 deletions(-) diff --git a/quimb/experimental/belief_propagation/diis.py b/quimb/experimental/belief_propagation/diis.py index 1a7cb18a..f9ea5d0b 100644 --- a/quimb/experimental/belief_propagation/diis.py +++ b/quimb/experimental/belief_propagation/diis.py @@ -184,7 +184,6 @@ def _extrapolate(self): coeffs = [self.scalar(c) for c in coeffs[1:]] # construct linear combination of previous guesses! - # xnew = np.zeros_like(self.guesses[0]) xnew = ar.do("zeros_like", self.guesses[0], like=self.backend) for ci, xi in zip(coeffs, self.guesses): xnew += ci * xi @@ -216,8 +215,6 @@ def update(self, y): same tree structure as `y`. """ # convert from pytree -> single real vector - # copy is important so that sequence of - # guesses are not the same object y = self.vectorizer.pack(y) x = self.guesses[self.head] if x is None: @@ -228,7 +225,7 @@ def update(self, y): xnext = self._extrapolate() self.head = (self.head + 1) % self.max_history - # # TODO: make copy backend agnostic + # NOTE: copy seems to be necessary here to avoid in-place modifications self.guesses[self.head] = ar.do("copy", xnext, like=self.backend) # convert new extrapolated guess back to pytree diff --git a/quimb/experimental/belief_propagation/hv1bp.py b/quimb/experimental/belief_propagation/hv1bp.py index 0f940b9f..dd48b16f 100644 --- a/quimb/experimental/belief_propagation/hv1bp.py +++ b/quimb/experimental/belief_propagation/hv1bp.py @@ -587,12 +587,13 @@ def get_messages(self): self.input_locs_t, ) - def contract(self, strip_exponent=False): + def contract(self, strip_exponent=False, check_zero=True): # TODO: do this in batched form directly return contract_hyper_messages( self.tn, self.get_messages(), strip_exponent=strip_exponent, + check_zero=check_zero, backend=self.backend, ) @@ -602,9 +603,16 @@ def contract_hv1bp( messages=None, max_iterations=1000, tol=5e-6, - smudge_factor=1e-12, damping=0.0, + diis=False, + update="parallel", + normalize="L2", + distance="L2", + tol_abs=None, + tol_rolling_diff=None, + smudge_factor=1e-12, strip_exponent=False, + check_zero=True, info=None, progbar=False, ): @@ -621,14 +629,46 @@ def contract_hv1bp( The maximum number of iterations to perform. tol : float, optional The convergence tolerance for messages. + damping : float, optional + The damping factor to use, 0.0 means no damping. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + update : {'parallel'}, optional + Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. smudge_factor : float, optional A small number to add to the denominator of messages to avoid division by zero. Note when this happens the numerator will also be zero. - damping : float, optional - The damping factor to use, 0.0 means no damping. strip_exponent : bool, optional Whether to strip the exponent from the final result. If ``True`` then the returned result is ``(mantissa, exponent)``. + check_zero : bool, optional + Whether to check for zero values and return zero early. info : dict, optional If specified, update this dictionary with information about the belief propagation run. @@ -643,15 +683,24 @@ def contract_hv1bp( tn, messages=messages, damping=damping, + update=update, + normalize=normalize, + distance=distance, smudge_factor=smudge_factor, ) bp.run( max_iterations=max_iterations, tol=tol, + diis=diis, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, info=info, progbar=progbar, ) - return bp.contract(strip_exponent=strip_exponent) + return bp.contract( + strip_exponent=strip_exponent, + check_zero=check_zero, + ) def run_belief_propagation_hv1bp( @@ -660,6 +709,12 @@ def run_belief_propagation_hv1bp( max_iterations=1000, tol=5e-6, damping=0.0, + diis=False, + update="parallel", + normalize="L2", + distance="L2", + tol_abs=None, + tol_rolling_diff=None, smudge_factor=1e-12, info=None, progbar=False, @@ -680,6 +735,36 @@ def run_belief_propagation_hv1bp( The convergence tolerance. damping : float, optional The damping factor to use, 0.0 means no damping. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + update : {'parallel'}, optional + Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. smudge_factor : float, optional A small number to add to the denominator of messages to avoid division by zero. Note when this happens the numerator will also be zero. @@ -697,9 +782,23 @@ def run_belief_propagation_hv1bp( Whether the algorithm converged. """ bp = HV1BP( - tn, messages=messages, damping=damping, smudge_factor=smudge_factor + tn, + messages=messages, + damping=damping, + smudge_factor=smudge_factor, + update=update, + normalize=normalize, + distance=distance, + ) + bp.run( + max_iterations=max_iterations, + tol=tol, + diis=diis, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, + info=info, + progbar=progbar, ) - bp.run(max_iterations=max_iterations, tol=tol, info=info, progbar=progbar) return bp.get_messages(), bp.converged @@ -710,6 +809,12 @@ def sample_hv1bp( max_iterations=1000, tol=1e-2, damping=0.0, + diis=False, + update="parallel", + normalize="L2", + distance="L2", + tol_abs=None, + tol_rolling_diff=None, smudge_factor=1e-12, bias=False, seed=None, @@ -732,6 +837,38 @@ def sample_hv1bp( The maximum number of iterations for each message passing run. tol : float, optional The convergence tolerance for each message passing run. + damping : float, optional + The damping factor to use, 0.0 means no damping. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + update : {'parallel'}, optional + Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. smudge_factor : float, optional A small number to add to each message to avoid zeros. Making this large is similar to adding a temperature, which can aid convergence but @@ -796,6 +933,12 @@ def sample_hv1bp( max_iterations=max_iterations, tol=tol, damping=damping, + diis=diis, + update=update, + normalize=normalize, + distance=distance, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, smudge_factor=smudge_factor, ) diff --git a/tests/test_tensor/test_belief_propagation/test_hv1bp.py b/tests/test_tensor/test_belief_propagation/test_hv1bp.py index 82932ea7..1f5fcd25 100644 --- a/tests/test_tensor/test_belief_propagation/test_hv1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_hv1bp.py @@ -9,11 +9,12 @@ @pytest.mark.parametrize("damping", [0.0, 0.1, 0.5]) -def test_contract_hyper(damping): +@pytest.mark.parametrize("diis", [False, True]) +def test_contract_hyper(damping, diis): htn = qtn.HTN_random_ksat(3, 50, alpha=2.0, seed=42, mode="dense") info = {} num_solutions = contract_hv1bp( - htn, damping=damping, info=info, progbar=True + htn, damping=damping, diis=diis, info=info, progbar=True ) assert info["converged"] assert num_solutions == pytest.approx(309273226, rel=0.1) @@ -29,11 +30,14 @@ def test_contract_tree_exact(): @pytest.mark.parametrize("damping", [0.0, 0.1, 0.5]) -def test_contract_normal(damping): +@pytest.mark.parametrize("diis", [False, True]) +def test_contract_normal(damping, diis): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() info = {} - Z_bp = contract_hv1bp(tn, damping=damping, info=info, progbar=True) + Z_bp = contract_hv1bp( + tn, damping=damping, diis=diis, info=info, progbar=True + ) assert info["converged"] assert Z == pytest.approx(Z_bp, rel=1e-1)