diff --git a/docs/changelog.md b/docs/changelog.md index 89d48192..62b5bfa7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -8,6 +8,10 @@ Release notes for `quimb`. **Enhancements:** - [`MatrixProductState.measure`](quimb.tensor.tensor_1d.MatrixProductState.measure): add a `seed` kwarg +- belief propagation, implement DIIS (direct inversion in the iterative subspace) +- belief propagation, unify various aspects such as message normalization and distance. +- belief propagation, add a `plot` method. +- add `qu.plot_multi_series_zoom` for plotting multiple series with a zoomed inset, useful for various convergence plots such as BP **Bug fixes:** diff --git a/quimb/__init__.py b/quimb/__init__.py index 908d7a43..ed15a95c 100644 --- a/quimb/__init__.py +++ b/quimb/__init__.py @@ -20,488 +20,481 @@ import warnings # some useful math -from math import pi, cos, sin, tan, exp, log, log2, log10, sqrt +from math import cos, exp, log, log2, log10, pi, sin, sqrt, tan + +# Functions for calculating properties +from .calc import ( + bell_decomp, + concurrence, + correlation, + cprint, + decomp, + dephase, + ent_cross_matrix, + entropy, + entropy_subsys, + fidelity, + heisenberg_energy, + is_degenerate, + is_eigenvector, + kraus_op, + logarithmic_negativity, + logneg, + logneg_subsys, + measure, + mutinf, + mutinf_subsys, + mutual_information, + negativity, + one_way_classical_information, + page_entropy, + partial_transpose, + pauli_correlations, + pauli_decomp, + projector, + purify, + qid, + quantum_discord, + schmidt_gap, + simulate_counts, + tr_sqrt, + tr_sqrt_subsys, + trace_distance, +) # Core functions from .core import ( - qarray, - prod, - isket, + bra, + chop, + dag, + dim_compress, + dim_map, + dop, + dot, + expec, + expectation, + explt, + eye, + get_thread_pool, + identity, + ikron, + infer_size, isbra, - isop, - isvec, - issparse, isdense, - isreal, isherm, + isket, + isop, ispos, - mul, - dag, - dot, - vdot, - rdot, + isreal, + issparse, + isvec, + itrace, + ket, + kron, + kronpow, ldmul, - rdmul, - outer, - explt, - get_thread_pool, + mul, + nmlz, normalize, - chop, - quimbify, + outer, + partial_trace, + permute, + pkron, + prod, + ptr, + qarray, qu, - ket, - bra, - dop, + quimbify, + rdmul, + rdot, sparse, - infer_size, - trace, - identity, - eye, speye, - dim_map, - dim_compress, - kron, - kronpow, - ikron, - pkron, - permute, - itrace, - partial_trace, - expectation, - expec, - nmlz, tr, - ptr, + trace, + vdot, ) -# Linear algebra functions -from .linalg.base_linalg import ( - eigensystem, - eig, - eigh, - eigvals, - eigvalsh, - eigvecs, - eigvecsh, - eigensystem_partial, - groundstate, - groundenergy, - bound_spectrum, - eigh_window, - eigvalsh_window, - eigvecsh_window, - svd, - svds, - norm, - expm, - sqrtm, - expm_multiply, - Lazy, -) -from .linalg.rand_linalg import rsvd, estimate_rank -from .linalg.mpi_launcher import get_mpi_pool, can_use_mpi_pool +# Evolution class and methods +from .evo import Evolution # Generating objects from .gen.operators import ( - spin_operator, - pauli, - hadamard, - phase_gate, - S_gate, - T_gate, - U_gate, - rotation, + CNOT, Rx, Ry, Rz, + S_gate, + T_gate, + U_gate, + Wsqrt, Xsqrt, Ysqrt, Zsqrt, - Wsqrt, - swap, - iswap, - fsim, - fsimg, - ncontrolled_gate, - controlled, - CNOT, - cX, - cY, - cZ, ccX, ccY, ccZ, + controlled, controlled_swap, + create, cswap, + cX, + cY, + cZ, + destroy, fredkin, - toffoli, + fsim, + fsimg, + hadamard, ham_heis, + ham_heis_2D, + ham_hubbard_hardcore, ham_ising, - ham_XY, - ham_XXZ, ham_j1j2, ham_mbl, - ham_heis_2D, - zspin_projector, - create, - destroy, + ham_XXZ, + ham_XY, + iswap, + ncontrolled_gate, num, - ham_hubbard_hardcore, -) -from .gen.states import ( - basis_vec, - up, - zplus, - down, - zminus, - plus, - xplus, - minus, - xminus, - yplus, - yminus, - bloch_state, - bell_state, - singlet, - thermal_state, - neel_state, - singlet_pairs, - werner_state, - ghz_state, - w_state, - levi_civita, - perm_state, - graph_state_1d, - computational_state, + pauli, + phase_gate, + rotation, + spin_operator, + swap, + toffoli, + zspin_projector, ) from .gen.rand import ( - randn, + gen_rand_haar_states, rand, - rand_matrix, + rand_haar_state, rand_herm, - rand_pos, - rand_rho, + rand_iso, rand_ket, - rand_uni, - rand_haar_state, - gen_rand_haar_states, - rand_mix, - rand_product_state, + rand_matrix, rand_matrix_product_state, + rand_mera, + rand_mix, rand_mps, + rand_pos, + rand_product_state, + rand_rho, rand_seperable, - rand_iso, - rand_mera, + rand_uni, + randn, seed_rand, set_rand_bitgen, ) - -# Functions for calculating properties -from .calc import ( - fidelity, - purify, - entropy, - entropy_subsys, - mutual_information, - mutinf, - mutinf_subsys, - schmidt_gap, - tr_sqrt, - tr_sqrt_subsys, - partial_transpose, - negativity, - logarithmic_negativity, - logneg, - logneg_subsys, - concurrence, - one_way_classical_information, - quantum_discord, - trace_distance, - cprint, - decomp, - pauli_decomp, - bell_decomp, - correlation, - pauli_correlations, - ent_cross_matrix, - qid, - is_degenerate, - is_eigenvector, - page_entropy, - heisenberg_energy, - dephase, - kraus_op, - projector, - measure, - simulate_counts, +from .gen.states import ( + basis_vec, + bell_state, + bloch_state, + computational_state, + down, + ghz_state, + graph_state_1d, + levi_civita, + minus, + neel_state, + perm_state, + plus, + singlet, + singlet_pairs, + thermal_state, + up, + w_state, + werner_state, + xminus, + xplus, + yminus, + yplus, + zminus, + zplus, ) - -# Evolution class and methods -from .evo import Evolution - from .linalg.approx_spectral import ( approx_spectral_function, + entropy_subsys_approx, + logneg_subsys_approx, + negativity_subsys_approx, tr_abs_approx, tr_exp_approx, tr_sqrt_approx, tr_xlogx_approx, - entropy_subsys_approx, - logneg_subsys_approx, - negativity_subsys_approx, xlogx, ) + +# Linear algebra functions +from .linalg.base_linalg import ( + Lazy, + bound_spectrum, + eig, + eigensystem, + eigensystem_partial, + eigh, + eigh_window, + eigvals, + eigvalsh, + eigvalsh_window, + eigvecs, + eigvecsh, + eigvecsh_window, + expm, + expm_multiply, + groundenergy, + groundstate, + norm, + sqrtm, + svd, + svds, +) +from .linalg.mpi_launcher import can_use_mpi_pool, get_mpi_pool +from .linalg.rand_linalg import estimate_rank, rsvd from .utils import ( - save_to_disk, + LRU, + format_number_with_error, load_from_disk, oset, - LRU, - tree_map, + save_to_disk, tree_apply, tree_flatten, + tree_map, tree_unflatten, - format_number_with_error, +) +from .utils_plot import ( NEUTRAL_STYLE, default_to_neutral_style, + plot_multi_series_zoom, ) - warnings.filterwarnings("ignore", message="Caching is not available when ") __all__ = [ - # Accel ----------------------------------------------------------------- # - "qarray", - "prod", - "isket", - "isbra", - "isop", - "isvec", - "issparse", - "isdense", - "isreal", - "isherm", - "ispos", - "mul", - "dag", - "dot", - "vdot", - "rdot", - "ldmul", - "rdmul", - "outer", - "explt", - # Core ------------------------------------------------------------------ # - "normalize", - "chop", - "quimbify", - "qu", - "ket", + "approx_spectral_function", + "basis_vec", + "bell_decomp", + "bell_state", + "bloch_state", + "bound_spectrum", "bra", - "dop", - "sparse", - "infer_size", - "trace", - "identity", - "eye", - "speye", - "dim_map", + "can_use_mpi_pool", + "ccX", + "ccY", + "ccZ", + "chop", + "CNOT", + "computational_state", + "concurrence", + "controlled_swap", + "controlled", + "correlation", + "cos", + "cprint", + "create", + "cswap", + "cX", + "cY", + "cZ", + "dag", + "decomp", + "default_to_neutral_style", + "dephase", + "destroy", "dim_compress", - "kron", - "kronpow", - "ikron", - "pkron", - "permute", - "itrace", - "partial_trace", - "expectation", - "expec", - "nmlz", - "tr", - "ptr", - # Linalg ---------------------------------------------------------------- # - "eigensystem", + "dim_map", + "dop", + "dot", + "down", "eig", + "eigensystem_partial", + "eigensystem", + "eigh_window", "eigh", "eigvals", + "eigvalsh_window", "eigvalsh", "eigvecs", - "eigvecsh", - "eigensystem_partial", - "groundstate", - "groundenergy", - "bound_spectrum", - "eigh_window", - "eigvalsh_window", "eigvecsh_window", - "svd", - "svds", - "norm", - "Lazy", - "rsvd", + "eigvecsh", + "ent_cross_matrix", + "entropy_subsys_approx", + "entropy_subsys", + "entropy", "estimate_rank", - # Gen ------------------------------------------------------------------- # - "spin_operator", - "pauli", - "hadamard", - "phase_gate", - "T_gate", - "S_gate", - "U_gate", - "rotation", - "Rx", - "Ry", - "Rz", - "Xsqrt", - "Ysqrt", - "Zsqrt", - "Wsqrt", - "swap", - "iswap", + "Evolution", + "exp", + "expec", + "expectation", + "explt", + "expm_multiply", + "expm", + "eye", + "fidelity", + "format_number_with_error", + "fredkin", "fsim", "fsimg", - "ncontrolled_gate", - "controlled", - "CNOT", - "cX", - "cY", - "cZ", - "ccX", - "ccY", - "ccZ", - "controlled_swap", - "cswap", - "fredkin", - "toffoli", + "gen_rand_haar_states", + "get_mpi_pool", + "get_thread_pool", + "ghz_state", + "graph_state_1d", + "groundenergy", + "groundstate", + "hadamard", + "ham_heis_2D", "ham_heis", + "ham_hubbard_hardcore", "ham_ising", - "ham_XY", - "ham_XXZ", "ham_j1j2", "ham_mbl", - "ham_heis_2D", - "create", - "destroy", - "num", - "ham_hubbard_hardcore", - "zspin_projector", - "basis_vec", - "up", - "zplus", - "down", - "zminus", - "plus", - "xplus", + "ham_XXZ", + "ham_XY", + "heisenberg_energy", + "identity", + "ikron", + "infer_size", + "is_degenerate", + "is_eigenvector", + "isbra", + "isdense", + "isherm", + "isket", + "isop", + "ispos", + "isreal", + "issparse", + "isvec", + "iswap", + "itrace", + "ket", + "kraus_op", + "kron", + "kronpow", + "Lazy", + "ldmul", + "levi_civita", + "load_from_disk", + "log", + "log10", + "log2", + "logarithmic_negativity", + "logneg_subsys_approx", + "logneg_subsys", + "logneg", + "LRU", + "measure", "minus", - "xminus", - "yplus", - "yminus", - "bloch_state", - "bell_state", - "singlet", - "thermal_state", + "mul", + "mutinf_subsys", + "mutinf", + "mutual_information", + "ncontrolled_gate", "neel_state", - "singlet_pairs", - "werner_state", - "ghz_state", - "w_state", - "levi_civita", + "negativity_subsys_approx", + "negativity", + "NEUTRAL_STYLE", + "nmlz", + "norm", + "normalize", + "num", + "one_way_classical_information", + "oset", + "outer", + "page_entropy", + "partial_trace", + "partial_transpose", + "pauli_correlations", + "pauli_decomp", + "pauli", "perm_state", - "graph_state_1d", - "rand_matrix", + "permute", + "phase_gate", + "pi", + "pkron", + "plot_multi_series_zoom", + "plus", + "prod", + "projector", + "ptr", + "purify", + "qarray", + "qid", + "qu", + "quantum_discord", + "quimbify", + "rand_haar_state", "rand_herm", - "rand_pos", - "rand_rho", + "rand_iso", "rand_ket", - "rand_uni", - "rand_haar_state", - "gen_rand_haar_states", + "rand_matrix_product_state", + "rand_matrix", + "rand_mera", "rand_mix", "rand_mps", - "randn", - "rand", + "rand_pos", "rand_product_state", - "rand_matrix_product_state", + "rand_rho", "rand_seperable", - "rand_iso", - "rand_mera", + "rand_uni", + "rand", + "randn", + "rdmul", + "rdot", + "rotation", + "rsvd", + "Rx", + "Ry", + "Rz", + "S_gate", + "save_to_disk", + "schmidt_gap", "seed_rand", "set_rand_bitgen", - "computational_state", - # Calc ------------------------------------------------------------------ # - "expm", - "sqrtm", - "expm_multiply", - "fidelity", - "purify", - "entropy", - "entropy_subsys", - "mutual_information", - "mutinf", - "mutinf_subsys", - "schmidt_gap", - "tr_sqrt", - "tr_sqrt_subsys", - "partial_transpose", - "negativity", - "logarithmic_negativity", - "logneg", - "logneg_subsys", - "concurrence", - "one_way_classical_information", - "quantum_discord", - "trace_distance", - "cprint", - "decomp", - "pauli_decomp", - "bell_decomp", - "correlation", - "pauli_correlations", - "ent_cross_matrix", - "qid", - "is_degenerate", - "is_eigenvector", - "page_entropy", - "heisenberg_energy", - "dephase", - "kraus_op", - "projector", - "measure", "simulate_counts", - # Evo ------------------------------------------------------------------- # - "Evolution", - # Approx spectral ------------------------------------------------------- # - "approx_spectral_function", + "sin", + "singlet_pairs", + "singlet", + "sparse", + "speye", + "spin_operator", + "sqrt", + "sqrtm", + "svd", + "svds", + "swap", + "T_gate", + "tan", + "thermal_state", + "toffoli", "tr_abs_approx", "tr_exp_approx", "tr_sqrt_approx", + "tr_sqrt_subsys", + "tr_sqrt", "tr_xlogx_approx", - "entropy_subsys_approx", - "logneg_subsys_approx", - "negativity_subsys_approx", - # Some misc useful math ------------------------------------------------- # - "pi", - "cos", - "sin", - "tan", - "exp", - "log", - "log2", - "log10", - "sqrt", - "xlogx", - # Utils ----------------------------------------------------------------- # - "save_to_disk", - "load_from_disk", - "get_thread_pool", - "get_mpi_pool", - "can_use_mpi_pool", - "oset", - "LRU", - "tree_map", + "tr", + "trace_distance", + "trace", "tree_apply", "tree_flatten", + "tree_map", "tree_unflatten", - "format_number_with_error", - "NEUTRAL_STYLE", - "default_to_neutral_style", + "U_gate", + "up", + "vdot", + "w_state", + "werner_state", + "Wsqrt", + "xlogx", + "xminus", + "xplus", + "Xsqrt", + "yminus", + "yplus", + "Ysqrt", + "zminus", + "zplus", + "zspin_projector", + "Zsqrt", ] diff --git a/quimb/calc.py b/quimb/calc.py index 87d7f72e..0eaaac8c 100644 --- a/quimb/calc.py +++ b/quimb/calc.py @@ -1,14 +1,16 @@ """Functions for more advanced calculations of quantities and properties of quantum objects. """ -import numbers -import itertools -import functools + import collections -from math import sin, cos, pi, log, log2, sqrt +import functools +import itertools +import numbers +from math import cos, log, log2, pi, sin, sqrt import numpy as np import numpy.linalg as nla +from cotengra import array_contract from scipy.optimize import minimize from .core import ( @@ -32,23 +34,21 @@ tr, zeroify, ) -from .linalg.base_linalg import eigh, eigvalsh, norm, sqrtm, norm_trace_dense +from .gen.operators import pauli +from .gen.states import basis_vec, bell_state, bloch_state from .linalg.approx_spectral import ( entropy_subsys_approx, gen_bipartite_spectral_fn, logneg_subsys_approx, tr_sqrt_subsys_approx, ) -from .gen.operators import pauli -from .gen.states import basis_vec, bell_state, bloch_state +from .linalg.base_linalg import eigh, eigvalsh, norm, norm_trace_dense, sqrtm from .utils import ( frequencies, int2tup, keymap, ) -from .tensor.contraction import array_contract - def fidelity(p1, p2, squared=False): """Fidelity between two quantum states. By default, the unsquared fidelity diff --git a/quimb/experimental/__init__.py b/quimb/experimental/__init__.py index e69de29b..59c87277 100644 --- a/quimb/experimental/__init__.py +++ b/quimb/experimental/__init__.py @@ -0,0 +1,3 @@ +"""Submodule for experimental features that are likely untested and subject to +change. +""" \ No newline at end of file diff --git a/quimb/experimental/belief_propagation/__init__.py b/quimb/experimental/belief_propagation/__init__.py index f700980d..7000590a 100644 --- a/quimb/experimental/belief_propagation/__init__.py +++ b/quimb/experimental/belief_propagation/__init__.py @@ -77,7 +77,7 @@ update messages adjacent to messages which have changed. """ -from .bp_common import initialize_hyper_messages +from .bp_common import combine_local_contractions, initialize_hyper_messages from .d1bp import D1BP, contract_d1bp from .d2bp import D2BP, compress_d2bp, contract_d2bp, sample_d2bp from .hd1bp import HD1BP, contract_hd1bp, sample_hd1bp @@ -87,6 +87,7 @@ from .regions import RegionGraph __all__ = ( + "combine_local_contractions", "compress_d2bp", "compress_l2bp", "contract_d1bp", diff --git a/quimb/experimental/belief_propagation/bp_common.py b/quimb/experimental/belief_propagation/bp_common.py index d2d5cb96..d25cb277 100644 --- a/quimb/experimental/belief_propagation/bp_common.py +++ b/quimb/experimental/belief_propagation/bp_common.py @@ -43,15 +43,252 @@ class BeliefPropagationCommon: Parameters ---------- - max_iterations : int, optional - The maximum number of iterations to perform. - tol : float, optional - The convergence tolerance for messages. - progbar : bool, optional - Whether to show a progress bar. + tn : TensorNetwork + The tensor network to perform belief propagation on. + damping : float or callable, optional + The damping factor to apply to messages. This simply mixes some part + of the old message into the new one, with the final message being + ``damping * old + (1 - damping) * new``. This makes convergence more + reliable but slower. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially (newly computed messages are + immediately used for other updates in the same iteration round) or in + parallel (all messages are comptued using messages from the previous + round only). Sequential generally helps convergence but parallel can + possibly converge to differnt solutions. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. """ - def run(self, max_iterations=1000, tol=5e-6, info=None, progbar=False): + def __init__( + self, + tn, + *, + damping=0.0, + update="sequential", + normalize=None, + distance=None, + inplace=False, + ): + self.tn = tn if inplace else tn.copy() + self.backend = self.tn.backend + self.dtype = self.tn.dtype + self.sign = 1.0 + self.exponent = tn.exponent + self.damping = damping + self.update = update + + if normalize is None: + if "complex" in self.dtype: + normalize = "L2phased" + else: + normalize = "L2" + self.normalize = normalize + + if distance is None: + if ("complex" in self.dtype) and ( + callable(normalize) or ("phased" not in normalize) + ): + distance = "L2phased" + else: + distance = "L2" + self.distance = distance + + self.n = 0 + self.mdiffs = [] + self.rdiffs = [] + + @property + def damping(self): + return self._damping + + @damping.setter + def damping(self, damping): + if callable(damping): + self.fn_damping = self._damping = damping + else: + self._damping = damping + + def damp(old, new): + return damping * old + (1 - damping) * new + + self.fn_damping = damp + + @property + def normalize(self): + return self._normalize + + @normalize.setter + def normalize(self, normalize): + if callable(normalize): + # custom normalization function + _normalize_fn = normalize + + elif normalize == "L1": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _normalize_fn(x): + return x / _sum(_abs(x)) + + elif normalize == "L2": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _normalize_fn(x): + return x / (_sum(_abs(x) ** 2) ** 0.5) + + elif normalize == "L2phased": + _sum = ar.get_lib_fn(self.backend, "sum") + _abs = ar.get_lib_fn(self.backend, "abs") + + def _normalize_fn(x): + xnrm = float(_sum(_abs(x) ** 2)) ** 0.5 + sumx = complex(_sum(x)) + if sumx != 0.0: + if sumx.imag == 0.0: + sumx = sumx.real + sumx /= abs(sumx) + xnrm *= sumx + return x / xnrm + + elif normalize == "Linf": + _abs = ar.get_lib_fn(self.backend, "abs") + _max = ar.get_lib_fn(self.backend, "max") + + def _normalize_fn(x): + return x / _max(_abs(x)) + + else: + raise ValueError(f"Unrecognized normalize={normalize}") + + self._normalize = normalize + self._normalize_fn = _normalize_fn + + @property + def distance(self): + return self._distance + + @distance.setter + def distance(self, distance): + if callable(distance): + _distance_fn = distance + + elif distance == "L1": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _distance_fn(x, y): + return float(_sum(_abs(x - y))) + + elif distance == "L2": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _distance_fn(x, y): + return float(_sum(_abs(x - y) ** 2) ** 0.5) + + elif distance == "L2phased": + _conj = ar.get_lib_fn(self.backend, "conj") + _sum = ar.get_lib_fn(self.backend, "sum") + _abs = ar.get_lib_fn(self.backend, "abs") + + def _distance_fn(x, y): + xnorm = _sum(_abs(x) ** 2) ** 0.5 + ynorm = _sum(_abs(y) ** 2) ** 0.5 + # cosine similarity with phase + cs = _sum(_conj(x) * y) + phase = cs / _abs(cs) + xn = x / xnorm + yn = y / (ynorm * phase) + # L2 distance between normalized, phased vectors + return float(_sum(_abs(xn - yn) ** 2) ** 0.5) + + elif distance == "Linf": + _abs = ar.get_lib_fn(self.backend, "abs") + _max = ar.get_lib_fn(self.backend, "max") + + def _distance_fn(x, y): + return float(_max(_abs(x - y))) + + elif distance == "cosine": + # this is like L2phased, but with less precision + _conj = ar.get_lib_fn(self.backend, "conj") + _sum = ar.get_lib_fn(self.backend, "sum") + _abs = ar.get_lib_fn(self.backend, "abs") + + def _distance_fn(x, y): + xnorm = float(_sum(_abs(x) ** 2) ** 0.5) + ynorm = float(_sum(_abs(y) ** 2) ** 0.5) + # compute cosine similarity + cs = float(_abs(_sum(_conj(x) * y)) / (xnorm * ynorm)) + # clip to avoid numerical issues + cs = min(max(cs, -1.0), 1.0) + return (2 - 2 * cs) ** 0.5 + + else: + raise ValueError(f"Unrecognized distance={distance}") + + self._distance = distance + self._distance_fn = _distance_fn + + def run( + self, + max_iterations=1000, + diis=False, + tol=5e-6, + tol_abs=None, + tol_rolling_diff=None, + info=None, + progbar=False, + ): + """ + Parameters + ---------- + max_iterations : int, optional + The maximum number of iterations to perform. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + tol : float, optional + The convergence tolerance for messages. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. + info : dict, optional + If supplied, the following information will be added to it: + ``converged`` (bool), ``iterations`` (int), ``max_mdiff`` (float), + ``rolling_abs_mean_diff`` (float). + progbar : bool, optional + Whether to show a progress bar. + """ + if tol_abs is None: + tol_abs = tol + if tol_rolling_diff is None: + tol_rolling_diff = tol + if progbar: import tqdm @@ -59,30 +296,63 @@ def run(self, max_iterations=1000, tol=5e-6, info=None, progbar=False): else: pbar = None - try: - it = 0 - rdm = RollingDiffMean() - self.converged = False - while not self.converged and it < max_iterations: - # perform a single iteration of BP - # we supply tol here for use with local convergence - nconv, ncheck, max_mdiff = self.iterate(tol=tol) - it += 1 + if diis: + from .diis import DIIS + if isinstance(diis, dict): + self._diis = DIIS(**diis) + diis = True + else: + self._diis = DIIS() + else: + self._diis = None + + it = 0 + rdm = RollingDiffMean() + self.converged = False + while not self.converged and it < max_iterations: + # perform a single iteration of BP + # we supply tol here for use with local convergence + result = self.iterate(tol=tol) + + if diis: + # extrapolate new guess for messages + self.messages = self._diis.update(self.messages) + + if isinstance(result, dict): + max_mdiff = result.get("max_mdiff", float("inf")) + else: + max_mdiff = result + result = dict() + + self.mdiffs.append(max_mdiff) + + if pbar is not None: + msg = f"max|dM|={max_mdiff:.3g}" + + nconv = result.get("nconv", None) + if nconv is not None: + ncheck = result.get("ncheck", None) + msg += f" nconv: {nconv}/{ncheck} " + + pbar.set_description(msg, refresh=False) + pbar.update() + + # check covergence criteria + self.converged |= max_mdiff < tol_abs + if tol_rolling_diff > 0.0: # check rolling mean convergence rdm.update(max_mdiff) - self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol) + amd = rdm.absmeandiff() + self.converged |= amd < tol_rolling_diff + self.rdiffs.append(amd) - if pbar is not None: - pbar.set_description( - f"nconv: {nconv}/{ncheck} max|dM|={max_mdiff:.2e}", - refresh=False, - ) - pbar.update() + it += 1 + self.n += 1 - finally: - if pbar is not None: - pbar.close() + # finally: + if pbar is not None: + pbar.close() if tol != 0.0 and not self.converged: import warnings @@ -98,6 +368,19 @@ def run(self, max_iterations=1000, tol=5e-6, info=None, progbar=False): info["max_mdiff"] = max_mdiff info["rolling_abs_mean_diff"] = rdm.absmeandiff() + def plot(self, **kwargs): + from quimb import plot_multi_series_zoom + + data = { + "mdiffs": self.mdiffs, + "rdiffs": self.rdiffs, + } + if getattr(self, "_diis", None) is not None: + data["diis.lambdas"] = self._diis.lambdas + + kwargs.setdefault("yscale", "log") + return plot_multi_series_zoom(data, **kwargs) + def initialize_hyper_messages( tn, @@ -160,32 +443,65 @@ def initialize_hyper_messages( def combine_local_contractions( - tvals, - mvals, - backend, + values, + backend=None, strip_exponent=False, - check_for_zero=True, + check_zero=True, + mantissa=None, + exponent=None, ): - _abs = ar.get_lib_fn(backend, "abs") - _log10 = ar.get_lib_fn(backend, "log10") + """Combine a product of local contractions into a single value, avoiding + overflow/underflow by accumulating the mantissa and exponent separately. - mantissa = 1 - exponent = 0 - for vt in tvals: - avt = _abs(vt) + Parameters + ---------- + values : sequence of (scalar, int) + The values to combine, each with a power to be raised to. + backend : str, optional + The backend to use. Infered from the first value if not given. + strip_exponent : bool, optional + Whether to return the mantissa and exponent separately. + check_zero : bool, optional + Whether to check for zero values and return zero early. + mantissa : float, optional + The initial mantissa to accumulate into. + exponent : float, optional + The initial exponent to accumulate into. - if check_for_zero and (avt == 0.0): + Returns + ------- + result : float or (float, float) + The combined value, or the mantissa and exponent separately. + """ + if mantissa is None: + mantissa = 1.0 + if exponent is None: + exponent = 0.0 + + _abs = _log10 = None + for x, power in values: + if _abs is None: + # lazily get functions + if backend is None: + backend = ar.infer_backend(x) + _abs = ar.get_lib_fn(backend, "abs") + _log10 = ar.get_lib_fn(backend, "log10") + + # factor into phase and magnitude + x_mag = _abs(x) + x_phase = x / x_mag + + if check_zero and (x_mag == 0.0): + # checking explicitly avoids errors from taking log(0) if strip_exponent: return 0.0, 0.0 else: return 0.0 - mantissa = mantissa * (vt / avt) - exponent = exponent + _log10(avt) - for mt in mvals: - amt = _abs(mt) - mantissa = mantissa / (mt / amt) - exponent = exponent - _log10(amt) + # accumulate the mantissa and exponent separately, + # accounting for the local power / counting factor + mantissa = mantissa * x_phase**power + exponent = exponent + power * _log10(x_mag) if strip_exponent: return mantissa, exponent @@ -197,13 +513,13 @@ def contract_hyper_messages( tn, messages, strip_exponent=False, + check_zero=True, backend=None, ): """Estimate the contraction of ``tn`` given ``messages``, via the exponential of the Bethe free entropy. """ - tvals = [] - mvals = [] + zvals = [] for tid, t in tn.tensor_map.items(): if backend is None: @@ -216,25 +532,29 @@ def contract_hyper_messages( inputs.append((i,)) # local message overlap correction - mvals.append( - qtn.array_contract( - (messages[tid, ix], messages[ix, tid]), - inputs=((0,), (0,)), - output=(), - ) + z = qtn.array_contract( + (messages[tid, ix], messages[ix, tid]), + inputs=((0,), (0,)), + output=(), ) + zvals.append((z, -1)) # local factor free entropy - tvals.append(qtn.array_contract(arrays, inputs, output=())) + z = qtn.array_contract(arrays, inputs, output=()) + zvals.append((z, 1)) for ix, tids in tn.ind_map.items(): arrays = tuple(messages[tid, ix] for tid in tids) inputs = tuple((0,) for _ in tids) # local variable free entropy - tvals.append(qtn.array_contract(arrays, inputs, output=())) + z = qtn.array_contract(arrays, inputs, output=()) + zvals.append((z, 1)) return combine_local_contractions( - tvals, mvals, backend, strip_exponent=strip_exponent + zvals, + backend=backend, + strip_exponent=strip_exponent, + check_zero=check_zero, ) @@ -317,6 +637,16 @@ def compute_all_index_marginals_from_messages(tn, messages): return {ix: compute_index_marginal(tn, ix, messages) for ix in tn.ind_map} +def normalize_message_pair(mi, mj): + """Normalize a pair of messages such that ` = 1` and + ` = ` (but in general != 1). + """ + nij = ar.do("abs", mi @ mj) ** 0.5 + nii = (mi @ mi) ** 0.25 + njj = (mj @ mj) ** 0.25 + return mi / (nij * nii / njj), mj / (nij * njj / nii) + + def maybe_get_thread_pool(thread_pool): """Get a thread pool if requested.""" if thread_pool is False: diff --git a/quimb/experimental/belief_propagation/d1bp.py b/quimb/experimental/belief_propagation/d1bp.py index 616cf87a..89d11751 100644 --- a/quimb/experimental/belief_propagation/d1bp.py +++ b/quimb/experimental/belief_propagation/d1bp.py @@ -10,12 +10,14 @@ import autoray as ar +from quimb.tensor import Tensor, rand_uuid from quimb.tensor.contraction import array_contract from quimb.utils import oset from .bp_common import ( BeliefPropagationCommon, combine_local_contractions, + normalize_message_pair, ) from .hd1bp import ( compute_all_tensor_messages_tree, @@ -23,10 +25,6 @@ def initialize_messages(tn, fill_fn=None): - - backend = ar.infer_backend(next(t.data for t in tn)) - _sum = ar.get_lib_fn(backend, "sum") - messages = {} for ix, tids in tn.ind_map.items(): if len(tids) != 2: @@ -44,7 +42,7 @@ def initialize_messages(tn, fill_fn=None): inputs=(tuple(range(t_from.ndim)),), output=(t_from.inds.index(ix),), ) - messages[ix, tid_to] = m / _sum(m) + messages[ix, tid_to] = m return messages @@ -62,10 +60,32 @@ class D1BP(BeliefPropagationCommon): messages : dict[(str, int), array_like], optional The initial messages to use, effectively defaults to all ones if not specified. - damping : float, optional - The damping factor to use, 0.0 means no damping. + damping : float or callable, optional + The damping factor to apply to messages. This simply mixes some part + of the old message into the new one, with the final message being + ``damping * old + (1 - damping) * new``. This makes convergence more + reliable but slower. update : {'sequential', 'parallel'}, optional - Whether to update messages sequentially or in parallel. + Whether to update messages sequentially (newly computed messages are + immediately used for other updates in the same iteration round) or in + parallel (all messages are comptued using messages from the previous + round only). Sequential generally helps convergence but parallel can + possibly converge to differnt solutions. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -87,29 +107,26 @@ class D1BP(BeliefPropagationCommon): def __init__( self, tn, + *, messages=None, damping=0.0, update="sequential", + normalize=None, + distance=None, local_convergence=True, message_init_function=None, + inplace=False, ): - self.tn = tn - self.damping = damping - self.local_convergence = local_convergence - self.update = update - - self.backend = next(t.backend for t in tn) - _abs = ar.get_lib_fn(self.backend, "abs") - _sum = ar.get_lib_fn(self.backend, "sum") - - def _normalize(x): - return x / _sum(x) - - def _distance(x, y): - return _sum(_abs(x - y)) + super().__init__( + tn=tn, + damping=damping, + update=update, + normalize=normalize, + distance=distance, + inplace=inplace, + ) - self._normalize = _normalize - self._distance = _distance + self.local_convergence = local_convergence if messages is None: self.messages = initialize_messages(self.tn, message_init_function) @@ -143,19 +160,25 @@ def _compute_ms(tid): [self.messages[ix, tid] for ix in t.inds], self.backend, ) - new_ms = [self._normalize(m) for m in new_ms] + new_ms = [self._normalize_fn(m) for m in new_ms] new_ks = [self.key_pairs[ix, tid] for ix in t.inds] return new_ks, new_ms - def _update_m(key, data): + def _update_m(key, new_m): nonlocal nconv, max_mdiff - m = self.messages[key] - if self.damping != 0.0: - data = (1 - self.damping) * data + self.damping * m + old_m = self.messages[key] + + # pre-damp distance + mdiff = self._distance_fn(old_m, new_m) + + if self.damping: + new_m = self.fn_damping(old_m, new_m) + + # # post-damp distance + # mdiff = self._distance_fn(old_m, new_m) - mdiff = float(self._distance(m, data)) if mdiff > tol: # mark distination tid for update new_touched.add(key[1]) @@ -163,15 +186,15 @@ def _update_m(key, data): nconv += 1 max_mdiff = max(max_mdiff, mdiff) - self.messages[key] = data + self.messages[key] = new_m if self.update == "sequential": # compute each new message and immediately re-insert it while self.touched: tid = self.touched.pop() keys, new_ms = _compute_ms(tid) - for key, data in zip(keys, new_ms): - _update_m(key, data) + for key, new_m in zip(keys, new_ms): + _update_m(key, new_m) elif self.update == "parallel": new_data = {} @@ -179,16 +202,21 @@ def _update_m(key, data): while self.touched: tid = self.touched.pop() keys, new_ms = _compute_ms(tid) - for key, data in zip(keys, new_ms): - new_data[key] = data + for key, new_m in zip(keys, new_ms): + new_data[key] = new_m # insert all new messages - for key, data in new_data.items(): - _update_m(key, data) + for key, new_m in new_data.items(): + _update_m(key, new_m) self.touched = new_touched - return nconv, ncheck, max_mdiff - def normalize_messages(self): + return { + "nconv": nconv, + "ncheck": ncheck, + "max_mdiff": max_mdiff, + } + + def normalize_message_pairs(self): """Normalize all messages such that for each bond ` = 1` and ` = ` (but in general != 1). """ @@ -198,11 +226,30 @@ def normalize_messages(self): tida, tidb = tids mi = self.messages[ix, tida] mj = self.messages[ix, tidb] - nij = abs(mi @ mj)**0.5 - nii = (mi @ mi)**0.25 - njj = (mj @ mj)**0.25 - self.messages[ix, tida] = mi / (nij * nii / njj) - self.messages[ix, tidb] = mj / (nij * njj / nii) + mi, mj = normalize_message_pair(mi, mj) + self.messages[ix, tida] = mi + self.messages[ix, tidb] = mj + + def normalize_tensors(self, strip_exponent=True): + """Normalize every local tensor contraction so that it equals 1. Gather + the overall normalization factor into ``self.exponent`` and the sign + into ``self.sign`` by default. + + Parameters + ---------- + strip_exponent : bool, optional + Whether to collect the sign and exponent. If ``False`` then the + value of the BP contraction is set to 1. + """ + for tid, t in self.tn.tensor_map.items(): + tval = self.local_tensor_contract(tid) + tabs = ar.do("abs", tval) + tsgn = tval / tabs + tlog = ar.do("log10", tabs) + t /= tsgn * tabs + if strip_exponent: + self.sign = tsgn * self.sign + self.exponent = tlog + self.exponent def get_gauged_tn(self): """Gauge the original TN by inserting the BP-approximated transfer @@ -218,54 +265,200 @@ def get_gauged_tn(self): ma = self.messages[ka] mb = self.messages[kb] - el, ev = ar.do('linalg.eig', ar.do('outer', ma, mb)) - k = ar.do('argsort', -ar.do('abs', el)) + el, ev = ar.do("linalg.eig", ar.do("outer", ma, mb)) + k = ar.do("argsort", -ar.do("abs", el)) ev = ev[:, k] Uinv = ev - U = ar.do('linalg.inv', ev) + U = ar.do("linalg.inv", ev) tng._insert_gauge_tids(U, tida, tidb, Uinv) return tng - def contract(self, strip_exponent=False): - tvals = [] - for tid, t in self.tn.tensor_map.items(): - arrays = [t.data] - inputs = [tuple(range(t.ndim))] - for i, ix in enumerate(t.inds): - m = self.messages[ix, tid] - arrays.append(m) - inputs.append((i,)) - tvals.append( - array_contract( - arrays=arrays, - inputs=inputs, - output=(), - ) - ) + def get_cluster(self, tids): + """Get the region of tensors given by `tids`, with the messages + on the border contracted in, removing those dangling indices. - mvals = [] - for ix, tids in self.tn.ind_map.items(): - if len(tids) != 2: + Parameters + ---------- + tids : sequence of int + The tensor ids forming a region. + + Returns + ------- + TensorNetwork + """ + # take copy as we are going contract messages in + tnr = self.tn._select_tids(tids, virtual=False) + oixr = tnr.outer_inds() + for ix in oixr: + # get the tensor this index belongs to + (tid,) = tnr._get_tids_from_inds(ix) + t = tnr.tensor_map[tid] + # contract the message in, removing index + t.vector_reduce_(ix, self.messages[ix, tid]) + return tnr + + def local_tensor_contract(self, tid): + """Contract the messages around tensor ``tid``.""" + t = self.tn.tensor_map[tid] + arrays = [t.data] + inputs = [tuple(range(t.ndim))] + for i, ix in enumerate(t.inds): + m = self.messages[ix, tid] + arrays.append(m) + inputs.append((i,)) + + return array_contract( + arrays=arrays, + inputs=inputs, + output=(), + ) + + def local_message_contract(self, ix): + """Contract the messages at index ``ix``.""" + tids = self.tn.ind_map[ix] + if len(tids) != 2: + return None + tida, tidb = tids + return self.messages[ix, tida] @ self.messages[ix, tidb] + + def contract( + self, + strip_exponent=False, + check_zero=True, + **kwargs, + ): + """Estimate the contraction of the tensor network.""" + + zvals = [ + (self.local_tensor_contract(tid), 1) for tid in self.tn.tensor_map + ] + [(self.local_message_contract(ix), -1) for ix in self.tn.ind_map] + + return combine_local_contractions( + zvals, + self.backend, + strip_exponent=strip_exponent, + check_zero=check_zero, + mantissa=self.sign, + exponent=self.exponent, + **kwargs, + ) + + def contract_with_loops( + self, + max_loop_length=None, + min_loop_length=1, + optimize="auto-hq", + strip_exponent=False, + check_zero=True, + **contract_opts, + ): + """Estimate the contraction of the tensor network, including loop + corrections. + """ + self.normalize_message_pairs() + self.normalize_tensors() + + zvals = [] + + for loop in self.tn.gen_paths_loops(max_loop_length=max_loop_length): + if len(loop) < min_loop_length: continue - tida, tidb = tids - mvals.append( - self.messages[ix, tida] @ self.messages[ix, tidb] - ) + + # get the loop local patch + ltn = self.tn.select_path(loop) + + # attach boundary messages + + for ix, tids in tuple(ltn.ind_map.items()): + if ix in loop: + continue + + elif len(tids) == 1: + # outer index -> cap it with messages + (tid,) = tids + ltn |= Tensor(self.messages[ix, tid], [ix]) + + else: + # non-loop internal index -> cut it with messages + tida, tidb = tids + ma = self.messages[ix, tida] + mb = self.messages[ix, tidb] + lix = rand_uuid() + rix = rand_uuid() + ltn._cut_between_tids(tida, tidb, lix, rix) + ltn |= Tensor(ma, [lix]) + ltn |= Tensor(mb, [rix]) + + zvals.append((ltn.contract(optimize=optimize, **contract_opts), 1)) return combine_local_contractions( - tvals, mvals, self.backend, strip_exponent=strip_exponent + zvals, + backend=self.backend, + strip_exponent=strip_exponent, + check_zero=check_zero, + mantissa=self.sign, + exponent=self.exponent, ) + def contract_cluster_expansion( + self, + clusters=None, + autocomplete=True, + strip_exponent=False, + check_zero=True, + optimize="auto-hq", + **contract_opts, + ): + from .regions import RegionGraph + + if isinstance(clusters, int): + max_cluster_size = clusters + clusters = None + else: + max_cluster_size = None + + if clusters is None: + clusters = tuple( + self.tn.gen_regions(max_region_size=max_cluster_size) + ) + else: + clusters = tuple(clusters) + + rg = RegionGraph(clusters, autocomplete=autocomplete) + + zvals = [] + for r in rg.regions: + c = rg.get_count(r) + tnr = self.get_cluster(r) + zr = tnr.contract(optimize=optimize, **contract_opts) + + zvals.append((zr, c)) + + return combine_local_contractions( + zvals, + backend=self.backend, + strip_exponent=strip_exponent, + check_zero=check_zero, + mantissa=self.sign, + exponent=self.exponent, + ) def contract_d1bp( tn, + *, max_iterations=1000, tol=5e-6, damping=0.0, + diis=False, update="sequential", + normalize=None, + distance=None, + tol_abs=None, + tol_rolling_diff=None, local_convergence=True, strip_exponent=False, + check_zero=True, info=None, progbar=False, **contract_opts, @@ -279,22 +472,52 @@ def contract_d1bp( The tensor network to contract, it should have no dangling or hyper indices. max_iterations : int, optional - The maximum number of iterations to run for. + The maximum number of iterations to perform. tol : float, optional The convergence tolerance for messages. damping : float, optional The damping parameter to use, defaults to no damping. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. update : {'sequential', 'parallel'}, optional Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. strip_exponent : bool, optional - Whether to strip the exponent from the final result. If ``True`` - then the returned result is ``(mantissa, exponent)``. + Whether to return the mantissa and exponent separately. + check_zero : bool, optional + Whether to check for zero values and return zero early. info : dict, optional - If specified, update this dictionary with information about the - belief propagation run. + If supplied, the following information will be added to it: + ``converged`` (bool), ``iterations`` (int), ``max_mdiff`` (float), + ``rolling_abs_mean_diff`` (float). progbar : bool, optional Whether to show a progress bar. """ @@ -303,14 +526,20 @@ def contract_d1bp( damping=damping, local_convergence=local_convergence, update=update, + normalize=normalize, + distance=distance, **contract_opts, ) bp.run( max_iterations=max_iterations, + diis=diis, tol=tol, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, info=info, progbar=progbar, ) return bp.contract( strip_exponent=strip_exponent, + check_zero=check_zero, ) diff --git a/quimb/experimental/belief_propagation/d2bp.py b/quimb/experimental/belief_propagation/d2bp.py index 3ff31c50..b2e8680e 100644 --- a/quimb/experimental/belief_propagation/d2bp.py +++ b/quimb/experimental/belief_propagation/d2bp.py @@ -1,3 +1,5 @@ +import contextlib + import autoray as ar import quimb.tensor as qtn @@ -6,7 +8,9 @@ from .bp_common import ( BeliefPropagationCommon, combine_local_contractions, + normalize_message_pair, ) +from .regions import RegionGraph class D2BP(BeliefPropagationCommon): @@ -34,10 +38,32 @@ class D2BP(BeliefPropagationCommon): Computed automatically if not specified. optimize : str or PathOptimizer, optional The path optimizer to use when contracting the messages. - damping : float, optional - The damping factor to use, 0.0 means no damping. - update : {'parallel', 'sequential'}, optional - Whether to update all messages in parallel or sequentially. + damping : float or callable, optional + The damping factor to apply to messages. This simply mixes some part + of the old message into the new one, with the final message being + ``damping * old + (1 - damping) * new``. This makes convergence more + reliable but slower. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially (newly computed messages are + immediately used for other updates in the same iteration round) or in + parallel (all messages are comptued using messages from the previous + round only). Sequential generally helps convergence but parallel can + possibly converge to differnt solutions. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -48,41 +74,36 @@ class D2BP(BeliefPropagationCommon): def __init__( self, tn, + *, messages=None, output_inds=None, optimize="auto-hq", damping=0.0, update="sequential", + normalize=None, + distance=None, + inplace=False, local_convergence=True, **contract_opts, ): - from quimb.tensor.contraction import array_contract_expression + super().__init__( + tn=tn, + damping=damping, + update=update, + normalize=normalize, + distance=distance, + inplace=inplace, + ) - self.tn = tn self.contract_opts = contract_opts self.contract_opts.setdefault("optimize", optimize) - self.damping = damping self.local_convergence = local_convergence - self.update = update if output_inds is None: self.output_inds = set(self.tn.outer_inds()) else: self.output_inds = set(output_inds) - self.backend = next(t.backend for t in tn) - _abs = ar.get_lib_fn(self.backend, "abs") - _sum = ar.get_lib_fn(self.backend, "sum") - - def _normalize(x): - return x / _sum(x) - - def _distance(x, y): - return _sum(_abs(x - y)) - - self._normalize = _normalize - self._distance = _distance - if messages is None: self.messages = {} else: @@ -113,36 +134,39 @@ def _distance(x, y): self.touch_map[ix, tid] = this_touchmap if (ix, tid) not in self.messages: - tm = (t_in.reindex({ix: jx}).conj_() @ t_in).data - self.messages[ix, tid] = self._normalize(tm.data) + m = (t_in.reindex({ix: jx}).conj_() @ t_in).data + self.messages[ix, tid] = self._normalize_fn(m) # for efficiency setup all the contraction expressions ahead of time for ix, tids in self.tn.ind_map.items(): - if ix in self.output_inds: - continue + if ix not in self.output_inds: + self.build_expr(ix) - for tida, tidb in (sorted(tids), sorted(tids, reverse=True)): - ta = self.tn.tensor_map[tida] - kix = ta.inds - bix = tuple( - i if i in self.output_inds else i + "*" for i in kix - ) - inputs = [kix, bix] - data = [ta.data, ta.data.conj()] - shapes = [ta.shape, ta.shape] - for i in kix: - if (i != ix) and i not in self.output_inds: - inputs.append((i + "*", i)) - data.append((i, tida)) - shapes.append(self.messages[i, tida].shape) - - expr = array_contract_expression( - inputs=inputs, - output=(ix + "*", ix), - shapes=shapes, - **self.contract_opts, - ) - self.exprs[ix, tidb] = expr, data + def build_expr(self, ix): + from quimb.tensor.contraction import array_contract_expression + + tids = self.tn.ind_map[ix] + + for tida, tidb in (sorted(tids), sorted(tids, reverse=True)): + ta = self.tn.tensor_map[tida] + kix = ta.inds + bix = tuple(i if i in self.output_inds else i + "*" for i in kix) + inputs = [kix, bix] + data = [ta.data, ta.data.conj()] + shapes = [ta.shape, ta.shape] + for i in kix: + if (i != ix) and i not in self.output_inds: + inputs.append((i + "*", i)) + data.append((i, tida)) + shapes.append(self.messages[i, tida].shape) + + expr = array_contract_expression( + inputs=inputs, + output=(ix + "*", ix), + shapes=shapes, + **self.contract_opts, + ) + self.exprs[ix, tidb] = expr, data def update_touched_from_tids(self, *tids): """Specify that the messages for the given ``tids`` have changed.""" @@ -184,21 +208,22 @@ def _compute_m(key): expr, data = self.exprs[key] m = expr(*data[:2], *(self.messages[mkey] for mkey in data[2:])) # enforce hermiticity and normalize - return self._normalize(m + ar.dag(m)) + return self._normalize_fn(m + ar.dag(m)) def _update_m(key, new_m): nonlocal nconv, max_mdiff old_m = self.messages[key] - if self.damping > 0.0: - new_m = self._normalize( - self.damping * old_m + (1 - self.damping) * new_m - ) - try: - mdiff = float(self._distance(old_m, new_m)) - except (TypeError, ValueError): - # handle e.g. lazy arrays - mdiff = float("inf") + + # pre-damp distance + mdiff = self._distance_fn(old_m, new_m) + + if self.damping: + new_m = self.fn_damping(old_m, new_m) + + # # post-damp distance + # mdiff = self._distance_fn(old_m, new_m) + if mdiff > tol: # mark touching messages for update new_touched.update(self.touch_map[key]) @@ -226,7 +251,11 @@ def _update_m(key, new_m): self.touched = new_touched - return nconv, ncheck, max_mdiff + return { + "nconv": nconv, + "ncheck": ncheck, + "max_mdiff": max_mdiff, + } def compute_marginal(self, ind): """Compute the marginal for the index ``ind``.""" @@ -264,7 +293,32 @@ def compute_marginal(self, ind): p = ar.do("real", p) return p / ar.do("sum", p) - def contract(self, strip_exponent=False): + def normalize_message_pairs(self): + """Normalize a pair of messages such that ` = 1` and + ` = ` (but in general != 1). + """ + _reshape = ar.get_lib_fn(self.backend, "reshape") + + for ix, tids in self.tn.ind_map.items(): + if len(tids) != 2: + continue + tida, tidb = tids + ml = self.messages[ix, tida] + mr = self.messages[ix, tidb] + + nml, nmr = normalize_message_pair( + _reshape(ml, (-1,)), + _reshape(mr, (-1,)), + ) + + self.messages[ix, tida] = _reshape(nml, ml.shape) + self.messages[ix, tidb] = _reshape(nmr, mr.shape) + + def contract( + self, + strip_exponent=False, + check_zero=True, + ): """Estimate the total contraction, i.e. the 2-norm. Parameters @@ -277,7 +331,7 @@ def contract(self, strip_exponent=False): ------- scalar or (scalar, float) """ - tvals = [] + zvals = [] for tid, t in self.tn.tensor_map.items(): arrays = [t.data, ar.do("conj", t.data)] @@ -298,9 +352,8 @@ def contract(self, strip_exponent=False): tval = qtn.array_contract( arrays, inputs, output, **self.contract_opts ) - tvals.append(tval) + zvals.append((tval, 1)) - mvals = [] for ix, tids in self.tn.ind_map.items(): if ix in self.output_inds: continue @@ -310,10 +363,93 @@ def contract(self, strip_exponent=False): mval = qtn.array_contract( (ml, mr), ((1, 2), (1, 2)), (), **self.contract_opts ) - mvals.append(mval) + # counting factor is -1 i.e. divide by the message + zvals.append((mval, -1)) return combine_local_contractions( - tvals, mvals, self.backend, strip_exponent=strip_exponent + zvals, + backend=self.backend, + strip_exponent=strip_exponent, + check_zero=check_zero, + ) + + def contract_cluster_expansion( + self, + clusters=None, + autocomplete=True, + optimize="auto-hq", + strip_exponent=False, + check_zero=True, + info=None, + progbar=False, + **contract_opts, + ): + self.normalize_message_pairs() + + if isinstance(clusters, int): + max_cluster_size = clusters + clusters = None + else: + max_cluster_size = None + + if clusters is None: + clusters = tuple( + self.tn.gen_regions(max_region_size=max_cluster_size) + ) + else: + clusters = tuple(clusters) + + rg = RegionGraph(clusters, autocomplete=autocomplete) + + for tid in self.tn.tensor_map: + rg.add_region([tid]) + + if info is None: + info = {} + info.setdefault("contractions", {}) + contractions = info["contractions"] + + zvals = [] + + if progbar: + import tqdm + + it = tqdm.tqdm(rg.regions) + else: + it = rg.regions + + for region in it: + counting_factor = rg.get_count(region) + + if counting_factor == 0: + continue + + try: + zr = contractions[region] + except KeyError: + k = self.tn._select_tids(region, virtual=False) + b = k.conj() + # apply gauge by contracting messages into ket layer + for oix in k.outer_inds(): + if oix in self.output_inds: + continue + (tid,) = k.ind_map[oix] + m = self.messages[oix, tid] + t = k.tensor_map[tid] + t.gate_(m, oix) + zr = (k | b).contract( + optimize=optimize, + **contract_opts, + ) + contractions[region] = zr + + zvals.append((zr, counting_factor)) + + return combine_local_contractions( + zvals, + backend=self.backend, + strip_exponent=strip_exponent, + check_zero=check_zero, ) def compress( @@ -371,24 +507,163 @@ def compress( return tn + def gauge_insert(self, tn, smudge=1e-12): + """Insert the sqrt of messages on the boundary of a part of the main BP + TN. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to insert the messages into. + smudge : float, optional + Smudge factor to avoid numerical issues, the eigenvalues of the + messages are clipped to be at least the largest eigenvalue times + this factor. + + Returns + ------- + list[tuple[Tensor, str, array_like]] + The sequence of tensors, indices and inverse gauges to apply to + reverse the gauges applied. + """ + outer = [] + + _eigh = ar.get_lib_fn(self.backend, "linalg.eigh") + _clip = ar.get_lib_fn(self.backend, "clip") + _sqrt = ar.get_lib_fn(self.backend, "sqrt") + + for ix in tn.outer_inds(): + # get the tensor and dangling index + (tid,) = tn.ind_map[ix] + try: + m = self.messages[ix, tid] + except KeyError: + # could be phsyical index or not generated yet + continue + t = tn.tensor_map[tid] + + # compute the 'square root' of the message + s2, W = _eigh(m) + s2 = _clip(s2, s2[-1] * smudge, None) + s = _sqrt(s2) + msqrt = qtn.decomp.ldmul(s, ar.dag(W)) + msqrt_inv = qtn.decomp.rddiv(W, s) + t.gate_(msqrt, ix) + outer.append((t, ix, msqrt_inv)) + + return outer + + @contextlib.contextmanager + def gauge_temp(self, tn, ungauge_outer=True): + """Context manager to temporarily gauge a tensor network, presumably a + subnetwork of the main BP network, using the current messages, and then + un-gauge it afterwards. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to gauge. + ungauge_outer : bool, optional + Whether to un-gauge the outer indices of the tensor network. + """ + outer = self.gauge_insert(tn) + try: + yield outer + finally: + if ungauge_outer: + for t, ix, msqrt_inv in outer: + t.gate_(msqrt_inv, ix) + + def gate_( + self, + G, + where, + max_bond=None, + cutoff=0.0, + cutoff_mode="rsum2", + renorm=0, + tn=None, + **gate_opts, + ): + """Apply a gate to the tensor network at the specified sites, using + the current messages to gauge the tensors. + """ + if len(where) == 1: + # single site gate + self.tn.gate_(G, where, contract=True) + return + + gate_opts.setdefault("contract", "reduce-split") + + if tn is None: + tn = self.tn + site_tags = tuple(map(tn.site_tag, where)) + tn_where = tn.select_any(site_tags) + + with self.gauge_temp(tn_where): + # contract and split the gate + tn_where.gate_( + G, + where, + max_bond=max_bond, + cutoff=cutoff, + cutoff_mode=cutoff_mode, + renorm=renorm, + **gate_opts, + ) + + # update the messages for this bond + taga, tagb = site_tags + (tida,) = tn._get_tids_from_tags(taga) + (tidb,) = tn._get_tids_from_tags(tagb) + ta = tn.tensor_map[tida] + tb = tn.tensor_map[tidb] + lix, (ix,), rix = qtn.group_inds(ta, tb) + + # make use of the fact that we already have gauged tensors + A = ta.to_dense(lix, (ix,)) + B = tb.to_dense((ix,), rix) + ma = ar.dag(A) @ A + mb = B @ ar.dag(B) + + shape_changed = self.messages[ix, tidb].shape != ma.shape + + self.messages[ix, tidb] = ma + self.messages[ix, tida] = mb + + # mark the sites as touched + self.update_touched_from_tids(tida, tidb) + if shape_changed: + # rebuild the contraction expressions if shapes changed + for cix in (*lix, ix, *rix): + if cix not in self.output_inds: + self.build_expr(cix) + def contract_d2bp( tn, + *, messages=None, output_inds=None, - optimize="auto-hq", + max_iterations=1000, + tol=5e-6, damping=0.0, + diis=False, update="sequential", + normalize=None, + distance=None, + tol_abs=None, + tol_rolling_diff=None, local_convergence=True, - max_iterations=1000, - tol=5e-6, + optimize="auto-hq", strip_exponent=False, + check_zero=True, info=None, progbar=False, **contract_opts, ): """Estimate the norm squared of ``tn`` using dense 2-norm belief - propagation. + propagation (no hyper indices). Parameters ---------- @@ -397,28 +672,58 @@ def contract_d2bp( messages : dict[(str, int), array_like], optional The initial messages to use, effectively defaults to all ones if not specified. + output_inds : set[str], optional + The indices to consider as output (dangling) indices of the tn. + Computed automatically if not specified. max_iterations : int, optional The maximum number of iterations to perform. tol : float, optional The convergence tolerance for messages. - output_inds : set[str], optional - The indices to consider as output (dangling) indices of the tn. - Computed automatically if not specified. - optimize : str or PathOptimizer, optional - The path optimizer to use when contracting the messages. damping : float, optional The damping parameter to use, defaults to no damping. - update : {'parallel', 'sequential'}, optional - Whether to update all messages in parallel or sequentially. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. strip_exponent : bool, optional - Whether to strip the exponent from the final result. If ``True`` - then the returned result is ``(mantissa, exponent)``. + Whether to return the mantissa and exponent separately. + check_zero : bool, optional + Whether to check for zero values and return zero early. info : dict, optional - If specified, update this dictionary with information about the - belief propagation run. + If supplied, the following information will be added to it: + ``converged`` (bool), ``iterations`` (int), ``max_mdiff`` (float), + ``rolling_abs_mean_diff`` (float). progbar : bool, optional Whether to show a progress bar. contract_opts @@ -433,18 +738,26 @@ def contract_d2bp( messages=messages, output_inds=output_inds, optimize=optimize, - damping=damping, local_convergence=local_convergence, + damping=damping, update=update, + normalize=normalize, + distance=distance, **contract_opts, ) bp.run( max_iterations=max_iterations, + diis=diis, tol=tol, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, info=info, progbar=progbar, ) - return bp.contract(strip_exponent=strip_exponent) + return bp.contract( + strip_exponent=strip_exponent, + check_zero=check_zero, + ) def compress_d2bp( @@ -455,12 +768,17 @@ def compress_d2bp( renorm=0, messages=None, output_inds=None, - optimize="auto-hq", + max_iterations=1000, + tol=5e-6, damping=0.0, + diis=False, update="sequential", + normalize=None, + distance=None, + tol_abs=None, + tol_rolling_diff=None, local_convergence=True, - max_iterations=1000, - tol=5e-6, + optimize="auto-hq", inplace=False, info=None, progbar=False, @@ -479,25 +797,55 @@ def compress_d2bp( The cutoff to use when compressing. cutoff_mode : int, optional The cutoff mode to use when compressing. + renorm : float, optional + Whether to renormalize the singular values when compressing. messages : dict[(str, int), array_like], optional The initial messages to use, effectively defaults to all ones if not specified. + output_inds : set[str], optional + The indices to consider as output (dangling) indices of the tn. + Computed automatically if not specified. max_iterations : int, optional The maximum number of iterations to perform. tol : float, optional The convergence tolerance for messages. - output_inds : set[str], optional - The indices to consider as output (dangling) indices of the tn. - Computed automatically if not specified. - optimize : str or PathOptimizer, optional - The path optimizer to use when contracting the messages. damping : float, optional The damping parameter to use, defaults to no damping. - update : {'parallel', 'sequential'}, optional - Whether to update all messages in parallel or sequentially. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. inplace : bool, optional Whether to perform the compression inplace. info : dict, optional @@ -519,12 +867,18 @@ def compress_d2bp( optimize=optimize, damping=damping, update=update, + normalize=normalize, + distance=distance, local_convergence=local_convergence, + inplace=inplace, **contract_opts, ) bp.run( max_iterations=max_iterations, tol=tol, + diis=diis, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, info=info, progbar=progbar, ) @@ -545,6 +899,14 @@ def sample_d2bp( tol=1e-2, bias=None, seed=None, + optimize="auto-hq", + damping=0.0, + diis=False, + update="sequential", + normalize=None, + distance=None, + tol_abs=None, + tol_rolling_diff=None, local_convergence=True, progbar=False, **contract_opts, @@ -570,6 +932,40 @@ def sample_d2bp( done by raising the probability of each bit-string to this power. seed : int, optional A random seed for reproducibility. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. + damping : float, optional + The damping parameter to use, defaults to no damping. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -600,10 +996,21 @@ def sample_d2bp( bp = D2BP( tn, messages=messages, + optimize=optimize, + damping=damping, + update=update, + normalize=normalize, + distance=distance, local_convergence=local_convergence, **contract_opts, ) - bp.run(max_iterations=max_iterations, tol=tol) + bp.run( + max_iterations=max_iterations, + tol=tol, + diis=diis, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, + ) marginals = dict.fromkeys(output_inds) @@ -640,12 +1047,23 @@ def sample_d2bp( bp = D2BP( tn, - messages=bp.messages, + messages=messages, + optimize=optimize, + damping=damping, + update=update, + normalize=normalize, + distance=distance, local_convergence=local_convergence, **contract_opts, ) bp.update_touched_from_tids(*tids) - bp.run(tol=tol, max_iterations=max_iterations) + bp.run( + max_iterations=max_iterations, + tol=tol, + diis=diis, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, + ) if progbar: pbar.close() diff --git a/quimb/experimental/belief_propagation/diis.py b/quimb/experimental/belief_propagation/diis.py new file mode 100644 index 00000000..fd915947 --- /dev/null +++ b/quimb/experimental/belief_propagation/diis.py @@ -0,0 +1,236 @@ +import autoray as ar +from quimb.tensor import Tensor +from quimb.utils import ( + tree_map, + tree_unflatten, + tree_apply, + Leaf, +) + + +class ArrayInfo: + __slots__ = ("shape", "size") + + def __init__(self, shape, size): + self.shape = shape + self.size = size + + +class Vectorizer: + """Object for mapping back and forth between any nested pytree of arrays + or Tensors and a single flat vector. + + Parameters + ---------- + tree : pytree of array, optional + Any nested container of arrays, which will be flattened and packed into + a single vector. + """ + + def __init__(self, tree=None, backend=None): + self.infos = None + self.d = None + self.ref_tree = None + self.backend = backend + self._concatenate = None + self._reshape = None + if tree is not None: + self.setup(tree) + + def setup(self, tree): + self.infos = [] + self.d = 0 + + def extracter(x): + + if isinstance(x, Tensor): + array = x.data + size = x.size + info = x + else: + array = x + shape = ar.do("shape", x.shape, like=self.backend) + size = ar.do("size", x, like=self.backend) + info = ArrayInfo(shape, size) + + if self.backend is None: + # set backend from first array encountered + self.backend = ar.infer_backend(array) + self._concatenate = ar.get_lib_fn(self.backend, "concatenate") + self._reshape = ar.get_lib_fn(self.backend, "reshape") + + self.infos.append(info) + self.d += size + return Leaf + + self.ref_tree = tree_map(extracter, tree) + + def pack(self, tree): + """Take ``arrays`` and pack their values into attribute `.{name}`, by + default `.vector`. + """ + if self.infos is None: + self.setup(tree) + + def extractor(x): + if isinstance(x, Tensor): + x = x.data + arrays.append(self._reshape(x, -1)) + + arrays = [] + tree_apply(extractor, tree) + return self._concatenate(tuple(arrays)) + + def unpack(self, vector): + """Turn the single, flat ``vector`` into a sequence of arrays.""" + + def _gen_arrays(): + i = 0 + for info in self.infos: + # get the linear slice + f = i + info.size + array = self._reshape(vector[i:f], info.shape) + i = f + if isinstance(info, Tensor): + # insert array back into tensor, inplace + info.modify(data=array) + yield info + else: + yield array + + return tree_unflatten(_gen_arrays(), self.ref_tree) + + def __repr__(self): + return f"" + + +class DIIS: + """Direct Inversion in the Iterative Subspace (DIIS) method (AKA Pulay + mixing) [1] for converging fixed-point iterations. + + [1] P. Pulay, Convergence acceleration of iterative sequences. The case of + SCF iteration, 1980, Elsevier, https://doi.org/10.1016/0009-2614(80)80396-4. + + Parameters + ---------- + max_history : int + Maximum number of previous guesses to use in extrapolation. + beta : float + Mixing parameter, 0.0 means only use input guesses, 1.0 means only use + extrapolated guesses (original Pulay mixing). Default is 1.0. + rcond : float + Cutoff for small singular values in the pseudo-inverse of the B matrix. + Default is 1e-14. + """ + + def __init__(self, max_history=6, beta=1.0, rcond=1e-14): + self.max_history = max_history + self.beta = beta + self.rcond = rcond + + # storage + self.vectorizer = Vectorizer() + self.guesses = [None] * max_history + self.errors = [None] * max_history + self.lambdas = [] + self.head = self.max_history - 1 + self.B = None + self.y = None + + def _extrapolate(self): + # TODO: make this backend agnostic + import numpy as np + + if self.B is None: + dtype = ar.get_dtype_name(self.guesses[0]) + self.B = np.zeros((self.max_history + 1,) * 2, dtype=dtype) + self.y = np.zeros(self.max_history + 1, dtype=dtype) + self.B[1:, 0] = self.B[0, 1:] = self.y[0] = 1.0 + + # number of error estimates we have + d = sum(e is not None for e in self.errors) + i = self.head + error_i_conj = self.errors[i].conj() + for j in range(d): + cij = error_i_conj @ self.errors[j] + self.B[i + 1, j + 1] = cij + if i != j: + self.B[j + 1, i + 1] = cij.conj() + + # solve for coefficients, taking into account rank deficiency + Binv = np.linalg.pinv( + self.B[: d + 1, : d + 1], + rcond=self.rcond, + hermitian=True, + ) + coeffs = Binv @ self.y[: d + 1] + + # first entry is -ve. lagrange multiplier -> estimated next residual + self.lambdas.append(-coeffs[0]) + + # construct linear combination of previous guesses! + xnew = np.zeros_like(self.guesses[0]) + for ci, xi in zip(coeffs[1:], self.guesses): + xnew += ci * xi + + if self.beta != 0.0: + # allow custom mix of x + xnew: + # https://prefetch.eu/know/concept/pulay-mixing/ + # i.e. use not just x_i but also f(x_i) -> y_i + # original Pulay mixing is beta=1.0 == only xnews + for ci, ei in zip(coeffs[1:], self.errors): + xnew += (self.beta * ci) * ei + + return xnew + + def update(self, y): + """Given new output `y[i]` (the result of `f(x[i])`), update the + internal state and return the extrapolated next guess `x[i+1]`. + + Parameters + ---------- + y : pytree of array + The output of the function `f(x)`. Can be any arbitrary nested + tree structure with arrays treated at leaves. + + Returns + ------- + xnext : pytree of array + The next guess `x[i+1]` to pass to the function `f(x)`, with the + same tree structure as `y`. + """ + # convert from pytree -> single real vector + # copy is important so that sequence of + # guesses are not the same object + y = self.vectorizer.pack(y) + x = self.guesses[self.head] + if x is None: + # first guess (no extrapolation) + xnext = y + else: + self.errors[self.head] = y - x + xnext = self._extrapolate() + + self.head = (self.head + 1) % self.max_history + # # TODO: make copy backend agnostic + self.guesses[self.head] = xnext.copy() + + # convert new extrapolated guess back to pytree + return self.vectorizer.unpack(xnext) + + +class DIISPyscf: + """Thin wrapper around the PySCF DIIS implementation to handle arbitrary + pytrees of arrays, for testing purposes.""" + + def __init__(self, max_history=6): + from pyscf.lib.diis import DIIS as PDIIS + + self.pdiis = PDIIS() + self.pdiis.space = max_history + self.vectorizer = Vectorizer() + + def update(self, y): + y = self.vectorizer.pack(y) + xnext = self.pdiis.update(y) + return self.vectorizer.unpack(xnext) diff --git a/quimb/experimental/belief_propagation/hd1bp.py b/quimb/experimental/belief_propagation/hd1bp.py index 9233bedb..39024205 100644 --- a/quimb/experimental/belief_propagation/hd1bp.py +++ b/quimb/experimental/belief_propagation/hd1bp.py @@ -8,6 +8,7 @@ - [ ] implement sequential update """ + import autoray as ar import quimb.tensor as qtn @@ -174,81 +175,6 @@ def compute_all_tensor_messages_tree(x, ms, backend=None): return mouts -def iterate_belief_propagation_basic( - tn, - messages, - damping=None, - smudge_factor=1e-12, - tol=None, -): - """Run a single iteration of belief propagation. This is the basic version - that does not vectorize contractions. - - Parameters - ---------- - tn : TensorNetwork - The tensor network to run BP on. - messages : dict - The current messages. For every index and tensor id pair, there should - be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``. - smudge_factor : float, optional - A small number to add to the denominator of messages to avoid division - by zero. Note when this happens the numerator will also be zero. - - Returns - ------- - new_messages : dict - The new messages. - """ - backend = ar.infer_backend(next(iter(messages.values()))) - - # _sum = ar.get_lib_fn(backend, "sum") - # n.b. at small sizes python sum is faster than numpy sum - _sum = ar.get_lib_fn(backend, "sum") - # _max = ar.get_lib_fn(backend, "max") - _abs = ar.get_lib_fn(backend, "abs") - - def _normalize_and_insert(k, m, max_dm): - # normalize and insert - m = m / _sum(m) - - old_m = messages[k] - - if damping is not None: - # mix old and new - m = damping * old_m + (1 - damping) * m - - # compare to the old messages - dm = _sum(_abs(m - old_m)) - max_dm = max(dm, max_dm) - - # set and return the max diff so far - messages[k] = m - return max_dm - - max_dm = 0.0 - - # hyper index messages - for ix, tids in tn.ind_map.items(): - ms = compute_all_hyperind_messages_prod( - [messages[tid, ix] for tid in tids], smudge_factor - ) - for tid, m in zip(tids, ms): - max_dm = _normalize_and_insert((ix, tid), m, max_dm) - - # tensor messages - for tid, t in tn.tensor_map.items(): - inds = t.inds - ms = compute_all_tensor_messages_tree( - t.data, - [messages[ix, tid] for ix in inds], - ) - for ix, m in zip(inds, ms): - max_dm = _normalize_and_insert((tid, ix), m, max_dm) - - return messages, max_dm - - class HD1BP(BeliefPropagationCommon): """Object interface for hyper, dense, 1-norm belief propagation. This is standard belief propagation in tensor network form. @@ -259,37 +185,127 @@ class HD1BP(BeliefPropagationCommon): The tensor network to run BP on. messages : dict, optional Initial messages to use, if not given then uniform messages are used. + damping : float or callable, optional + The damping factor to apply to messages. This simply mixes some part + of the old message into the new one, with the final message being + ``damping * old + (1 - damping) * new``. This makes convergence more + reliable but slower. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially (newly computed messages are + immediately used for other updates in the same iteration round) or in + parallel (all messages are comptued using messages from the previous + round only). Sequential generally helps convergence but parallel can + possibly converge to differnt solutions. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. smudge_factor : float, optional A small number to add to the denominator of messages to avoid division by zero. Note when this happens the numerator will also be zero. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. """ def __init__( self, tn, + *, messages=None, - damping=None, + damping=0.0, + update="sequential", + normalize=None, + distance=None, smudge_factor=1e-12, + inplace=False, ): - self.tn = tn - self.backend = next(t.backend for t in tn) + super().__init__( + tn, + damping=damping, + update=update, + normalize=normalize, + distance=distance, + inplace=inplace, + ) + self.smudge_factor = smudge_factor - self.damping = damping - if messages is None: + + if callable(messages): + messages = initialize_hyper_messages( + tn, fill_fn=messages, smudge_factor=smudge_factor + ) + elif messages is None: messages = initialize_hyper_messages( tn, smudge_factor=smudge_factor ) self.messages = messages - def iterate(self, **kwargs): - self.messages, max_dm = iterate_belief_propagation_basic( - self.tn, - self.messages, - damping=self.damping, - smudge_factor=self.smudge_factor, - **kwargs, - ) - return None, None, max_dm + def iterate(self, tol=None): + if self.update == "sequential": + new_messages = self.messages + else: + new_messages = {} + + def _normalize_and_insert(key, new_m, max_mdiff): + # normalize and insert + new_m = self._normalize_fn(new_m) + old_m = self.messages[key] + + # pre-damp distance + mdiff = self._distance_fn(old_m, new_m) + + if self.damping: + new_m = self.fn_damping(old_m, new_m) + + # # post-damp distance + # mdiff = self._distance_fn(old_m, new_m) + + max_mdiff = max(mdiff, max_mdiff) + + # set and return the max diff so far + new_messages[key] = new_m + return max_mdiff + + max_mdiff = 0.0 + + # hyper index messages + for ix, tids in self.tn.ind_map.items(): + ms = compute_all_hyperind_messages_prod( + [self.messages[tid, ix] for tid in tids], + self.smudge_factor, + ) + for tid, m in zip(tids, ms): + max_mdiff = _normalize_and_insert((ix, tid), m, max_mdiff) + + if self.update == "parallel": + self.messages.update(new_messages) + new_messages.clear() + + # tensor messages + for tid, t in self.tn.tensor_map.items(): + inds = t.inds + ms = compute_all_tensor_messages_tree( + t.data, + [self.messages[ix, tid] for ix in inds], + ) + for ix, m in zip(inds, ms): + max_mdiff = _normalize_and_insert((tid, ix), m, max_mdiff) + + if self.update == "parallel": + self.messages.update(new_messages) + + return max_mdiff def get_gauged_tn(self): """Assuming the supplied tensor network has no hyper or dangling @@ -305,15 +321,15 @@ def get_gauged_tn(self): ma = self.messages[ka] mb = self.messages[kb] - el, ev = ar.do('linalg.eig', ar.do('outer', ma, mb)) - k = ar.do('argsort', -ar.do('abs', el)) + el, ev = ar.do("linalg.eig", ar.do("outer", ma, mb)) + k = ar.do("argsort", -ar.do("abs", el)) ev = ev[:, k] Uinv = ev - U = ar.do('linalg.inv', ev) + U = ar.do("linalg.inv", ev) tng._insert_gauge_tids(U, tida, tidb, Uinv) return tng - def contract(self, strip_exponent=False): + def contract(self, strip_exponent=False, check_zero=True): """Estimate the total contraction, i.e. the exponential of the 'Bethe free entropy'. """ @@ -321,6 +337,7 @@ def contract(self, strip_exponent=False): self.tn, self.messages, strip_exponent=strip_exponent, + check_zero=check_zero, backend=self.backend, ) @@ -331,8 +348,15 @@ def contract_hd1bp( max_iterations=1000, tol=5e-6, damping=0.0, + diis=False, + update="sequential", + normalize=None, + distance=None, + tol_abs=None, + tol_rolling_diff=None, smudge_factor=1e-12, strip_exponent=False, + check_zero=True, info=None, progbar=False, ): @@ -350,13 +374,45 @@ def contract_hd1bp( tol : float, optional The convergence tolerance for messages. damping : float, optional - The damping factor to use, 0.0 means no damping. + The damping parameter to use, defaults to no damping. + diis : bool or dict, optional + Whether to use direct inversion in the iterative subspace to + help converge the messages by extrapolating to low error guesses. + If a dict, should contain options for the DIIS algorithm. The + relevant options are {`max_history`, `beta`, `rcond`}. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially or in parallel. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + tol_abs : float, optional + The absolute convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. + tol_rolling_diff : float, optional + The rolling mean convergence tolerance for maximum message update + distance, if not given then taken as ``tol``. This is used to stop + running when the messages are just bouncing around the same level, + without any overall upward or downward trends, roughly speaking. smudge_factor : float, optional A small number to add to the denominator of messages to avoid division by zero. Note when this happens the numerator will also be zero. strip_exponent : bool, optional Whether to strip the exponent from the final result. If ``True`` then the returned result is ``(mantissa, exponent)``. + check_zero : bool, optional + Whether to check for zero values and return zero early. info : dict, optional If specified, update this dictionary with information about the belief propagation run. @@ -371,15 +427,24 @@ def contract_hd1bp( tn, messages=messages, damping=damping, + update=update, + normalize=normalize, + distance=distance, smudge_factor=smudge_factor, ) bp.run( max_iterations=max_iterations, tol=tol, + diis=diis, + tol_abs=tol_abs, + tol_rolling_diff=tol_rolling_diff, info=info, progbar=progbar, ) - return bp.contract(strip_exponent=strip_exponent) + return bp.contract( + strip_exponent=strip_exponent, + check_zero=check_zero, + ) def run_belief_propagation_hd1bp( diff --git a/quimb/experimental/belief_propagation/hv1bp.py b/quimb/experimental/belief_propagation/hv1bp.py index e77400e5..0f940b9f 100644 --- a/quimb/experimental/belief_propagation/hv1bp.py +++ b/quimb/experimental/belief_propagation/hv1bp.py @@ -1,5 +1,4 @@ -"""Hyper, vectorized, 1-norm, belief propagation. -""" +"""Hyper, vectorized, 1-norm, belief propagation.""" import autoray as ar @@ -13,104 +12,6 @@ ) -def initialize_messages_batched(tn, messages=None): - """Initialize batched messages for belief propagation, as the uniform - distribution. - """ - if messages is None: - messages = initialize_hyper_messages(tn) - - backend = ar.infer_backend(next(iter(messages.values()))) - _stack = ar.get_lib_fn(backend, "stack") - _array = ar.get_lib_fn(backend, "array") - - # prepare index messages - batched_inputs_m = {} - input_locs_m = {} - output_locs_m = {} - for ix, tids in tn.ind_map.items(): - rank = len(tids) - try: - batch = batched_inputs_m[rank] - except KeyError: - batch = batched_inputs_m[rank] = [[] for _ in range(rank)] - - for i, tid in enumerate(tids): - batch_i = batch[i] - # position in the stack - b = len(batch_i) - input_locs_m[tid, ix] = (rank, i, b) - output_locs_m[ix, tid] = (rank, i, b) - batch_i.append(messages[tid, ix]) - - # prepare tensor messages - batched_tensors = {} - batched_inputs_t = {} - input_locs_t = {} - output_locs_t = {} - for tid, t in tn.tensor_map.items(): - rank = t.ndim - if rank == 0: - continue - - try: - batch = batched_inputs_t[rank] - batch_t = batched_tensors[rank] - except KeyError: - batch = batched_inputs_t[rank] = [[] for _ in range(rank)] - batch_t = batched_tensors[rank] = [] - - for i, ix in enumerate(t.inds): - batch_i = batch[i] - # position in the stack - b = len(batch_i) - input_locs_t[ix, tid] = (rank, i, b) - output_locs_t[tid, ix] = (rank, i, b) - batch_i.append(messages[ix, tid]) - - batch_t.append(t.data) - - # stack messages in into single arrays - for batched_inputs in (batched_inputs_m, batched_inputs_t): - for key, batch in batched_inputs.items(): - batched_inputs[key] = _stack( - tuple(_stack(batch_i) for batch_i in batch) - ) - for rank, tensors in batched_tensors.items(): - batched_tensors[rank] = _stack(tensors) - - # make numeric masks for updating output to input messages - masks_m = {} - masks_t = {} - for masks, input_locs, output_locs in [ - (masks_m, input_locs_m, output_locs_t), - (masks_t, input_locs_t, output_locs_m), - ]: - for pair in input_locs: - (ranki, ii, bi) = input_locs[pair] - (ranko, io, bo) = output_locs[pair] - key = (ranki, ranko) - try: - maskin, maskout = masks[key] - except KeyError: - maskin, maskout = masks[key] = [], [] - maskin.append([ii, bi]) - maskout.append([io, bo]) - - for key, (maskin, maskout) in masks.items(): - masks[key] = _array(maskin), _array(maskout) - - return ( - batched_inputs_m, - batched_inputs_t, - batched_tensors, - input_locs_m, - input_locs_t, - masks_m, - masks_t, - ) - - def _compute_all_hyperind_messages_tree_batched(bm): """ """ ndim = len(bm) @@ -236,7 +137,6 @@ def _compute_all_tensor_messages_tree_batched(bx, bm): def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12): backend = ar.infer_backend_multi(bx, bm) - _einsum = ar.get_lib_fn(backend, "einsum") _stack = ar.get_lib_fn(backend, "stack") ndim = len(bm) @@ -267,129 +167,51 @@ def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12): def _compute_output_single_t( bm, bx, - _reshape, - _sum, + normalize, smudge_factor=1e-12, ): # tensor messages bmo = _compute_all_tensor_messages_tree_batched(bx, bm) # bmo = _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor) - # normalize - bmo /= _reshape(_sum(bmo, axis=-1), (*ar.shape(bmo)[:-1], 1)) + normalize(bmo) return bmo -def _compute_output_single_m(bm, _reshape, _sum, smudge_factor=1e-12): +def _compute_output_single_m(bm, normalize, smudge_factor=1e-12): # index messages # bmo = _compute_all_hyperind_messages_tree_batched(bm) bmo = _compute_all_hyperind_messages_prod_batched(bm, smudge_factor) - # normalize - bmo /= _reshape(_sum(bmo, axis=-1), (*ar.shape(bmo)[:-1], 1)) + normalize(bmo) return bmo -def _compute_outputs_batched( - batched_inputs, - batched_tensors=None, - smudge_factor=1e-12, - _pool=None, -): - """Given stacked messsages and tensors, compute stacked output messages.""" - backend = ar.infer_backend(next(iter(batched_inputs.values()))) - _sum = ar.get_lib_fn(backend, "sum") - _reshape = ar.get_lib_fn(backend, "reshape") - - if batched_tensors is not None: - # tensor messages - f = _compute_output_single_t - f_args = { - rank: (bm, batched_tensors[rank], _reshape, _sum, smudge_factor) - for rank, bm in batched_inputs.items() - } - else: - # index messages - f = _compute_output_single_m - f_args = { - rank: (bm, _reshape, _sum, smudge_factor) - for rank, bm in batched_inputs.items() - } - - batched_outputs = {} - if _pool is None: - # sequential process - for rank, args in f_args.items(): - batched_outputs[rank] = f(*args) - else: - # parallel process - for rank, args in f_args.items(): - batched_outputs[rank] = _pool.submit(f, *args) - for key, fut in batched_outputs.items(): - batched_outputs[key] = fut.result() - - return batched_outputs - - def _update_output_to_input_single_batched( - bi, - bo, + batched_input, + batched_output, maskin, maskout, - _max, - _sum, - _abs, + _distance_fn, damping=0.0, ): # do a vectorized update select_in = (maskin[:, 0], maskin[:, 1], slice(None)) select_out = (maskout[:, 0], maskout[:, 1], slice(None)) - bim = bi[select_in] - bom = bo[select_out] + bim = batched_input[select_in] + bom = batched_output[select_out] - if damping > 0.0: - bim = (1 - damping) * bom + damping * bim + # pre-damp distance + mdiff = _distance_fn(bim, bom) - # first check the change - dm = _max(_sum(_abs(bim - bom), axis=-1)) - - # update the input - bi[select_in] = bom - - return dm + if damping != 0.0: + bom = damping * bim + (1 - damping) * bom + # # post-damp distance + # mdiff = _distance_fn(bim, bom) -def _update_outputs_to_inputs_batched( - batched_inputs, batched_outputs, masks, damping=0.0, _pool=None -): - """Update the stacked input messages from the stacked output messages.""" - backend = ar.infer_backend(next(iter(batched_outputs.values()))) - _max = ar.get_lib_fn(backend, "max") - _sum = ar.get_lib_fn(backend, "sum") - _abs = ar.get_lib_fn(backend, "abs") - - f = _update_output_to_input_single_batched - f_args = ( - ( - batched_inputs[ranki], - batched_outputs[ranko], - maskin, - maskout, - _max, - _sum, - _abs, - damping, - ) - for (ranki, ranko), (maskin, maskout) in masks.items() - ) - - if _pool is None: - # sequential process - dms = (f(*args) for args in f_args) - else: - # parallel process - futs = [_pool.submit(f, *args) for args in f_args] - dms = (fut.result() for fut in futs) + # update the input + batched_input[select_in] = bom - return max(dms) + return mdiff def _extract_messages_from_inputs_batched( @@ -407,51 +229,6 @@ def _extract_messages_from_inputs_batched( return messages -def iterate_belief_propagation_batched( - batched_inputs_m, - batched_inputs_t, - batched_tensors, - masks_m, - masks_t, - smudge_factor=1e-12, - damping=0.0, - _pool=None, -): - """ """ - # compute tensor messages - batched_outputs_t = _compute_outputs_batched( - batched_inputs=batched_inputs_t, - batched_tensors=batched_tensors, - smudge_factor=smudge_factor, - _pool=_pool, - ) - # update the index input messages - t_max_dm = _update_outputs_to_inputs_batched( - batched_inputs_m, - batched_outputs_t, - masks_m, - damping=damping, - _pool=_pool, - ) - - # compute index messages - batched_outputs_m = _compute_outputs_batched( - batched_inputs=batched_inputs_m, - batched_tensors=None, - smudge_factor=smudge_factor, - _pool=_pool, - ) - # update the tensor input messages - m_max_dm = _update_outputs_to_inputs_batched( - batched_inputs_t, - batched_outputs_m, - masks_t, - damping=damping, - _pool=_pool, - ) - return batched_inputs_m, batched_inputs_t, max(t_max_dm, m_max_dm) - - class HV1BP(BeliefPropagationCommon): """Object interface for hyper, vectorized, 1-norm, belief propagation. This is the fast version of belief propagation possible when there are many, @@ -463,6 +240,32 @@ class HV1BP(BeliefPropagationCommon): The tensor network to run BP on. messages : dict, optional Initial messages to use, if not given then uniform messages are used. + damping : float or callable, optional + The damping factor to apply to messages. This simply mixes some part + of the old message into the new one, with the final message being + ``damping * old + (1 - damping) * new``. This makes convergence more + reliable but slower. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially (newly computed messages are + immediately used for other updates in the same iteration round) or in + parallel (all messages are comptued using messages from the previous + round only). Sequential generally helps convergence but parallel can + possibly converge to differnt solutions. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. smudge_factor : float, optional A small number to add to the denominator of messages to avoid division by zero. Note when this happens the numerator will also be zero. @@ -474,42 +277,306 @@ class HV1BP(BeliefPropagationCommon): def __init__( self, tn, + *, messages=None, - smudge_factor=1e-12, damping=0.0, + update="parallel", + normalize="L2", + distance="L2", + inplace=False, + smudge_factor=1e-12, thread_pool=False, ): - self.tn = tn - self.backend = next(t.backend for t in tn) + super().__init__( + tn, + damping=damping, + update=update, + normalize=normalize, + distance=distance, + inplace=inplace, + ) + + if update != "parallel": + raise ValueError("Only parallel update supported.") + self.smudge_factor = smudge_factor - self.damping = damping self.pool = maybe_get_thread_pool(thread_pool) - ( + self.initialize_messages_batched(messages) + + @property + def normalize(self): + return self._normalize + + @normalize.setter + def normalize(self, normalize): + if callable(normalize): + # custom normalization function + _normalize = normalize + elif normalize == "L1": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _normalize(bx): + bxn = _sum(_abs(bx), axis=-1, keepdims=True) + bx /= bxn + + elif normalize == "L2": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _normalize(bx): + bxn = _sum(_abs(bx) ** 2, axis=-1, keepdims=True) ** 0.5 + bx /= bxn + + elif normalize == "Linf": + _abs = ar.get_lib_fn(self.backend, "abs") + _max = ar.get_lib_fn(self.backend, "max") + + def _normalize(bx): + bxn = _max(_abs(bx), axis=-1, keepdims=True) + bx /= bxn + + else: + raise ValueError(f"Unrecognized normalize={normalize}") + + self._normalize = _normalize + + @property + def distance(self): + return self._distance + + @distance.setter + def distance(self, distance): + if callable(distance): + # custom normalization function + _distance_fn = distance + + elif distance == "L1": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + _max = ar.get_lib_fn(self.backend, "max") + + def _distance_fn(bx, by): + return _max(_sum(_abs(bx - by), axis=-1)) + + elif distance == "L2": + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + _max = ar.get_lib_fn(self.backend, "max") + + def _distance_fn(bx, by): + return _max(_sum(_abs(bx - by) ** 2, axis=-1)) ** 0.5 + + elif distance == "Linf": + _abs = ar.get_lib_fn(self.backend, "abs") + _max = ar.get_lib_fn(self.backend, "max") + + def _distance_fn(bx, by): + return _max(_abs(bx - by)) + + else: + raise ValueError(f"Unrecognized distance={distance}") + + self._distance = distance + self._distance_fn = _distance_fn + + def initialize_messages_batched(self, messages=None): + if messages is None: + # XXX: explicit use uniform distribution to avoid non-vectorized + # contractions? + messages = initialize_hyper_messages(self.tn) + + _stack = ar.get_lib_fn(self.backend, "stack") + _array = ar.get_lib_fn(self.backend, "array") + + # prepare index messages + batched_inputs_m = {} + input_locs_m = {} + output_locs_m = {} + for ix, tids in self.tn.ind_map.items(): + # all updates of the same rank can be performed simultaneously + rank = len(tids) + try: + batch = batched_inputs_m[rank] + except KeyError: + batch = batched_inputs_m[rank] = [[] for _ in range(rank)] + + for i, tid in enumerate(tids): + batch_i = batch[i] + # position in the stack + b = len(batch_i) + input_locs_m[tid, ix] = (rank, i, b) + output_locs_m[ix, tid] = (rank, i, b) + batch_i.append(messages[tid, ix]) + + # prepare tensor messages + batched_tensors = {} + batched_inputs_t = {} + input_locs_t = {} + output_locs_t = {} + for tid, t in self.tn.tensor_map.items(): + # all updates of the same rank can be performed simultaneously + rank = t.ndim + if rank == 0: + # floating scalars are not updated + continue + + try: + batch = batched_inputs_t[rank] + batch_t = batched_tensors[rank] + except KeyError: + batch = batched_inputs_t[rank] = [[] for _ in range(rank)] + batch_t = batched_tensors[rank] = [] + + for i, ix in enumerate(t.inds): + batch_i = batch[i] + # position in the stack + b = len(batch_i) + input_locs_t[ix, tid] = (rank, i, b) + output_locs_t[tid, ix] = (rank, i, b) + batch_i.append(messages[ix, tid]) + + batch_t.append(t.data) + + # stack messages in into single arrays + for batched_inputs in (batched_inputs_m, batched_inputs_t): + for key, batch in batched_inputs.items(): + batched_inputs[key] = _stack( + tuple(_stack(batch_i) for batch_i in batch) + ) + # stack all tensors of each rank into a single array + for rank, tensors in batched_tensors.items(): + batched_tensors[rank] = _stack(tensors) + + # make numeric masks for updating output to input messages + masks_m = {} + masks_t = {} + for masks, input_locs, output_locs in [ + (masks_m, input_locs_m, output_locs_t), + (masks_t, input_locs_t, output_locs_m), + ]: + for pair in input_locs: + (ranki, ii, bi) = input_locs[pair] + (ranko, io, bo) = output_locs[pair] + key = (ranki, ranko) + try: + maskin, maskout = masks[key] + except KeyError: + maskin, maskout = masks[key] = [], [] + maskin.append([ii, bi]) + maskout.append([io, bo]) + + for key, (maskin, maskout) in masks.items(): + masks[key] = _array(maskin), _array(maskout) + + self.batched_inputs_m = batched_inputs_m + self.batched_inputs_t = batched_inputs_t + self.batched_tensors = batched_tensors + self.input_locs_m = input_locs_m + self.input_locs_t = input_locs_t + self.masks_m = masks_m + self.masks_t = masks_t + + @property + def messages(self): + return (self.batched_inputs_m, self.batched_inputs_t) + + @messages.setter + def messages(self, messages): + self.batched_inputs_m, self.batched_inputs_t = messages + + def _compute_outputs_batched( + self, + batched_inputs, + batched_tensors=None, + ): + """Given stacked messsages and optionally tensors, compute stacked + output messages, possibly using parallel pool. + """ + + if batched_tensors is not None: + # tensor messages + f = _compute_output_single_t + f_args = { + rank: (bm, batched_tensors[rank], self.normalize) + for rank, bm in batched_inputs.items() + } + else: + # index messages + f = _compute_output_single_m + f_args = { + rank: (bm, self.normalize, self.smudge_factor) + for rank, bm in batched_inputs.items() + } + + batched_outputs = {} + if self.pool is None: + # sequential process + for rank, args in f_args.items(): + batched_outputs[rank] = f(*args) + else: + # parallel process + for rank, args in f_args.items(): + batched_outputs[rank] = self.pool.submit(f, *args) + for key, fut in batched_outputs.items(): + batched_outputs[key] = fut.result() + + return batched_outputs + + def _update_outputs_to_inputs_batched( + self, + batched_inputs, + batched_outputs, + masks, + ): + """Update the stacked input messages from the stacked output messages.""" + f = _update_output_to_input_single_batched + f_args = ( + ( + batched_inputs[ranki], + batched_outputs[ranko], + maskin, + maskout, + self._distance_fn, + self.damping, + ) + for (ranki, ranko), (maskin, maskout) in masks.items() + ) + + if self.pool is None: + # sequential process + mdiffs = (f(*args) for args in f_args) + else: + # parallel process + futs = [self.pool.submit(f, *args) for args in f_args] + mdiffs = (fut.result() for fut in futs) + + return max(mdiffs) + + def iterate(self, tol=None): + # first we compute new tensor output messages + self.batched_outputs_t = self._compute_outputs_batched( + batched_inputs=self.batched_inputs_t, + batched_tensors=self.batched_tensors, + ) + # update the index input messages with these + t_max_mdiff = self._update_outputs_to_inputs_batched( self.batched_inputs_m, - self.batched_inputs_t, - self.batched_tensors, - self.input_locs_m, - self.input_locs_t, + self.batched_outputs_t, self.masks_m, - self.masks_t, - ) = initialize_messages_batched(tn, messages) + ) - def iterate(self, **kwargs): - ( - self.batched_inputs_m, - self.batched_inputs_t, - max_dm, - ) = iterate_belief_propagation_batched( - self.batched_inputs_m, + # compute index messages + self.batched_outputs_m = self._compute_outputs_batched( + batched_inputs=self.batched_inputs_m, + ) + # update the tensor input messages + m_max_mdiff = self._update_outputs_to_inputs_batched( self.batched_inputs_t, - self.batched_tensors, - self.masks_m, + self.batched_outputs_m, self.masks_t, - damping=self.damping, - smudge_factor=self.smudge_factor, - _pool=self.pool, ) - return None, None, max_dm + return max(t_max_mdiff, m_max_mdiff) def get_messages(self): """Get messages in individual form from the batched stacks.""" @@ -521,6 +588,7 @@ def get_messages(self): ) def contract(self, strip_exponent=False): + # TODO: do this in batched form directly return contract_hyper_messages( self.tn, self.get_messages(), diff --git a/quimb/experimental/belief_propagation/l1bp.py b/quimb/experimental/belief_propagation/l1bp.py index 4822c938..6db25b39 100644 --- a/quimb/experimental/belief_propagation/l1bp.py +++ b/quimb/experimental/belief_propagation/l1bp.py @@ -1,5 +1,3 @@ -import autoray as ar - import quimb.tensor as qtn from quimb.utils import oset @@ -22,10 +20,32 @@ class L1BP(BeliefPropagationCommon): The tags identifying the sites in ``tn``, each tag forms a region, which should not overlap. If the tensor network is structured, then these are inferred automatically. - damping : float, optional - The damping parameter to use, defaults to no damping. - update : {'parallel', 'sequential'}, optional - Whether to update all messages in parallel or sequentially. + damping : float or callable, optional + The damping factor to apply to messages. This simply mixes some part + of the old message into the new one, with the final message being + ``damping * old + (1 - damping) * new``. This makes convergence more + reliable but slower. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially (newly computed messages are + immediately used for other updates in the same iteration round) or in + parallel (all messages are comptued using messages from the previous + round only). Sequential generally helps convergence but parallel can + possibly converge to differnt solutions. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -39,17 +59,27 @@ def __init__( self, tn, site_tags=None, + *, damping=0.0, update="sequential", + normalize=None, + distance=None, + inplace=False, local_convergence=True, optimize="auto-hq", message_init_function=None, **contract_opts, ): - self.backend = next(t.backend for t in tn) - self.damping = damping + super().__init__( + tn, + damping=damping, + update=update, + normalize=normalize, + distance=distance, + inplace=inplace, + ) + self.local_convergence = local_convergence - self.update = update self.optimize = optimize self.contract_opts = contract_opts @@ -66,33 +96,6 @@ def __init__( ) = create_lazy_community_edge_map(tn, site_tags) self.touched = oset() - self._abs = ar.get_lib_fn(self.backend, "abs") - self._max = ar.get_lib_fn(self.backend, "max") - self._sum = ar.get_lib_fn(self.backend, "sum") - _real = ar.get_lib_fn(self.backend, "real") - _argmax = ar.get_lib_fn(self.backend, "argmax") - _reshape = ar.get_lib_fn(self.backend, "reshape") - self._norm = ar.get_lib_fn(self.backend, "linalg.norm") - - def _normalize(x): - - # sx = self._sum(x) - # sphase = sx / self._abs(sx) - # smag = self._norm(x)**0.5 - # return x / (smag * sphase) - - return x / self._sum(x) - # return x / self._norm(x) - # return x / self._max(x) - # fx = _reshape(x, (-1,)) - # return x / fx[_argmax(self._abs(_real(fx)))] - - def _distance(x, y): - return self._sum(self._abs(x - y)) - - self._normalize = _normalize - self._distance = _distance - # for each meta bond create initial messages self.messages = {} for pair, bix in self.edges.items(): @@ -110,7 +113,7 @@ def _distance(x, y): **self.contract_opts, ) # normalize - tm.modify(apply=self._normalize) + tm.modify(apply=self._normalize_fn) else: shape = tuple(tn_i.ind_size(ix) for ix in bix) tm = qtn.Tensor( @@ -156,20 +159,21 @@ def _compute_m(key): optimize=self.optimize, **self.contract_opts, ) - return self._normalize(tm_new.data) + return self._normalize_fn(tm_new.data) def _update_m(key, data): nonlocal nconv, max_mdiff tm = self.messages[key] - if callable(self.damping): - damping_m = self.damping() - data = (1 - damping_m) * data + damping_m * tm.data - elif self.damping != 0.0: - data = (1 - self.damping) * data + self.damping * tm.data + # pre-damp distance + mdiff = self._distance_fn(data, tm.data) + + if self.damping: + data = self.fn_damping(data, tm.data) - mdiff = float(self._distance(tm.data, data)) + # # post-damp distance + # mdiff = self._distance_fn(data, tm.data) if mdiff > tol: # mark touching messages for update @@ -198,10 +202,14 @@ def _update_m(key, data): _update_m(key, data) self.touched = new_touched - return nconv, ncheck, max_mdiff - - def contract(self, strip_exponent=False): - tvals = [] + return { + "nconv": nconv, + "ncheck": ncheck, + "max_mdiff": max_mdiff, + } + + def contract(self, strip_exponent=False, check_zero=True): + zvals = [] for site, tn_ic in self.local_tns.items(): if site in self.neighbors: tval = qtn.tensor_contract( @@ -218,9 +226,8 @@ def contract(self, strip_exponent=False): optimize=self.optimize, **self.contract_opts, ) - tvals.append(tval) + zvals.append((tval, 1)) - mvals = [] for i, j in self.edges: mval = qtn.tensor_contract( self.messages[i, j], @@ -228,24 +235,28 @@ def contract(self, strip_exponent=False): optimize=self.optimize, **self.contract_opts, ) - mvals.append(mval) + # power / counting factor is -1 for messages + zvals.append((mval, -1)) return combine_local_contractions( - tvals, mvals, self.backend, strip_exponent=strip_exponent + zvals, + backend=self.backend, + strip_exponent=strip_exponent, + check_zero=check_zero, ) - def normalize_messages(self): + def normalize_message_pairs(self): """Normalize all messages such that for each bond ` = 1` and ` = ` (but in general != 1). """ for i, j in self.edges: tmi = self.messages[i, j] tmj = self.messages[j, i] - nij = abs(tmi @ tmj)**0.5 - nii = (tmi @ tmi)**0.25 - njj = (tmj @ tmj)**0.25 - tmi /= (nij * nii / njj) - tmj /= (nij * njj / nii) + nij = abs(tmi @ tmj) ** 0.5 + nii = (tmi @ tmi) ** 0.25 + njj = (tmj @ tmj) ** 0.25 + tmi /= nij * nii / njj + tmj /= nij * njj / nii def contract_l1bp( @@ -255,6 +266,7 @@ def contract_l1bp( site_tags=None, damping=0.0, update="sequential", + diis=False, local_convergence=True, optimize="auto-hq", strip_exponent=False, @@ -308,6 +320,7 @@ def contract_l1bp( bp.run( max_iterations=max_iterations, tol=tol, + diis=diis, info=info, progbar=progbar, ) diff --git a/quimb/experimental/belief_propagation/l2bp.py b/quimb/experimental/belief_propagation/l2bp.py index 15ff5ea0..62f2ba66 100644 --- a/quimb/experimental/belief_propagation/l2bp.py +++ b/quimb/experimental/belief_propagation/l2bp.py @@ -12,6 +12,10 @@ ) +def _identity(x): + return x + + class L2BP(BeliefPropagationCommon): """Lazy (as in multiple uncontracted tensors per site) 2-norm (as in for wavefunctions and operators) belief propagation. @@ -24,10 +28,38 @@ class L2BP(BeliefPropagationCommon): The tags identifying the sites in ``tn``, each tag forms a region, which should not overlap. If the tensor network is structured, then these are inferred automatically. - damping : float, optional - The damping parameter to use, defaults to no damping. - update : {'parallel', 'sequential'}, optional - Whether to update all messages in parallel or sequentially. + damping : float or callable, optional + The damping factor to apply to messages. This simply mixes some part + of the old message into the new one, with the final message being + ``damping * old + (1 - damping) * new``. This makes convergence more + reliable but slower. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially (newly computed messages are + immediately used for other updates in the same iteration round) or in + parallel (all messages are comptued using messages from the previous + round only). Sequential generally helps convergence but parallel can + possibly converge to differnt solutions. + normalize : {'L1', 'L2', 'L2phased', 'Linf', callable}, optional + How to normalize messages after each update. If None choose + automatically. If a callable, it should take a message and return the + normalized message. If a string, it should be one of 'L1', 'L2', + 'L2phased', 'Linf' for the corresponding norms. 'L2phased' is like 'L2' + but also normalizes the phase of the message, by default used for + complex dtypes. + distance : {'L1', 'L2', 'L2phased', 'Linf', 'cosine', callable}, optional + How to compute the distance between messages to check for convergence. + If None choose automatically. If a callable, it should take two + messages and return the distance. If a string, it should be one of + 'L1', 'L2', 'L2phased', 'Linf', or 'cosine' for the corresponding + norms. 'L2phased' is like 'L2' but also normalizes the phases of the + messages, by default used for complex dtypes if phased normalization is + not already being used. + inplace : bool, optional + Whether to perform any operations inplace on the input tensor network. + symmetrize : bool or callable, optional + Whether to symmetrize the messages, i.e. for each message ensure that + it is hermitian with respect to its bra and ket indices. If a callable + it should take a message and return the symmetrized message. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -41,16 +73,27 @@ def __init__( self, tn, site_tags=None, + *, damping=0.0, update="sequential", + normalize=None, + distance=None, + inplace=False, + symmetrize=True, local_convergence=True, optimize="auto-hq", **contract_opts, ): - self.backend = next(t.backend for t in tn) - self.damping = damping + super().__init__( + tn, + damping=damping, + update=update, + normalize=normalize, + distance=distance, + inplace=inplace, + ) + self.local_convergence = local_convergence - self.update = update self.optimize = optimize self.contract_opts = contract_opts @@ -67,25 +110,8 @@ def __init__( ) = create_lazy_community_edge_map(tn, site_tags) self.touched = oset() - _abs = ar.get_lib_fn(self.backend, "abs") - _sum = ar.get_lib_fn(self.backend, "sum") - _transpose = ar.get_lib_fn(self.backend, "transpose") - _conj = ar.get_lib_fn(self.backend, "conj") - - def _normalize(x): - return x / _sum(x) - - def _symmetrize(x): - N = ar.ndim(x) - perm = (*range(N // 2, N), *range(0, N // 2)) - return x + _conj(_transpose(x, perm)) - - def _distance(x, y): - return _sum(_abs(x - y)) - - self._normalize = _normalize - self._symmetrize = _symmetrize - self._distance = _distance + # these are all settable properties + self.symmetrize = symmetrize # initialize messages self.messages = {} @@ -106,8 +132,8 @@ def _distance(x, y): drop_tags=True, **self.contract_opts, ) - tm.modify(apply=self._symmetrize) - tm.modify(apply=self._normalize) + tm.modify(apply=self._symmetrize_fn) + tm.modify(apply=self._normalize_fn) self.messages[i, j] = tm # initialize contractions @@ -144,6 +170,36 @@ def _distance(x, y): (tn_i_left, *tks, tn_i_right), virtual=True ) + @property + def symmetrize(self): + return self._symmetrize + + @symmetrize.setter + def symmetrize(self, symmetrize): + if callable(symmetrize): + # explicit function + self._symmetrize = True + self._symmetrize_fn = symmetrize + + elif symmetrize: + # default symmetrization + _transpose = ar.get_lib_fn(self.backend, "transpose") + _conj = ar.get_lib_fn(self.backend, "conj") + + def _symmetrize_fn(x): + N = ar.ndim(x) + perm = (*range(N // 2, N), *range(0, N // 2)) + # XXX: do this blockwise for block/fermi arrays? + return x + _conj(_transpose(x, perm)) + + self._symmetrize = True + self._symmetrize_fn = _symmetrize_fn + + else: + # no symmetrization + self._symmetrize = False + self._symmetrize_fn = _identity + def iterate(self, tol=5e-6): if (not self.local_convergence) or (not self.touched): # assume if asked to iterate that we want to check all messages @@ -171,8 +227,8 @@ def _compute_m(key): optimize=self.optimize, **self.contract_opts, ) - tm_new.modify(apply=self._symmetrize) - tm_new.modify(apply=self._normalize) + tm_new.modify(apply=self._symmetrize_fn) + tm_new.modify(apply=self._normalize_fn) return tm_new.data def _update_m(key, data): @@ -180,14 +236,14 @@ def _update_m(key, data): tm = self.messages[key] - if self.damping > 0.0: - data = (1 - self.damping) * data + self.damping * tm.data + # pre-damp distance + mdiff = self._distance_fn(data, tm.data) + + if self.damping: + data = self.fn_damping(data, tm.data) - try: - mdiff = float(self._distance(tm.data, data)) - except (TypeError, ValueError): - # handle e.g. lazy arrays - mdiff = float("inf") + # # post-damp distance + # mdiff = self._distance_fn(data, tm.data) if mdiff > tol: # mark touching messages for update @@ -217,26 +273,31 @@ def _update_m(key, data): self.touched = new_touched - return nconv, ncheck, max_mdiff + return { + "nconv": nconv, + "ncheck": ncheck, + "max_mdiff": max_mdiff, + } - def normalize_messages(self): + def normalize_message_pairs(self): """Normalize all messages such that for each bond ` = 1` and - ` = ` (but in general != 1). + ` = ` (but in general != 1). This is different to + normalizing each message. """ for i, j in self.edges: tmi = self.messages[i, j] tmj = self.messages[j, i] - nij = (tmi @ tmj)**0.5 - nii = (tmi @ tmi)**0.25 - njj = (tmj @ tmj)**0.25 - tmi /= (nij * nii / njj) - tmj /= (nij * njj / nii) + nij = (tmi @ tmj) ** 0.5 + nii = (tmi @ tmi) ** 0.25 + njj = (tmj @ tmj) ** 0.25 + tmi /= nij * nii / njj + tmj /= nij * njj / nii - def contract(self, strip_exponent=False): + def contract(self, strip_exponent=False, check_zero=True): """Estimate the contraction of the norm squared using the current messages. """ - tvals = [] + zvals = [] for i, ket in self.local_tns.items(): # we allow missing keys here for tensors which are just # disconnected but still appear in local_tns @@ -250,22 +311,23 @@ def contract(self, strip_exponent=False): bra, ) ) - tvals.append( - tni.contract(all, optimize=self.optimize, **self.contract_opts) - ) + z = tni.contract(all, optimize=self.optimize, **self.contract_opts) + zvals.append((z, 1)) - mvals = [] for i, j in self.edges: - mvals.append( - (self.messages[i, j] & self.messages[j, i]).contract( - all, - optimize=self.optimize, - **self.contract_opts, - ) + z = (self.messages[i, j] & self.messages[j, i]).contract( + all, + optimize=self.optimize, + **self.contract_opts, ) + # power / counting factor is -1 for messages, i.e. divide + zvals.append((z, -1)) return combine_local_contractions( - tvals, mvals, self.backend, strip_exponent=strip_exponent + zvals, + backend=self.backend, + strip_exponent=strip_exponent, + check_zero=check_zero, ) def partial_trace( diff --git a/quimb/linalg/approx_spectral.py b/quimb/linalg/approx_spectral.py index 2e645c66..d199370b 100644 --- a/quimb/linalg/approx_spectral.py +++ b/quimb/linalg/approx_spectral.py @@ -14,24 +14,9 @@ from ..core import divide_update_, dot, njit, prod, ptr, subtract_update_, vdot from ..gen.rand import rand_phase, rand_rademacher, randn, seed_rand from ..linalg.mpi_launcher import get_mpi_pool -from ..utils import ( - default_to_neutral_style, - find_library, - format_number_with_error, - int2tup, - raise_cant_find_library_function, -) +from ..utils import format_number_with_error, int2tup from ..utils import progbar as Progbar - -if find_library("cotengra") and find_library("autoray"): - from ..tensor.tensor_1d import MatrixProductOperator - from ..tensor.tensor_approx_spectral import construct_lanczos_tridiag_MPO - from ..tensor.tensor_core import Tensor -else: - reqs = "[cotengra,autoray]" - Tensor = raise_cant_find_library_function(reqs) - construct_lanczos_tridiag_MPO = raise_cant_find_library_function(reqs) - +from ..utils_plot import default_to_neutral_style # --------------------------------------------------------------------------- # # 'Lazy' representation tensor contractions # @@ -67,6 +52,8 @@ def lazy_ptr_linop(psi_ab, dims, sysa, **linop_opts): sysa : int or sequence of int, optional Index(es) of the 'a' subsystem(s) to keep. """ + from .tensor.tensor_core import Tensor + sysa = int2tup(sysa) Kab = Tensor( @@ -126,6 +113,8 @@ def lazy_ptr_ppt_linop(psi_abc, dims, sysa, sysb, **linop_opts): Index(es) of the 'b' subsystem(s) to keep, with respect to all the dimensions, ``dims``, (i.e. pre-partial trace). """ + from .tensor.tensor_core import Tensor + sysa, sysb = int2tup(sysa), int2tup(sysb) sys_ab = sorted(sysa + sysb) @@ -517,8 +506,14 @@ def single_random_estimate( info=None, **lanczos_opts, ): + from ..tensor.tensor_1d import MatrixProductOperator + # choose normal (any LinearOperator) or MPO lanczos tridiag construction if isinstance(A, MatrixProductOperator): + from ..tensor.tensor_approx_spectral import ( + construct_lanczos_tridiag_MPO, + ) + lanc_fn = construct_lanczos_tridiag_MPO else: lanc_fn = construct_lanczos_tridiag diff --git a/quimb/tensor/optimize.py b/quimb/tensor/optimize.py index 1b4eb9f1..220324a3 100644 --- a/quimb/tensor/optimize.py +++ b/quimb/tensor/optimize.py @@ -14,12 +14,12 @@ from ..core import prod from ..utils import ( - default_to_neutral_style, ensure_dict, tree_flatten, tree_map, tree_unflatten, ) +from ..utils_plot import default_to_neutral_style from .interface import get_jax from .tensor_core import ( TensorNetwork, @@ -1416,7 +1416,7 @@ def get_tn_opt(self): """Extract the optimized tensor network, this is a three part process: 1. inject the current optimized vector into the target tensor - network, + network or pytree, 2. run it through ``norm_fn``, 3. drop any tags used to identify variables. diff --git a/quimb/utils.py b/quimb/utils.py index 91ad9d03..0d299b97 100644 --- a/quimb/utils.py +++ b/quimb/utils.py @@ -802,50 +802,6 @@ def tree_apply_dict(f, tree, is_leaf): tree_register_container(dict, tree_map_dict, tree_iter_dict, tree_apply_dict) -# a style to use for matplotlib that works with light and dark backgrounds -NEUTRAL_STYLE = { - "axes.edgecolor": (0.5, 0.5, 0.5), - "axes.facecolor": (0, 0, 0, 0), - "axes.grid": True, - "axes.labelcolor": (0.5, 0.5, 0.5), - "axes.spines.right": False, - "axes.spines.top": False, - "figure.facecolor": (0, 0, 0, 0), - "grid.alpha": 0.1, - "grid.color": (0.5, 0.5, 0.5), - "legend.frameon": False, - "text.color": (0.5, 0.5, 0.5), - "xtick.color": (0.5, 0.5, 0.5), - "xtick.minor.visible": True, - "ytick.color": (0.5, 0.5, 0.5), - "ytick.minor.visible": True, -} - - -def default_to_neutral_style(fn): - """Wrap a function or method to use the neutral style by default.""" - - @functools.wraps(fn) - def wrapper(*args, style="neutral", show_and_close=True, **kwargs): - import matplotlib.pyplot as plt - - if style == "neutral": - style = NEUTRAL_STYLE - elif not style: - style = {} - - with plt.style.context(style): - out = fn(*args, **kwargs) - - if show_and_close: - plt.show() - plt.close() - - return out - - return wrapper - - def autocorrect_kwargs(func=None, valid_kwargs=None): """A decorator that suggests the right keyword arguments if you get them wrong. Useful for functions with many specific options. diff --git a/quimb/utils_plot.py b/quimb/utils_plot.py new file mode 100644 index 00000000..3694f555 --- /dev/null +++ b/quimb/utils_plot.py @@ -0,0 +1,234 @@ +import functools +import math + +# a style to use for matplotlib that works with light and dark backgrounds +NEUTRAL_STYLE = { + "axes.edgecolor": (0.5, 0.5, 0.5), + "axes.facecolor": (0, 0, 0, 0), + "axes.grid": True, + "axes.labelcolor": (0.5, 0.5, 0.5), + "axes.spines.right": False, + "axes.spines.top": False, + "figure.facecolor": (0, 0, 0, 0), + "grid.alpha": 0.1, + "grid.color": (0.5, 0.5, 0.5), + "legend.frameon": False, + "text.color": (0.5, 0.5, 0.5), + "xtick.color": (0.5, 0.5, 0.5), + "xtick.minor.visible": True, + "ytick.color": (0.5, 0.5, 0.5), + "ytick.minor.visible": True, +} + + +def default_to_neutral_style(fn): + """Wrap a function or method to use the neutral style by default.""" + + @functools.wraps(fn) + def wrapper( + *args, + style="neutral", + show_and_close=True, + clear_previous=False, + **kwargs + ): + import matplotlib.pyplot as plt + + if clear_previous: + from IPython import display + + # clear old plots + display.clear_output(wait=True) + + if style == "neutral": + style = NEUTRAL_STYLE + elif not style: + style = {} + + with plt.style.context(style): + out = fn(*args, **kwargs) + + if show_and_close: + plt.show() + plt.close() + + return out + + return wrapper + + +def _ensure_dict(k, v): + import numpy as np + from .schematic import hash_to_color, get_color + + # ensure is a dictionaty + if not isinstance(v, dict): + v = {"y": v} + v["y"] = np.asarray(v["y"]) + + if v["y"].size == 0: + return None + + # make sure x-coords exists explicitly + if "x" not in v: + v["x"] = np.arange(v["y"].size) + else: + v["x"] = np.asarray(v["x"]) + + # set label as data name by default + v.setdefault("label", k) + + if v.get("color", None) is None: + label = v["label"] + if label is None: + v["color"] = get_color("blue") + else: + v["color"] = hash_to_color(k, vmin=0.75, vmax=0.85) + + return v + + +@default_to_neutral_style +def plot_multi_series_zoom( + data, + zoom="auto", + zoom_max=100, + zoom_marker="|", + zoom_markersize=3, + xlabel="Iteration", + figsize=None, + **kwargs, +): + """Plot possibly multiple series of data, using the asinh scale for an + overview and a linear scale for a zoomed in final section. + + Parameters + ---------- + data : dict[dict], dict[array], dict, array, optional + The data to plot. + """ + import matplotlib as mpl + import matplotlib.pyplot as plt + + if isinstance(data, dict) and "y" not in data: + # multiple plain, or configured, sequences supplied + data = [_ensure_dict(k, v) for k, v in data.items()] + else: + # single plain, or configured, sequence supplied + data = [_ensure_dict(None, data)] + + # remove any empty data + data = [d for d in data if d is not None] + + nrows = len(data) + + if figsize is None: + figsize = (8, 2 * nrows) + + fig, axs = plt.subplots( + nrows=nrows, + ncols=2, + figsize=figsize, + width_ratios=(3, 2), + gridspec_kw={"wspace": 0.05, "hspace": 0.10}, + squeeze=False, + ) + + n = max(d["x"][-1] for d in data) + if zoom is not None: + if zoom == "auto": + zoom = min(zoom_max, n // 2) + nz = n - zoom + + for i, d in enumerate(data): + # get data and correct zoomed range + x = d.pop("x") + y = d.pop("y") + iz = min(range(x.size), key=lambda i: x[i] < nz) + + label = d.pop("label") + color = d.pop("color") + yscale = d.pop("yscale", kwargs.get("yscale", "linear")) + + # plot overview + ax = axs[i, 0] + ax.plot( + x, + y, + color=color, + linewidth=1, + ) + if label is not None: + ax.text( + 0.05, + 1.0, + label, + color=color, + transform=ax.transAxes, + ha="left", + va="top", + ) + # x props + ax.set_xscale("asinh", linear_width=20) + ax.xaxis.set_major_locator( + mpl.ticker.AsinhLocator(20, numticks=6, subs=range(10)) + ) + # y props + ax.tick_params(axis="y", colors=color, which="both") + if yscale == "linear": + ax.yaxis.set_major_formatter( + mpl.ticker.ScalarFormatter(useOffset=False) + ) + else: + ax.set_yscale(yscale) + + # highlight zoomed range + ax.axvspan(nz, n, alpha=0.15, color=(0.5, 0.5, 0.5)) + + # plot zoom + ax = axs[i, 1] + ax.plot( + x[iz:], + y[iz:], + color=color, + marker=zoom_marker, + markersize=zoom_markersize, + ) + # y props + ax.yaxis.tick_right() + ax.spines["left"].set_visible(False) + ax.spines["right"].set_visible(True) + ax.tick_params(axis="y", colors=color, which="both") + if yscale == "linear": + ax.yaxis.set_major_formatter( + mpl.ticker.ScalarFormatter(useOffset=False) + ) + else: + ax.set_yscale(yscale) + + # remove ticklabels on all but last row + for i in range(nrows - 1): + axs[i, 0].tick_params(axis="x", labelbottom=False) + axs[i, 1].tick_params(axis="x", labelbottom=False) + + # set x-limits to just cover full range of data + for i in range(nrows): + axs[i, 0].set_xlim(0.0 - 0.5, n + 0.5) + axs[i, 1].set_xlim(nz - 0.5, n + 0.5) + + # make the xticklabels appear like [0, 1, 10, 100, ...] + axs[-1, 0].xaxis.set_minor_formatter(mpl.ticker.NullFormatter()) + axs[-1, 0].xaxis.set_major_formatter( + mpl.ticker.FuncFormatter( + lambda x, _: f"{int(x):,}" + if math.isclose(x, 0) + or (x > 1 and math.isclose(math.log10(x) % 1, 0)) + else "" + ) + ) + + # set x-labels + axs[-1, 0].set_xlabel(f"{xlabel} (full)") + axs[-1, 1].set_xlabel(f"{xlabel} (zoom)") + + return fig, axs diff --git a/tests/test_tensor/test_belief_propagation/test_d1bp.py b/tests/test_tensor/test_belief_propagation/test_d1bp.py index 30347ca9..9d23b87a 100644 --- a/tests/test_tensor/test_belief_propagation/test_d1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_d1bp.py @@ -8,21 +8,27 @@ ) -def test_contract_tree_exact(): +@pytest.mark.parametrize("local_convergence", [False, True]) +def test_contract_tree_exact(local_convergence): tn = qtn.TN_rand_tree(20, 3) Z = tn.contract() info = {} - Z_bp = contract_d1bp(tn, info=info, progbar=True) + Z_bp = contract_d1bp( + tn, info=info, local_convergence=local_convergence, progbar=True + ) assert info["converged"] assert Z == pytest.approx(Z_bp, rel=1e-12) @pytest.mark.parametrize("damping", [0.0, 0.1]) -def test_contract_normal(damping): +@pytest.mark.parametrize("diis", [False, True]) +def test_contract_normal(damping, diis): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() info = {} - Z_bp = contract_d1bp(tn, damping=damping, info=info, progbar=True) + Z_bp = contract_d1bp( + tn, damping=damping, diis=diis, info=info, progbar=True + ) assert info["converged"] assert Z == pytest.approx(Z_bp, rel=1e-1) diff --git a/tests/test_tensor/test_belief_propagation/test_d2bp.py b/tests/test_tensor/test_belief_propagation/test_d2bp.py index 649c32bb..e6c8afcf 100644 --- a/tests/test_tensor/test_belief_propagation/test_d2bp.py +++ b/tests/test_tensor/test_belief_propagation/test_d2bp.py @@ -10,37 +10,52 @@ @pytest.mark.parametrize("damping", [0.0, 0.1]) @pytest.mark.parametrize("dtype", ["float32", "complex64"]) -def test_contract(damping, dtype): +@pytest.mark.parametrize("diis", [True, False]) +def test_contract(damping, dtype, diis): peps = qtn.PEPS.rand(3, 4, 3, seed=42, dtype=dtype) # normalize exactly peps /= (peps.H @ peps) ** 0.5 info = {} - N_ap = contract_d2bp(peps, damping=damping, info=info, progbar=True) + N_ap = contract_d2bp( + peps, damping=damping, diis=diis, info=info, progbar=True + ) assert info["converged"] assert N_ap == pytest.approx(1.0, rel=0.3) @pytest.mark.parametrize("dtype", ["float32", "complex64"]) -def test_tree_exact(dtype): +@pytest.mark.parametrize("local_convergence", [True, False]) +def test_tree_exact(dtype, local_convergence): psi = qtn.TN_rand_tree(20, 3, 2, dtype=dtype, seed=42) norm2 = psi.H @ psi info = {} - norm2_bp = contract_d2bp(psi, info=info, progbar=True) + norm2_bp = contract_d2bp( + psi, info=info, local_convergence=local_convergence, progbar=True + ) assert info["converged"] assert norm2_bp == pytest.approx(norm2, rel=1e-4) @pytest.mark.parametrize("damping", [0.0, 0.1]) +@pytest.mark.parametrize("diis", [True, False]) @pytest.mark.parametrize("dtype", ["float32", "complex64"]) -def test_compress(damping, dtype): +def test_compress(damping, dtype, diis): peps = qtn.PEPS.rand(3, 4, 3, seed=42, dtype=dtype) # test that using the BP compression gives better fidelity than purely # local, naive compression scheme peps_c1 = peps.compress_all(max_bond=2) info = {} - peps_c2 = compress_d2bp( - peps, max_bond=2, damping=damping, info=info, progbar=True + peps_c2 = peps.copy() + compress_d2bp( + peps_c2, + max_bond=2, + damping=damping, + diis=diis, + info=info, + inplace=True, + progbar=True, ) + assert peps_c2.max_bond() == 2 assert info["converged"] fid1 = peps_c1.H @ peps_c2 fid2 = peps_c2.H @ peps_c2 diff --git a/tests/test_tensor/test_belief_propagation/test_hd1bp.py b/tests/test_tensor/test_belief_propagation/test_hd1bp.py index 1d8de8d8..2c010a5e 100644 --- a/tests/test_tensor/test_belief_propagation/test_hd1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_hd1bp.py @@ -20,21 +20,30 @@ def test_contract_hyper(damping): assert num_solutions == pytest.approx(309273226, rel=0.1) -def test_contract_tree_exact(): +@pytest.mark.parametrize("normalize", ["L1", "L2", "Linf"]) +def test_contract_tree_exact(normalize): tn = qtn.TN_rand_tree(20, 3) Z = tn.contract() info = {} - Z_bp = contract_hd1bp(tn, info=info, progbar=True) + Z_bp = contract_hd1bp( + tn, + info=info, + normalize=normalize, + progbar=True, + ) assert info["converged"] assert Z == pytest.approx(Z_bp, rel=1e-12) @pytest.mark.parametrize("damping", [0.0, 0.1]) -def test_contract_normal(damping): +@pytest.mark.parametrize("diis", [False, True]) +def test_contract_normal(damping, diis): tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) Z = tn.contract() info = {} - Z_bp = contract_hd1bp(tn, damping=damping, info=info, progbar=True) + Z_bp = contract_hd1bp( + tn, damping=damping, diis=diis, info=info, progbar=True + ) assert info["converged"] assert Z == pytest.approx(Z_bp, rel=1e-1) diff --git a/tests/test_tensor/test_belief_propagation/test_l1bp.py b/tests/test_tensor/test_belief_propagation/test_l1bp.py index 705be708..697b919b 100644 --- a/tests/test_tensor/test_belief_propagation/test_l1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_l1bp.py @@ -7,22 +7,37 @@ @pytest.mark.parametrize("dtype", ["float32", "complex64"]) -def test_contract_tree_exact(dtype): +@pytest.mark.parametrize("local_convergence", [False, True]) +@pytest.mark.parametrize("normalize", ["L1", "L2", "Linf"]) +def test_contract_tree_exact(dtype, local_convergence, normalize): tn = qtn.TN_rand_tree(10, 3, seed=42, dtype=dtype) Z_ex = tn.contract() info = {} - Z_bp = contract_l1bp(tn, info=info, progbar=True) + Z_bp = contract_l1bp( + tn, + info=info, + normalize=normalize, + local_convergence=local_convergence, + progbar=True, + ) assert info["converged"] assert Z_ex == pytest.approx(Z_bp, rel=5e-6) @pytest.mark.parametrize("dtype", ["float32", "complex64"]) @pytest.mark.parametrize("damping", [0.0, 0.1]) -def test_contract_loopy_approx(dtype, damping): +@pytest.mark.parametrize("diis", [False, True]) +def test_contract_loopy_approx(dtype, damping, diis): tn = qtn.TN2D_rand(3, 4, 5, dtype=dtype, dist="uniform") Z_ex = tn.contract() info = {} - Z_bp = contract_l1bp(tn, damping=damping, info=info, progbar=True) + Z_bp = contract_l1bp( + tn, + damping=damping, + diis=diis, + info=info, + progbar=True, + ) assert info["converged"] assert Z_ex == pytest.approx(Z_bp, rel=0.1) @@ -36,13 +51,17 @@ def test_contract_double_loopy_approx(dtype, damping, update): Z_ex = tn.contract() info = {} Z_bp1 = contract_l1bp( - tn, damping=damping, update=update, info=info, progbar=True + tn, + damping=damping, + update=update, + info=info, + progbar=True, ) assert info["converged"] assert Z_bp1 == pytest.approx(Z_ex, rel=0.3) # compare with 2-norm BP on the peps directly Z_bp2 = contract_d2bp(peps) - assert Z_bp1 == pytest.approx(Z_bp2, rel=5e-6) + assert Z_bp1 == pytest.approx(Z_bp2, rel=5e-5) @pytest.mark.parametrize("dtype", ["float32", "complex64"]) @@ -82,7 +101,8 @@ def test_contract_tree_triple_sandwich_exact(dtype): @pytest.mark.parametrize("dtype", ["float32", "complex64"]) @pytest.mark.parametrize("damping", [0.0, 0.1]) -def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping): +@pytest.mark.parametrize("diis", [False, True]) +def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping, diis): edges = qtn.edges_2d_hexagonal(2, 3) ket = qtn.TN_from_edges_rand( edges, @@ -100,7 +120,13 @@ def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping): tn = ket.H | G_ket Z_ex = tn.contract() info = {} - Z_bp = contract_l1bp(tn, damping=damping, info=info, progbar=True) + Z_bp = contract_l1bp( + tn, + damping=damping, + diis=diis, + info=info, + progbar=True, + ) assert info["converged"] assert Z_bp == pytest.approx(Z_ex, rel=0.5)