From 69f28e5f3998917cb2e10a26749d04bb602964ab Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Tue, 17 Dec 2024 18:17:23 -0800 Subject: [PATCH] TN equalize_norms add check_zero option and turn on for simplifying --- quimb/tensor/tensor_core.py | 89 ++++++++++++++++++++++++++++++------- 1 file changed, 72 insertions(+), 17 deletions(-) diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index 80d4af49..72554f0f 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -9011,7 +9011,8 @@ def insert_compressor_between_regions( # then form the 'oblique' projectors Pl, Pr = decomp.compute_oblique_projectors( - Rl, Rr, + Rl, + Rr, max_bond=max_bond, cutoff=cutoff, **compress_opts, @@ -9389,7 +9390,7 @@ def randomize(self, dtype=None, seed=None, inplace=False, **randn_opts): randomize_ = functools.partialmethod(randomize, inplace=True) - def strip_exponent(self, tid_or_tensor, value=None): + def strip_exponent(self, tid_or_tensor, value=None, check_zero=False): """Scale the elements of tensor corresponding to ``tid`` so that the norm of the array is some value, which defaults to ``1``. The log of the scaling factor, base 10, is then accumulated in the ``exponent`` @@ -9401,6 +9402,11 @@ def strip_exponent(self, tid_or_tensor, value=None): The tensor identifier or actual tensor. value : None or float, optional The value to scale the norm of the tensor to. + check_zero : bool, optional + Whether to check if the tensor has zero norm and in that case do + nothing, since the `exponent` would be -inf. Off by default to + avoid data dependent computational graphs when tracing and + computing gradients etc. """ if (value is None) or (value is True): value = 1.0 @@ -9411,6 +9417,10 @@ def strip_exponent(self, tid_or_tensor, value=None): t = self.tensor_map[tid_or_tensor] stripped_factor = t.norm() / value + + if check_zero and (stripped_factor == 0.0): + return + t.modify(apply=lambda data: data / stripped_factor) self.exponent = self.exponent + do("log10", stripped_factor) @@ -9425,7 +9435,7 @@ def distribute_exponent(self): # reset the exponent to zero self.exponent = 0.0 - def equalize_norms(self, value=None, inplace=False): + def equalize_norms(self, value=None, check_zero=False, inplace=False): """Make the Frobenius norm of every tensor in this TN equal without changing the overall value if ``value=None``, or set the norm of every tensor to ``value`` by scalar multiplication only. @@ -9436,6 +9446,11 @@ def equalize_norms(self, value=None, inplace=False): Set the norm of each tensor to this value specifically. If supplied the change in overall scaling will be accumulated in ``tn.exponent`` in the form of a base 10 power. + check_zero : bool, optional + Whether, if and when equalizing norms, to check if tensors have + zero norm and in that case do nothing, since the `exponent` would + be -inf. Off by default to avoid data dependent computational + graphs when tracing and computing gradients etc. inplace : bool, optional Whether to perform the norm equalization inplace or not. @@ -9446,7 +9461,7 @@ def equalize_norms(self, value=None, inplace=False): tn = self if inplace else self.copy() for tid in tn.tensor_map: - tn.strip_exponent(tid, value=value) + tn.strip_exponent(tid, value=value, check_zero=check_zero) if value is None: tn.distribute_exponent() @@ -9591,6 +9606,7 @@ def rank_simplify( equalize_norms=False, cache=None, max_combinations=500, + check_zero=False, inplace=False, ): """Simplify this tensor network by performing contractions that don't @@ -9607,6 +9623,11 @@ def rank_simplify( exponent in ``tn.exponent``. cache : None or set Persistent cache used to mark already checked tensors. + check_zero : bool, optional + Whether, if and when equalizing norms, to check if tensors have + zero norm and in that case do nothing, since the `exponent` would + be -inf. Off by default to avoid data dependent computational + graphs when tracing and computing gradients etc. inplace : bool, optional Whether to perform the rand reduction inplace. @@ -9752,7 +9773,7 @@ def rank_weight(ind): tn |= tab if equalize_norms: - tn.strip_exponent(tab, equalize_norms) + tn.strip_exponent(tab, equalize_norms, check_zero=check_zero) for ix in out_ab: # now we need to check outputs indices again @@ -9760,10 +9781,16 @@ def rank_weight(ind): if scalars: if equalize_norms: + # move overall scaling factor into exponent, absorb phase signs = [] for s in scalars: - signs.append(s / do("abs", s)) - tn.exponent += do("log10", do("abs", s)) + sa = do("abs", s) + if check_zero and (sa == 0.0): + # whole contraction is zero + signs = [0.0] + break + signs.append(s / sa) + tn.exponent += do("log10", sa) scalars = signs if tn.num_tensors: @@ -10023,6 +10050,7 @@ def split_simplify( atol=1e-12, equalize_norms=False, cache=None, + check_zero=False, inplace=False, **split_opts, ): @@ -10039,6 +10067,11 @@ def split_simplify( exponent in ``tn.exponent``. cache : None or set Persistent cache used to mark already checked tensors. + check_zero : bool, optional + Whether, if and when equalizing norms, to check if tensors have + zero norm and in that case do nothing, since the `exponent` would + be -inf. Off by default to avoid data dependent computational + graphs when tracing and computing gradients etc. inplace, bool, optional Whether to perform the split simplification inplace. """ @@ -10075,8 +10108,12 @@ def split_simplify( tn |= tr if equalize_norms: - tn.strip_exponent(tl, equalize_norms) - tn.strip_exponent(tr, equalize_norms) + tn.strip_exponent( + tl, equalize_norms, check_zero=check_zero + ) + tn.strip_exponent( + tr, equalize_norms, check_zero=check_zero + ) else: cache.add(cache_key) @@ -10093,6 +10130,7 @@ def pair_simplify( cache=None, equalize_norms=False, max_combinations=500, + check_zero=False, inplace=False, **split_opts, ): @@ -10180,8 +10218,8 @@ def gen_pairs(): tensor_fuse_squeeze(tl, tr) if equalize_norms: - tn.strip_exponent(tl, equalize_norms) - tn.strip_exponent(tr, equalize_norms) + tn.strip_exponent(tl, equalize_norms, check_zero=check_zero) + tn.strip_exponent(tr, equalize_norms, check_zero=check_zero) queue.extend(tl.inds) queue.extend(tr.inds) @@ -10199,6 +10237,7 @@ def loop_simplify( loops=None, cache=None, equalize_norms=False, + check_zero=False, inplace=False, **split_opts, ): @@ -10218,6 +10257,11 @@ def loop_simplify( cache : set, optional For performance reasons can supply a cache for already checked loops. + check_zero : bool, optional + Whether, if and when equalizing norms, to check if tensors have + zero norm and in that case do nothing, since the `exponent` would + be -inf. Off by default to avoid data dependent computational + graphs when tracing and computing gradients etc. inplace : bool, optional Whether to replace the loops inplace. split_opts @@ -10298,8 +10342,8 @@ def loop_simplify( tensor_fuse_squeeze(tl, tr) if equalize_norms: - tn.strip_exponent(tl, equalize_norms) - tn.strip_exponent(tr, equalize_norms) + tn.strip_exponent(tl, equalize_norms, check_zero=check_zero) + tn.strip_exponent(tr, equalize_norms, check_zero=check_zero) return tn @@ -10312,13 +10356,14 @@ def full_simplify( atol=1e-12, equalize_norms=False, cache=None, - inplace=False, - progbar=False, rank_simplify_opts=None, loop_simplify_opts=None, split_simplify_opts=None, custom_methods=(), split_method="svd", + check_zero=True, + inplace=False, + progbar=False, ): """Perform a series of tensor network 'simplifications' in a loop until there is no more reduction in the number of tensors or indices. Note @@ -10357,6 +10402,9 @@ def full_simplify( cache : None or set A persistent cache for each simplification process to mark already processed tensors. + check_zero : bool, optional + Whether to check if tensors have zero norm and in that case do + nothing if and when equalizing norms, rather than generating a NaN. progbar : bool, optional Show a live progress bar of the simplification process. inplace : bool, optional @@ -10422,6 +10470,7 @@ def full_simplify( output_inds=ix_o, cache=cache, equalize_norms=equalize_norms, + check_zero=check_zero, **rank_simplify_opts, ) elif meth == "A": @@ -10435,6 +10484,7 @@ def full_simplify( atol=atol, cache=cache, equalize_norms=equalize_norms, + check_zero=check_zero, **split_simplify_opts, ) elif meth == "L": @@ -10443,6 +10493,7 @@ def full_simplify( cutoff=atol, cache=cache, equalize_norms=equalize_norms, + check_zero=check_zero, **loop_simplify_opts, ) elif meth == "P": @@ -10451,6 +10502,7 @@ def full_simplify( cutoff=atol, cache=cache, equalize_norms=equalize_norms, + check_zero=check_zero, **loop_simplify_opts, ) else: @@ -10462,9 +10514,10 @@ def full_simplify( if equalize_norms: if equalize_norms is True: # this also redistributes the collected exponents - tn.equalize_norms_() + value = None else: - tn.equalize_norms_(value=equalize_norms) + value = equalize_norms + tn.equalize_norms_(value=value, check_zero=check_zero) if progbar: pbar.close() @@ -10594,6 +10647,7 @@ def compress_simplify( max_simplification_iterations=100, converged_tol=0.01, equalize_norms=True, + check_zero=True, progbar=False, inplace=False, **full_simplify_opts, @@ -10606,6 +10660,7 @@ def compress_simplify( simplify_opts = { "atol": atol, "equalize_norms": equalize_norms, + "check_zero": check_zero, "progbar": progbar, "output_inds": output_inds, "cache": set(),