From 20f58132b6257e5c0fb0a06357ebd2579dc25ab7 Mon Sep 17 00:00:00 2001 From: Tofani-Kanudo Date: Sat, 20 Aug 2022 17:24:59 -0400 Subject: [PATCH] Bug Fix For Numpy Sum --- tensornetwork/backends/numpy/numpy_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensornetwork/backends/numpy/numpy_backend.py b/tensornetwork/backends/numpy/numpy_backend.py index 2d3753911..9f1a9220a 100644 --- a/tensornetwork/backends/numpy/numpy_backend.py +++ b/tensornetwork/backends/numpy/numpy_backend.py @@ -604,7 +604,9 @@ def sum(self, tensor: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False) -> Tensor: - return np.sum(tensor, axis=tuple(axis), keepdims=keepdims) + if axis is not None and type(axis) == list: + return np.sum(tensor, axis=tuple(axis), keepdims=keepdims) + return np.sum(tensor, axis=axis, keepdims=keepdims) def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: if (tensor1.ndim <= 1) or (tensor2.ndim <= 1):