diff --git a/quimb/experimental/belief_propagation/hv1bp.py b/quimb/experimental/belief_propagation/hv1bp.py index dd48b16f..481dfd50 100644 --- a/quimb/experimental/belief_propagation/hv1bp.py +++ b/quimb/experimental/belief_propagation/hv1bp.py @@ -114,7 +114,7 @@ def _compute_all_tensor_messages_tree_batched(bx, bm): # contract the right messages to get new left array xl = array_contract( - arrays=(bx, *(mr[i] for i in range(mr.shape[0]))), + arrays=(bx, *(mr[p] for p in range(mr.shape[0]))), inputs=((-1, *js), *((-1, j) for j in jr)), output=(-1, *jl), backend=backend, @@ -122,7 +122,7 @@ def _compute_all_tensor_messages_tree_batched(bx, bm): # contract the left messages to get new right array xr = array_contract( - arrays=(bx, *(ml[i] for i in range(ml.shape[0]))), + arrays=(bx, *(ml[p] for p in range(ml.shape[0]))), inputs=((-1, *js), *((-1, j) for j in jl)), output=(-1, *jr), backend=backend, @@ -141,7 +141,7 @@ def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12): ndim = len(bm) x_inds = (-1, *range(ndim)) - m_inds = [(-1, i) for i in range(ndim)] + m_inds = [(-1, p) for p in range(ndim)] bmx = array_contract( arrays=(bx, *bm), inputs=(x_inds, *m_inds), @@ -151,13 +151,13 @@ def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12): bminv = 1 / (bm + smudge_factor) mouts = [] - for i in range(ndim): + for p in range(ndim): # sum all but ith index, apply inverse gate to that mouts.append( array_contract( - arrays=(bmx, bminv[i]), - inputs=(x_inds, m_inds[i]), - output=m_inds[i], + arrays=(bmx, bminv[p]), + inputs=(x_inds, m_inds[p]), + output=m_inds[p], ) ) @@ -188,14 +188,13 @@ def _compute_output_single_m(bm, normalize, smudge_factor=1e-12): def _update_output_to_input_single_batched( batched_input, batched_output, - maskin, - maskout, + mask, _distance_fn, damping=0.0, ): # do a vectorized update - select_in = (maskin[:, 0], maskin[:, 1], slice(None)) - select_out = (maskout[:, 0], maskout[:, 1], slice(None)) + select_in = (mask[0, 0, :], mask[0, 1, :], slice(None)) + select_out = (mask[1, 0, :], mask[1, 1, :], slice(None)) bim = batched_input[select_in] bom = batched_output[select_out] @@ -214,6 +213,84 @@ def _update_output_to_input_single_batched( return mdiff +def _gather_zb(zb, power=1.0): + """Given a vector of local contraction estimates `zb`, compute their + product, avoiding underflow/overflow by accumulating the sign and exponent + separately. + + Parameters + ---------- + zb : array + The local contraction estimates. + power : float, optional + Raise the final result to this power. + + Returns + ------- + sign : float + The accumulated sign or phase. + exponent : float + The accumulated exponent. + """ + zb_mag = ar.do("abs", zb) + zb_phase = zb / zb_mag + + # accumulate sign and exponent separately + sign = ar.do("prod", zb_phase) + exponent = ar.do("sum", ar.do("log10", zb_mag)) + + if power != 1.0: + sign **= power + exponent *= power + + return sign, exponent + + +def _contract_index_region_single(bm): + # take product over input position and sum over variable + zb = ar.do("sum", ar.do("prod", bm, axis=0), axis=1) + # that just leaves broadcast dimension to take product over + return _gather_zb(zb) + + +def _contract_tensor_region_single(rank, batched_tensors, batched_inputs_t): + bt = batched_tensors[rank] + bms = batched_inputs_t[rank] + # contract every tensor of rank `rank` with its messages + zb = array_contract( + [bt, *bms], + inputs=[tuple(range(-1, rank))] + [(-1, r) for r in range(rank)], + output=(-1,), + ) + return _gather_zb(zb) + + +def _contract_messages_pair_single( + ranki, + ranko, + mask, + batched_inputs_m, + batched_inputs_t, +): + bmm = batched_inputs_m[ranki] + bmt = batched_inputs_t[ranko] + select_in = (mask[0, 0, :], mask[0, 1, :], slice(None)) + select_out = (mask[1, 0, :], mask[1, 1, :], slice(None)) + + bml = bmm[select_in] + bmr = bmt[select_out] + + zb = array_contract( + [bml, bmr], + inputs=[(-1, 0), (-1, 0)], + output=(-1,), + ) + + # individual message reasons having counting factor -1 + # i.e. we are dividing by all of them + return _gather_zb(zb, power=-1.0) + + def _extract_messages_from_inputs_batched( batched_inputs_m, batched_inputs_t, @@ -222,10 +299,10 @@ def _extract_messages_from_inputs_batched( ): """Get all messages as a dict from the batch stacked input form.""" messages = {} - for pair, (rank, i, b) in input_locs_m.items(): - messages[pair] = batched_inputs_m[rank][i, b, :] - for pair, (rank, i, b) in input_locs_t.items(): - messages[pair] = batched_inputs_t[rank][i, b, :] + for pair, (rank, p, b) in input_locs_m.items(): + messages[pair] = batched_inputs_m[rank][p, b, :] + for pair, (rank, p, b) in input_locs_t.items(): + messages[pair] = batched_inputs_t[rank][p, b, :] return messages @@ -389,6 +466,15 @@ def initialize_messages_batched(self, messages=None): _stack = ar.get_lib_fn(self.backend, "stack") _array = ar.get_lib_fn(self.backend, "array") + # here we are stacking all contractions with matching rank + # + # rank: number of incident messages to a tensor or hyper index + # pos (p): which of those messages we are (0, 1, ..., rank - 1) + # batch position (b): which position in the stack we are + # + # _m = messages incident to indices + # _t = messages incident to tensors + # prepare index messages batched_inputs_m = {} input_locs_m = {} @@ -401,13 +487,13 @@ def initialize_messages_batched(self, messages=None): except KeyError: batch = batched_inputs_m[rank] = [[] for _ in range(rank)] - for i, tid in enumerate(tids): - batch_i = batch[i] + for p, tid in enumerate(tids): + batch_p = batch[p] # position in the stack - b = len(batch_i) - input_locs_m[tid, ix] = (rank, i, b) - output_locs_m[ix, tid] = (rank, i, b) - batch_i.append(messages[tid, ix]) + b = len(batch_p) + input_locs_m[tid, ix] = (rank, p, b) + output_locs_m[ix, tid] = (rank, p, b) + batch_p.append(messages[tid, ix]) # prepare tensor messages batched_tensors = {} @@ -428,13 +514,13 @@ def initialize_messages_batched(self, messages=None): batch = batched_inputs_t[rank] = [[] for _ in range(rank)] batch_t = batched_tensors[rank] = [] - for i, ix in enumerate(t.inds): - batch_i = batch[i] + for p, ix in enumerate(t.inds): + batch_p = batch[p] # position in the stack - b = len(batch_i) - input_locs_t[ix, tid] = (rank, i, b) - output_locs_t[tid, ix] = (rank, i, b) - batch_i.append(messages[ix, tid]) + b = len(batch_p) + input_locs_t[ix, tid] = (rank, p, b) + output_locs_t[tid, ix] = (rank, p, b) + batch_p.append(messages[ix, tid]) batch_t.append(t.data) @@ -442,7 +528,7 @@ def initialize_messages_batched(self, messages=None): for batched_inputs in (batched_inputs_m, batched_inputs_t): for key, batch in batched_inputs.items(): batched_inputs[key] = _stack( - tuple(_stack(batch_i) for batch_i in batch) + tuple(_stack(batch_p) for batch_p in batch) ) # stack all tensors of each rank into a single array for rank, tensors in batched_tensors.items(): @@ -456,18 +542,26 @@ def initialize_messages_batched(self, messages=None): (masks_t, input_locs_t, output_locs_m), ]: for pair in input_locs: - (ranki, ii, bi) = input_locs[pair] - (ranko, io, bo) = output_locs[pair] + (ranki, pi, bi) = input_locs[pair] + (ranko, po, bo) = output_locs[pair] + # we can vectorize over all distinct pairs of ranks key = (ranki, ranko) try: - maskin, maskout = masks[key] + ma_pi, ma_po, ma_bi, ma_bo = masks[key] except KeyError: - maskin, maskout = masks[key] = [], [] - maskin.append([ii, bi]) - maskout.append([io, bo]) + ma_pi, ma_po, ma_bi, ma_bo = masks[key] = [], [], [], [] - for key, (maskin, maskout) in masks.items(): - masks[key] = _array(maskin), _array(maskout) + ma_pi.append(pi) + ma_bi.append(bi) + ma_po.append(po) + ma_bo.append(bo) + + for key, (ma_pi, ma_po, ma_bi, ma_bo) in masks.items(): + # first dimension is in/out + # second dimension is position or batch + # third dimension is stack index + mask = _array([[ma_pi, ma_bi], [ma_po, ma_bo]]) + masks[key] = mask self.batched_inputs_m = batched_inputs_m self.batched_inputs_t = batched_inputs_t @@ -535,12 +629,11 @@ def _update_outputs_to_inputs_batched( ( batched_inputs[ranki], batched_outputs[ranko], - maskin, - maskout, + mask, self._distance_fn, self.damping, ) - for (ranki, ranko), (maskin, maskout) in masks.items() + for (ranki, ranko), mask in masks.items() ) if self.pool is None: @@ -555,25 +648,25 @@ def _update_outputs_to_inputs_batched( def iterate(self, tol=None): # first we compute new tensor output messages - self.batched_outputs_t = self._compute_outputs_batched( + batched_outputs_t = self._compute_outputs_batched( batched_inputs=self.batched_inputs_t, batched_tensors=self.batched_tensors, ) # update the index input messages with these t_max_mdiff = self._update_outputs_to_inputs_batched( self.batched_inputs_m, - self.batched_outputs_t, + batched_outputs_t, self.masks_m, ) # compute index messages - self.batched_outputs_m = self._compute_outputs_batched( + batched_outputs_m = self._compute_outputs_batched( batched_inputs=self.batched_inputs_m, ) # update the tensor input messages m_max_mdiff = self._update_outputs_to_inputs_batched( self.batched_inputs_t, - self.batched_outputs_m, + batched_outputs_m, self.masks_t, ) return max(t_max_mdiff, m_max_mdiff) @@ -587,8 +680,75 @@ def get_messages(self): self.input_locs_t, ) - def contract(self, strip_exponent=False, check_zero=True): - # TODO: do this in batched form directly + def contract(self, strip_exponent=False, check_zero=False): + """Estimate the contraction of the tensor network using the current + messages. Uses batched vectorized contractions for speed. + + Parameters + ---------- + strip_exponent : bool, optional + Whether to strip the exponent from the final result. If ``True`` + then the returned result is ``(mantissa, exponent)``. + check_zero : bool, optional + Whether to check for zero values and return zero early. Currently + ``True`` is not implemented for HV1BP. + + Returns + ------- + scalar or (scalar, float) + """ + if check_zero: + raise NotImplementedError("check_zero not implemented for HV1BP.") + + fn_args = [] + # for each rank contract index region estimate + for bm in self.batched_inputs_m.values(): + fn_args.append((_contract_index_region_single, (bm,))) + # for each rank contract tensor region estimate + for rank in self.batched_tensors: + fn_args.append( + ( + _contract_tensor_region_single, + (rank, self.batched_tensors, self.batched_inputs_t), + ) + ) + # for each pair of ranks contract messages pair + # region estimate which we divide by (power=-1.0) + for (ranki, ranko), mask in self.masks_m.items(): + fn_args.append( + ( + _contract_messages_pair_single, + ( + ranki, + ranko, + mask, + self.batched_inputs_m, + self.batched_inputs_t, + ), + ) + ) + + if self.pool is None: + results = [fn(*args) for fn, args in fn_args] + else: + futs = [self.pool.submit(fn, *args) for fn, args in fn_args] + results = [fut.result() for fut in futs] + + sign = 1.0 + exponent = 0.0 + for s, e in results: + sign *= s + exponent += e + + if strip_exponent: + return sign, exponent + else: + return sign * 10**exponent + + def contract_dense(self, strip_exponent=False, check_zero=True): + """Slow contraction via explicit extranting individual dense messages. + This supports check_zero=True and may be useful for debugging. + """ return contract_hyper_messages( self.tn, self.get_messages(), @@ -612,7 +772,7 @@ def contract_hv1bp( tol_rolling_diff=None, smudge_factor=1e-12, strip_exponent=False, - check_zero=True, + check_zero=False, info=None, progbar=False, ):