Skip to content

Commit

Permalink
HV1BP: vectorized contraction
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 30, 2025
1 parent 3b577e1 commit 0958058
Showing 1 changed file with 206 additions and 46 deletions.
252 changes: 206 additions & 46 deletions quimb/experimental/belief_propagation/hv1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ def _compute_all_tensor_messages_tree_batched(bx, bm):

# contract the right messages to get new left array
xl = array_contract(
arrays=(bx, *(mr[i] for i in range(mr.shape[0]))),
arrays=(bx, *(mr[p] for p in range(mr.shape[0]))),
inputs=((-1, *js), *((-1, j) for j in jr)),
output=(-1, *jl),
backend=backend,
)

# contract the left messages to get new right array
xr = array_contract(
arrays=(bx, *(ml[i] for i in range(ml.shape[0]))),
arrays=(bx, *(ml[p] for p in range(ml.shape[0]))),
inputs=((-1, *js), *((-1, j) for j in jl)),
output=(-1, *jr),
backend=backend,
Expand All @@ -141,7 +141,7 @@ def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12):

ndim = len(bm)
x_inds = (-1, *range(ndim))
m_inds = [(-1, i) for i in range(ndim)]
m_inds = [(-1, p) for p in range(ndim)]
bmx = array_contract(
arrays=(bx, *bm),
inputs=(x_inds, *m_inds),
Expand All @@ -151,13 +151,13 @@ def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12):
bminv = 1 / (bm + smudge_factor)

mouts = []
for i in range(ndim):
for p in range(ndim):
# sum all but ith index, apply inverse gate to that
mouts.append(
array_contract(
arrays=(bmx, bminv[i]),
inputs=(x_inds, m_inds[i]),
output=m_inds[i],
arrays=(bmx, bminv[p]),
inputs=(x_inds, m_inds[p]),
output=m_inds[p],
)
)

Expand Down Expand Up @@ -188,14 +188,13 @@ def _compute_output_single_m(bm, normalize, smudge_factor=1e-12):
def _update_output_to_input_single_batched(
batched_input,
batched_output,
maskin,
maskout,
mask,
_distance_fn,
damping=0.0,
):
# do a vectorized update
select_in = (maskin[:, 0], maskin[:, 1], slice(None))
select_out = (maskout[:, 0], maskout[:, 1], slice(None))
select_in = (mask[0, 0, :], mask[0, 1, :], slice(None))
select_out = (mask[1, 0, :], mask[1, 1, :], slice(None))
bim = batched_input[select_in]
bom = batched_output[select_out]

