Skip to content

Commit

Permalink
HV1BP: update interface options
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 30, 2025
1 parent 5a24907 commit 3b577e1
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 15 deletions.
5 changes: 1 addition & 4 deletions quimb/experimental/belief_propagation/diis.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def _extrapolate(self):
coeffs = [self.scalar(c) for c in coeffs[1:]]

# construct linear combination of previous guesses!
# xnew = np.zeros_like(self.guesses[0])
xnew = ar.do("zeros_like", self.guesses[0], like=self.backend)
for ci, xi in zip(coeffs, self.guesses):
xnew += ci * xi
Expand Down Expand Up @@ -216,8 +215,6 @@ def update(self, y):
same tree structure as `y`.
"""
# convert from pytree -> single real vector
# copy is important so that sequence of
# guesses are not the same object
y = self.vectorizer.pack(y)
x = self.guesses[self.head]
if x is None:
Expand All @@ -228,7 +225,7 @@ def update(self, y):
xnext = self._extrapolate()

self.head = (self.head + 1) % self.max_history
# # TODO: make copy backend agnostic
# NOTE: copy seems to be necessary here to avoid in-place modifications
self.guesses[self.head] = ar.do("copy", xnext, like=self.backend)

# convert new extrapolated guess back to pytree
Expand Down
157 changes: 150 additions & 7 deletions quimb/experimental/belief_propagation/hv1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,13 @@ def get_messages(self):
self.input_locs_t,
)

def contract(self, strip_exponent=False):
def contract(self, strip_exponent=False, check_zero=True):
# TODO: do this in batched form directly
return contract_hyper_messages(
self.tn,
self.get_messages(),
strip_exponent=strip_exponent,
check_zero=check_zero,
backend=self.backend,
)

Expand All @@ -602,9 +603,16 @@ def contract_hv1bp(
messages=None,
max_iterations=1000,
tol=5e-6,
smudge_factor=1e-12,
damping=0.0,
diis=False,
update="parallel",
normalize="L2",
distance="L2",
tol_abs=None,
tol_rolling_diff=None,
smudge_factor=1e-12,
strip_exponent=False,
check_zero=True,
info=None,
progbar=False,
):
Expand All @@ -621,14 +629,46 @@ def contract_hv1bp(
The maximum number of iterations to perform.
tol : float, optional
The convergence tolerance for messages.
damping : float, optional
The damping factor to use, 0.0 means no damping.
diis : bool or dict, optional
Whether to use direct inversion in the iterative subspace to
help converge the messages by extrapolating to low error guesses.
If a dict, should contain options for the DIIS algorithm. The
relevant options are {`max_history`, `beta`, `rcond`}.
update : {'parallel'}, optional
Whether to update messages sequentially or in parallel.
normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional
How to normalize messages after each update. If None choose
automatically. If a callable, it should take a message and return the
normalized message. If a string, it should be one of 'L1', 'L2',
'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2'
but also normalizes the phase of the message, by default used for
complex dtypes.
distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional
How to compute the distance between messages to check for convergence.
If None choose automatically. If a callable, it should take two
messages and return the distance. If a string, it should be one of
'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding
norms. 'L2phased' is like 'L2' but also normalizes the phases of the
messages, by default used for complex dtypes if phased normalization is
not already being used.
tol_abs : float, optional
The absolute convergence tolerance for maximum message update
distance, if not given then taken as ``tol``.
tol_rolling_diff : float, optional
The rolling mean convergence tolerance for maximum message update
distance, if not given then taken as ``tol``. This is used to stop
running when the messages are just bouncing around the same level,
without any overall upward or downward trends, roughly speaking.
smudge_factor : float, optional
A small number to add to the denominator of messages to avoid division
by zero. Note when this happens the numerator will also be zero.
damping : float, optional
The damping factor to use, 0.0 means no damping.
strip_exponent : bool, optional
Whether to strip the exponent from the final result. If ``True``
then the returned result is ``(mantissa, exponent)``.
check_zero : bool, optional
Whether to check for zero values and return zero early.
info : dict, optional
If specified, update this dictionary with information about the
belief propagation run.
Expand All @@ -643,15 +683,24 @@ def contract_hv1bp(
tn,
messages=messages,
damping=damping,
update=update,
normalize=normalize,
distance=distance,
smudge_factor=smudge_factor,
)
bp.run(
max_iterations=max_iterations,
tol=tol,
diis=diis,
tol_abs=tol_abs,
tol_rolling_diff=tol_rolling_diff,
info=info,
progbar=progbar,
)
return bp.contract(strip_exponent=strip_exponent)
return bp.contract(
strip_exponent=strip_exponent,
check_zero=check_zero,
)


def run_belief_propagation_hv1bp(
Expand All @@ -660,6 +709,12 @@ def run_belief_propagation_hv1bp(
max_iterations=1000,
tol=5e-6,
damping=0.0,
diis=False,
update="parallel",
normalize="L2",
distance="L2",
tol_abs=None,
tol_rolling_diff=None,
smudge_factor=1e-12,
info=None,
progbar=False,
Expand All @@ -680,6 +735,36 @@ def run_belief_propagation_hv1bp(
The convergence tolerance.
damping : float, optional
The damping factor to use, 0.0 means no damping.
diis : bool or dict, optional
Whether to use direct inversion in the iterative subspace to
help converge the messages by extrapolating to low error guesses.
If a dict, should contain options for the DIIS algorithm. The
relevant options are {`max_history`, `beta`, `rcond`}.
update : {'parallel'}, optional
Whether to update messages sequentially or in parallel.
normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional
How to normalize messages after each update. If None choose
automatically. If a callable, it should take a message and return the
normalized message. If a string, it should be one of 'L1', 'L2',
'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2'
but also normalizes the phase of the message, by default used for
complex dtypes.
distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional
How to compute the distance between messages to check for convergence.
If None choose automatically. If a callable, it should take two
messages and return the distance. If a string, it should be one of
'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding
norms. 'L2phased' is like 'L2' but also normalizes the phases of the
messages, by default used for complex dtypes if phased normalization is
not already being used.
tol_abs : float, optional
The absolute convergence tolerance for maximum message update
distance, if not given then taken as ``tol``.
tol_rolling_diff : float, optional
The rolling mean convergence tolerance for maximum message update
distance, if not given then taken as ``tol``. This is used to stop
running when the messages are just bouncing around the same level,
without any overall upward or downward trends, roughly speaking.
smudge_factor : float, optional
A small number to add to the denominator of messages to avoid division
by zero. Note when this happens the numerator will also be zero.
Expand All @@ -697,9 +782,23 @@ def run_belief_propagation_hv1bp(
Whether the algorithm converged.
"""
bp = HV1BP(
tn, messages=messages, damping=damping, smudge_factor=smudge_factor
tn,
messages=messages,
damping=damping,
smudge_factor=smudge_factor,
update=update,
normalize=normalize,
distance=distance,
)
bp.run(
max_iterations=max_iterations,
tol=tol,
diis=diis,
tol_abs=tol_abs,
tol_rolling_diff=tol_rolling_diff,
info=info,
progbar=progbar,
)
bp.run(max_iterations=max_iterations, tol=tol, info=info, progbar=progbar)
return bp.get_messages(), bp.converged


