Skip to content

Commit

Permalink
HV1BP: vectorized contraction and initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 30, 2025
1 parent 0958058 commit aaafbed
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 35 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Release notes for `quimb`.
- belief propagation, implement DIIS (direct inversion in the iterative subspace)
- belief propagation, unify various aspects such as message normalization and distance.
- belief propagation, add a `plot` method.
- belief propagation, add a `contract_every` option.
- HV1BP: vectorize both contraction and message initialization
- add `qu.plot_multi_series_zoom` for plotting multiple series with a zoomed inset, useful for various convergence plots such as BP

**Bug fixes:**
Expand Down
44 changes: 41 additions & 3 deletions quimb/experimental/belief_propagation/bp_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class BeliefPropagationCommon:
norms. 'L2phased' is like 'L2' but also normalizes the phases of the
messages, by default used for complex dtypes if phased normalization is
not already being used.
contract_every : int, optional
If not None, 'contract' (via BP) the tensor network every
``contract_every`` iterations. The resulting values are stored in
``zvals`` at corresponding points ``zval_its``.
inplace : bool, optional
Whether to perform any operations inplace on the input tensor network.
"""
Expand All @@ -83,6 +87,7 @@ def __init__(
update="sequential",
normalize=None,
distance=None,
contract_every=None,
inplace=False,
):
self.tn = tn if inplace else tn.copy()
Expand All @@ -109,9 +114,13 @@ def __init__(
distance = "L2"
self.distance = distance

self.contract_every = contract_every
self.n = 0
self.converged = False
self.mdiffs = []
self.rdiffs = []
self.zval_its = []
self.zvals = []

@property
def damping(self):
Expand All @@ -124,10 +133,10 @@ def damping(self, damping):
else:
self._damping = damping

def damp(old, new):
def fn_damping(old, new):
return damping * old + (1 - damping) * new

self.fn_damping = damp
self.fn_damping = fn_damping

@property
def normalize(self):
Expand Down Expand Up @@ -247,6 +256,16 @@ def _distance_fn(x, y):
self._distance = distance
self._distance_fn = _distance_fn

def _maybe_contract(self):
should_contract = (
(self.contract_every is not None)
and (self.n % self.contract_every == 0)
and ((not self.zval_its) or (self.zval_its[-1] != self.n))
)
if should_contract:
self.zval_its.append(self.n)
self.zvals.append(self.contract())

def run(
self,
max_iterations=1000,
Expand Down Expand Up @@ -311,6 +330,8 @@ def run(
rdm = RollingDiffMean()
self.converged = False
while not self.converged and it < max_iterations:
self._maybe_contract()

# perform a single iteration of BP
# we supply tol here for use with local convergence
result = self.iterate(tol=tol)
Expand Down Expand Up @@ -350,6 +371,8 @@ def run(
it += 1
self.n += 1

self._maybe_contract()

# finally:
if pbar is not None:
pbar.close()
Expand All @@ -368,10 +391,15 @@ def run(
info["max_mdiff"] = max_mdiff
info["rolling_abs_mean_diff"] = rdm.absmeandiff()

def plot(self, **kwargs):
def plot(self, zvals_yscale="asinh", **kwargs):
from quimb import plot_multi_series_zoom

data = {
"zvals": {
"x": self.zval_its,
"y": self.zvals,
"yscale": zvals_yscale,
},
"mdiffs": self.mdiffs,
"rdiffs": self.rdiffs,
}
Expand All @@ -381,6 +409,16 @@ def plot(self, **kwargs):
kwargs.setdefault("yscale", "log")
return plot_multi_series_zoom(data, **kwargs)

@property
def mdiff(self):
try:
return self.mdiffs[-1]
except IndexError:
return float("nan")

def __repr__(self):
return f"{self.__class__.__name__}(n={self.n}, mdiff={self.mdiff:.3g})"


def initialize_hyper_messages(
tn,
Expand Down
10 changes: 8 additions & 2 deletions quimb/experimental/belief_propagation/d1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ class D1BP(BeliefPropagationCommon):
local_convergence : bool, optional
Whether to allow messages to locally converge - i.e. if all their
input messages have converged then stop updating them.
fill_fn : callable, optional
If specified, use this function to fill in the initial messages.
contract_every : int, optional
If not None, 'contract' (via BP) the tensor network every
``contract_every`` iterations. The resulting values are stored in
``zvals`` at corresponding points ``zval_its``.
inplace : bool, optional
Whether to perform any operations inplace on the input tensor network.
Attributes
----------
Expand All @@ -115,6 +119,7 @@ def __init__(
distance=None,
local_convergence=True,
message_init_function=None,
contract_every=None,
inplace=False,
):
super().__init__(
Expand All @@ -123,6 +128,7 @@ def __init__(
update=update,
normalize=normalize,
distance=distance,
contract_every=contract_every,
inplace=inplace,
)

Expand Down
10 changes: 9 additions & 1 deletion quimb/experimental/belief_propagation/d2bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ class D2BP(BeliefPropagationCommon):
local_convergence : bool, optional
Whether to allow messages to locally converge - i.e. if all their
input messages have converged then stop updating them.
contract_every : int, optional
If not None, 'contract' (via BP) the tensor network every
``contract_every`` iterations. The resulting values are stored in
``zvals`` at corresponding points ``zval_its``.
inplace : bool, optional
Whether to perform any operations inplace on the input tensor network.
contract_opts
Other options supplied to ``cotengra.array_contract``.
"""
Expand All @@ -82,8 +88,9 @@ def __init__(
update="sequential",
normalize=None,
distance=None,
inplace=False,
local_convergence=True,
contract_every=None,
inplace=False,
**contract_opts,
):
super().__init__(
Expand All @@ -92,6 +99,7 @@ def __init__(
update=update,
normalize=normalize,
distance=distance,
contract_every=contract_every,
inplace=inplace,
)

Expand Down
100 changes: 77 additions & 23 deletions quimb/experimental/belief_propagation/hv1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,12 @@ class HV1BP(BeliefPropagationCommon):
thread_pool : bool or int, optional
Whether to use a thread pool for parallelization, if ``True`` use the
default number of threads, if an integer use that many threads.
contract_every : int, optional
If not None, 'contract' (via BP) the tensor network every
``contract_every`` iterations. The resulting values are stored in
``zvals`` at corresponding points ``zval_its``.
inplace : bool, optional
Whether to perform any operations inplace on the input tensor network.
"""

def __init__(
Expand All @@ -360,16 +366,18 @@ def __init__(
update="parallel",
normalize="L2",
distance="L2",
inplace=False,
smudge_factor=1e-12,
thread_pool=False,
contract_every=None,
inplace=False,
):
super().__init__(
tn,
damping=damping,
update=update,
normalize=normalize,
distance=distance,
contract_every=contract_every,
inplace=inplace,
)

Expand Down Expand Up @@ -458,14 +466,26 @@ def _distance_fn(bx, by):
self._distance_fn = _distance_fn

def initialize_messages_batched(self, messages=None):
if messages is None:
# XXX: explicit use uniform distribution to avoid non-vectorized
# contractions?
messages = initialize_hyper_messages(self.tn)

_stack = ar.get_lib_fn(self.backend, "stack")
_array = ar.get_lib_fn(self.backend, "array")

if isinstance(messages, dict):
# 'dense' (i.e. non-batch) messages explicitly supplied
message_init_fn = None
elif callable(messages):
# custom message initialization function
message_init_fn = messages
messages = None
elif messages == "dense":
# explicitly create dense messages first
message_init_fn = None
messages = initialize_hyper_messages(self.tn)
elif messages is None:
# default to uniform messages
message_init_fn = ar.get_lib_fn(self.backend, "ones")
else:
raise ValueError(f"Unrecognized messages={messages}")

# here we are stacking all contractions with matching rank
#
# rank: number of incident messages to a tensor or hyper index
Expand All @@ -479,27 +499,37 @@ def initialize_messages_batched(self, messages=None):
batched_inputs_m = {}
input_locs_m = {}
output_locs_m = {}
shapes_m = {}

for ix, tids in self.tn.ind_map.items():
# all updates of the same rank can be performed simultaneously
rank = len(tids)
try:
batch = batched_inputs_m[rank]
shape = shapes_m[rank]
except KeyError:
batch = batched_inputs_m[rank] = [[] for _ in range(rank)]
shape = shapes_m[rank] = [rank, 0, self.tn.ind_size(ix)]

# batch index
b = shape[1]
for p, tid in enumerate(tids):
batch_p = batch[p]
# position in the stack
b = len(batch_p)
if message_init_fn is None:
# we'll stack the messages later
batch[p].append(messages[tid, ix])
input_locs_m[tid, ix] = (rank, p, b)
output_locs_m[ix, tid] = (rank, p, b)
batch_p.append(messages[tid, ix])

# increment batch index
shape[1] = b + 1

# prepare tensor messages
batched_tensors = {}
batched_inputs_t = {}
input_locs_t = {}
output_locs_t = {}
shapes_t = {}

for tid, t in self.tn.tensor_map.items():
# all updates of the same rank can be performed simultaneously
rank = t.ndim
Expand All @@ -510,26 +540,39 @@ def initialize_messages_batched(self, messages=None):
try:
batch = batched_inputs_t[rank]
batch_t = batched_tensors[rank]
shape = shapes_t[rank]
except KeyError:
batch = batched_inputs_t[rank] = [[] for _ in range(rank)]
batch_t = batched_tensors[rank] = []
shape = shapes_t[rank] = [rank, 0, t.shape[0]]

# batch index
b = shape[1]
for p, ix in enumerate(t.inds):
batch_p = batch[p]
# position in the stack
b = len(batch_p)
if message_init_fn is None:
# we'll stack the messages later
batch[p].append(messages[ix, tid])
input_locs_t[ix, tid] = (rank, p, b)
output_locs_t[tid, ix] = (rank, p, b)
batch_p.append(messages[ix, tid])

batch_t.append(t.data)
# increment batch index
shape[1] = b + 1

# combine or create batch message arrays
for batched_inputs, shapes in zip(
(batched_inputs_m, batched_inputs_t), (shapes_m, shapes_t)
):
for rank, batch in batched_inputs.items():
if isinstance(messages, dict):
# stack given messages into single arrays
batched_inputs[rank] = _stack(
tuple(_stack(batch_p) for batch_p in batch)
)
else:
# create message arrays directly
batched_inputs[rank] = message_init_fn(shapes[rank])

# stack messages in into single arrays
for batched_inputs in (batched_inputs_m, batched_inputs_t):
for key, batch in batched_inputs.items():
batched_inputs[key] = _stack(
tuple(_stack(batch_p) for batch_p in batch)
)
# stack all tensors of each rank into a single array
for rank, tensors in batched_tensors.items():
batched_tensors[rank] = _stack(tensors)
Expand Down Expand Up @@ -671,7 +714,7 @@ def iterate(self, tol=None):
)
return max(t_max_mdiff, m_max_mdiff)

def get_messages(self):
def get_messages_dense(self):
"""Get messages in individual form from the batched stacks."""
return _extract_messages_from_inputs_batched(
self.batched_inputs_m,
Expand All @@ -680,6 +723,17 @@ def get_messages(self):
self.input_locs_t,
)

def get_messages(self):
import warnings

warnings.warn(
"get_messages() is deprecated, or in the future it might return "
"the batch messages, use get_messages_dense() instead.",
DeprecationWarning,
)

return self.get_messages_dense()

def contract(self, strip_exponent=False, check_zero=False):
"""Estimate the contraction of the tensor network using the current
messages. Uses batched vectorized contractions for speed.
Expand Down Expand Up @@ -751,7 +805,7 @@ def contract_dense(self, strip_exponent=False, check_zero=True):
"""
return contract_hyper_messages(
self.tn,
self.get_messages(),
self.get_messages_dense(),
strip_exponent=strip_exponent,
check_zero=check_zero,
backend=self.backend,
Expand Down Expand Up @@ -959,7 +1013,7 @@ def run_belief_propagation_hv1bp(
info=info,
progbar=progbar,
)
return bp.get_messages(), bp.converged
return bp.get_messages_dense(), bp.converged


def sample_hv1bp(
Expand Down
Loading

0 comments on commit aaafbed

Please sign in to comment.