Expand All @@ -214,6 +213,84 @@ def _update_output_to_input_single_batched(
return mdiff


def _gather_zb(zb, power=1.0):
"""Given a vector of local contraction estimates `zb`, compute their
product, avoiding underflow/overflow by accumulating the sign and exponent
separately.
Parameters
----------
zb : array
The local contraction estimates.
power : float, optional
Raise the final result to this power.
Returns
-------
sign : float
The accumulated sign or phase.
exponent : float
The accumulated exponent.
"""
zb_mag = ar.do("abs", zb)
zb_phase = zb / zb_mag

# accumulate sign and exponent separately
sign = ar.do("prod", zb_phase)
exponent = ar.do("sum", ar.do("log10", zb_mag))

if power != 1.0:
sign **= power
exponent *= power

return sign, exponent


def _contract_index_region_single(bm):
# take product over input position and sum over variable
zb = ar.do("sum", ar.do("prod", bm, axis=0), axis=1)
# that just leaves broadcast dimension to take product over
return _gather_zb(zb)


def _contract_tensor_region_single(rank, batched_tensors, batched_inputs_t):
bt = batched_tensors[rank]
bms = batched_inputs_t[rank]
# contract every tensor of rank `rank` with its messages
zb = array_contract(
[bt, *bms],
inputs=[tuple(range(-1, rank))] + [(-1, r) for r in range(rank)],
output=(-1,),
)
return _gather_zb(zb)


def _contract_messages_pair_single(
ranki,
ranko,
mask,
batched_inputs_m,
batched_inputs_t,
):
bmm = batched_inputs_m[ranki]
bmt = batched_inputs_t[ranko]
select_in = (mask[0, 0, :], mask[0, 1, :], slice(None))
select_out = (mask[1, 0, :], mask[1, 1, :], slice(None))

bml = bmm[select_in]
bmr = bmt[select_out]

zb = array_contract(
[bml, bmr],
inputs=[(-1, 0), (-1, 0)],
output=(-1,),
)

# individual message reasons having counting factor -1
# i.e. we are dividing by all of them
return _gather_zb(zb, power=-1.0)


def _extract_messages_from_inputs_batched(
batched_inputs_m,
batched_inputs_t,
Expand All @@ -222,10 +299,10 @@ def _extract_messages_from_inputs_batched(
):
"""Get all messages as a dict from the batch stacked input form."""
messages = {}
for pair, (rank, i, b) in input_locs_m.items():
messages[pair] = batched_inputs_m[rank][i, b, :]
for pair, (rank, i, b) in input_locs_t.items():
messages[pair] = batched_inputs_t[rank][i, b, :]
for pair, (rank, p, b) in input_locs_m.items():
messages[pair] = batched_inputs_m[rank][p, b, :]
for pair, (rank, p, b) in input_locs_t.items():
messages[pair] = batched_inputs_t[rank][p, b, :]
return messages


Expand Down Expand Up @@ -389,6 +466,15 @@ def initialize_messages_batched(self, messages=None):
_stack = ar.get_lib_fn(self.backend, "stack")
_array = ar.get_lib_fn(self.backend, "array")

# here we are stacking all contractions with matching rank
#
# rank: number of incident messages to a tensor or hyper index
# pos (p): which of those messages we are (0, 1, ..., rank - 1)
# batch position (b): which position in the stack we are
#
# _m = messages incident to indices
# _t = messages incident to tensors

# prepare index messages
batched_inputs_m = {}
input_locs_m = {}
Expand All @@ -401,13 +487,13 @@ def initialize_messages_batched(self, messages=None):
except KeyError:
batch = batched_inputs_m[rank] = [[] for _ in range(rank)]

for i, tid in enumerate(tids):
batch_i = batch[i]
for p, tid in enumerate(tids):
batch_p = batch[p]
# position in the stack
b = len(batch_i)
input_locs_m[tid, ix] = (rank, i, b)
output_locs_m[ix, tid] = (rank, i, b)
batch_i.append(messages[tid, ix])
b = len(batch_p)
input_locs_m[tid, ix] = (rank, p, b)
output_locs_m[ix, tid] = (rank, p, b)
batch_p.append(messages[tid, ix])

# prepare tensor messages
batched_tensors = {}
Expand All @@ -428,21 +514,21 @@ def initialize_messages_batched(self, messages=None):
batch = batched_inputs_t[rank] = [[] for _ in range(rank)]
batch_t = batched_tensors[rank] = []

for i, ix in enumerate(t.inds):
batch_i = batch[i]
for p, ix in enumerate(t.inds):
batch_p = batch[p]
# position in the stack
b = len(batch_i)
input_locs_t[ix, tid] = (rank, i, b)
output_locs_t[tid, ix] = (rank, i, b)
batch_i.append(messages[ix, tid])
b = len(batch_p)
input_locs_t[ix, tid] = (rank, p, b)
output_locs_t[tid, ix] = (rank, p, b)
batch_p.append(messages[ix, tid])

batch_t.append(t.data)

# stack messages in into single arrays
for batched_inputs in (batched_inputs_m, batched_inputs_t):
for key, batch in batched_inputs.items():
batched_inputs[key] = _stack(
tuple(_stack(batch_i) for batch_i in batch)
tuple(_stack(batch_p) for batch_p in batch)
)
# stack all tensors of each rank into a single array
for rank, tensors in batched_tensors.items():
Expand All @@ -456,18 +542,26 @@ def initialize_messages_batched(self, messages=None):
(masks_t, input_locs_t, output_locs_m),
]:
for pair in input_locs:
(ranki, ii, bi) = input_locs[pair]
(ranko, io, bo) = output_locs[pair]
(ranki, pi, bi) = input_locs[pair]
(ranko, po, bo) = output_locs[pair]
# we can vectorize over all distinct pairs of ranks
key = (ranki, ranko)
try:
maskin, maskout = masks[key]
ma_pi, ma_po, ma_bi, ma_bo = masks[key]
except KeyError:
maskin, maskout = masks[key] = [], []
maskin.append([ii, bi])
maskout.append([io, bo])
ma_pi, ma_po, ma_bi, ma_bo = masks[key] = [], [], [], []

for key, (maskin, maskout) in masks.items():
masks[key] = _array(maskin), _array(maskout)
ma_pi.append(pi)
ma_bi.append(bi)
ma_po.append(po)
ma_bo.append(bo)

for key, (ma_pi, ma_po, ma_bi, ma_bo) in masks.items():
# first dimension is in/out
# second dimension is position or batch
# third dimension is stack index
mask = _array([[ma_pi, ma_bi], [ma_po, ma_bo]])
masks[key] = mask

self.batched_inputs_m = batched_inputs_m
self.batched_inputs_t = batched_inputs_t
Expand Down Expand Up @@ -535,12 +629,11 @@ def _update_outputs_to_inputs_batched(
(
batched_inputs[ranki],
batched_outputs[ranko],
maskin,
maskout,
mask,
self._distance_fn,
self.damping,
)
for (ranki, ranko), (maskin, maskout) in masks.items()
for (ranki, ranko), mask in masks.items()
)

if self.pool is None:
Expand All @@ -555,25 +648,25 @@ def _update_outputs_to_inputs_batched(

def iterate(self, tol=None):
# first we compute new tensor output messages
self.batched_outputs_t = self._compute_outputs_batched(
batched_outputs_t = self._compute_outputs_batched(
batched_inputs=self.batched_inputs_t,
batched_tensors=self.batched_tensors,
)
# update the index input messages with these
t_max_mdiff = self._update_outputs_to_inputs_batched(
self.batched_inputs_m,
self.batched_outputs_t,
batched_outputs_t,
self.masks_m,
)

# compute index messages
self.batched_outputs_m = self._compute_outputs_batched(
batched_outputs_m = self._compute_outputs_batched(
batched_inputs=self.batched_inputs_m,
)
# update the tensor input messages
m_max_mdiff = self._update_outputs_to_inputs_batched(
self.batched_inputs_t,
self.batched_outputs_m,
batched_outputs_m,
self.masks_t,
)
return max(t_max_mdiff, m_max_mdiff)
Expand All @@ -587,8 +680,75 @@ def get_messages(self):
self.input_locs_t,
)

def contract(self, strip_exponent=False, check_zero=True):
# TODO: do this in batched form directly
def contract(self, strip_exponent=False, check_zero=False):
"""Estimate the contraction of the tensor network using the current
messages. Uses batched vectorized contractions for speed.
Parameters
----------
strip_exponent : bool, optional
Whether to strip the exponent from the final result. If ``True``
then the returned result is ``(mantissa, exponent)``.
check_zero : bool, optional
Whether to check for zero values and return zero early. Currently
``True`` is not implemented for HV1BP.
Returns
-------
scalar or (scalar, float)
"""
if check_zero:
raise NotImplementedError("check_zero not implemented for HV1BP.")

fn_args = []
# for each rank contract index region estimate
for bm in self.batched_inputs_m.values():
fn_args.append((_contract_index_region_single, (bm,)))
# for each rank contract tensor region estimate
for rank in self.batched_tensors:
fn_args.append(
(
_contract_tensor_region_single,
(rank, self.batched_tensors, self.batched_inputs_t),
)
)
# for each pair of ranks contract messages pair
# region estimate which we divide by (power=-1.0)
for (ranki, ranko), mask in self.masks_m.items():
fn_args.append(
(
_contract_messages_pair_single,
(
ranki,
ranko,
mask,
self.batched_inputs_m,
self.batched_inputs_t,
),
)
)

if self.pool is None:
results = [fn(*args) for fn, args in fn_args]
else:
futs = [self.pool.submit(fn, *args) for fn, args in fn_args]
results = [fut.result() for fut in futs]

sign = 1.0
exponent = 0.0
for s, e in results:
sign *= s
exponent += e

if strip_exponent:
return sign, exponent
else:
return sign * 10**exponent

def contract_dense(self, strip_exponent=False, check_zero=True):
"""Slow contraction via explicit extranting individual dense messages.
This supports check_zero=True and may be useful for debugging.
"""
return contract_hyper_messages(
self.tn,
self.get_messages(),
Expand All @@ -612,7 +772,7 @@ def contract_hv1bp(
tol_rolling_diff=None,
smudge_factor=1e-12,
strip_exponent=False,
check_zero=True,
check_zero=False,
info=None,
progbar=False,
):
Expand Down

0 comments on commit 0958058

Please sign in to comment.