Expand All @@ -710,6 +809,12 @@ def sample_hv1bp(
max_iterations=1000,
tol=1e-2,
damping=0.0,
diis=False,
update="parallel",
normalize="L2",
distance="L2",
tol_abs=None,
tol_rolling_diff=None,
smudge_factor=1e-12,
bias=False,
seed=None,
Expand All @@ -732,6 +837,38 @@ def sample_hv1bp(
The maximum number of iterations for each message passing run.
tol : float, optional
The convergence tolerance for each message passing run.
damping : float, optional
The damping factor to use, 0.0 means no damping.
diis : bool or dict, optional
Whether to use direct inversion in the iterative subspace to
help converge the messages by extrapolating to low error guesses.
If a dict, should contain options for the DIIS algorithm. The
relevant options are {`max_history`, `beta`, `rcond`}.
update : {'parallel'}, optional
Whether to update messages sequentially or in parallel.
normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional
How to normalize messages after each update. If None choose
automatically. If a callable, it should take a message and return the
normalized message. If a string, it should be one of 'L1', 'L2',
'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2'
but also normalizes the phase of the message, by default used for
complex dtypes.
distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional
How to compute the distance between messages to check for convergence.
If None choose automatically. If a callable, it should take two
messages and return the distance. If a string, it should be one of
'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding
norms. 'L2phased' is like 'L2' but also normalizes the phases of the
messages, by default used for complex dtypes if phased normalization is
not already being used.
tol_abs : float, optional
The absolute convergence tolerance for maximum message update
distance, if not given then taken as ``tol``.
tol_rolling_diff : float, optional
The rolling mean convergence tolerance for maximum message update
distance, if not given then taken as ``tol``. This is used to stop
running when the messages are just bouncing around the same level,
without any overall upward or downward trends, roughly speaking.
smudge_factor : float, optional
A small number to add to each message to avoid zeros. Making this large
is similar to adding a temperature, which can aid convergence but
Expand Down Expand Up @@ -796,6 +933,12 @@ def sample_hv1bp(
max_iterations=max_iterations,
tol=tol,
damping=damping,
diis=diis,
update=update,
normalize=normalize,
distance=distance,
tol_abs=tol_abs,
tol_rolling_diff=tol_rolling_diff,
smudge_factor=smudge_factor,
)

Expand Down
12 changes: 8 additions & 4 deletions tests/test_tensor/test_belief_propagation/test_hv1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@


@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5])
def test_contract_hyper(damping):
@pytest.mark.parametrize("diis", [False, True])
def test_contract_hyper(damping, diis):
htn = qtn.HTN_random_ksat(3, 50, alpha=2.0, seed=42, mode="dense")
info = {}
num_solutions = contract_hv1bp(
htn, damping=damping, info=info, progbar=True
htn, damping=damping, diis=diis, info=info, progbar=True
)
assert info["converged"]
assert num_solutions == pytest.approx(309273226, rel=0.1)
Expand All @@ -29,11 +30,14 @@ def test_contract_tree_exact():


@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5])
def test_contract_normal(damping):
@pytest.mark.parametrize("diis", [False, True])
def test_contract_normal(damping, diis):
tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2)
Z = tn.contract()
info = {}
Z_bp = contract_hv1bp(tn, damping=damping, info=info, progbar=True)
Z_bp = contract_hv1bp(
tn, damping=damping, diis=diis, info=info, progbar=True
)
assert info["converged"]
assert Z == pytest.approx(Z_bp, rel=1e-1)

Expand Down

0 comments on commit 3b577e1

Please sign in to comment.