From abfa80b1c7bf21b158e5be10ddb0452c6ececc61 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Tue, 18 Jun 2024 14:15:57 -0700 Subject: [PATCH] small fixes --- quimb/tensor/fermion/fermion_core.py | 7 +++---- quimb/tensor/tensor_arbgeom.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/quimb/tensor/fermion/fermion_core.py b/quimb/tensor/fermion/fermion_core.py index 807b8b68..3ed29b6a 100644 --- a/quimb/tensor/fermion/fermion_core.py +++ b/quimb/tensor/fermion/fermion_core.py @@ -2,12 +2,11 @@ """ import copy import functools -from operator import add import contextlib import numpy as np import scipy.sparse.linalg as spla import opt_einsum as oe -from opt_einsum.contract import parse_backend, _tensordot, _transpose +from opt_einsum.contract import parse_backend from autoray import conj from ...utils import (oset, valmap, check_opt) @@ -467,7 +466,7 @@ def _launch_fermion_expression( right_pos.append(input_right.find(s)) # Contract! - new_view = _tensordot(Ta.data, Tb.data, axes=(tuple(left_pos), tuple(right_pos)), backend=backend) + new_view = np.tensordot(Ta.data, Tb.data, axes=(tuple(left_pos), tuple(right_pos))) global_phase += Ta.phase.get("global_flip", False) \ + Tb.phase.get("global_flip", False) @@ -480,7 +479,7 @@ def _launch_fermion_expression( # Build a new view if needed if (tensor_result != results_index): transpose = tuple(map(tensor_result.index, results_index)) - new_view = _transpose(new_view, axes=transpose, backend=backend) + new_view = np.transpose(new_view, axes=transpose) o_ix = [o_ix[ix] for ix in transpose] o_tags = oset.union(Ta.tags, Tb.tags) diff --git a/quimb/tensor/tensor_arbgeom.py b/quimb/tensor/tensor_arbgeom.py index 621873cf..f8302edc 100644 --- a/quimb/tensor/tensor_arbgeom.py +++ b/quimb/tensor/tensor_arbgeom.py @@ -804,8 +804,8 @@ def local_expectation_cluster( tuple(map(self.site_tag, where)), 'any' ) - if len(tids) == 2: - tids = self._get_string_between_tids(*tids) + # if len(tids) == 2: + # tids = self._get_string_between_tids(*tids) k = self._select_local_tids( tids,