diff --git a/tensornetwork/backends/pytorch/pytorch_backend.py b/tensornetwork/backends/pytorch/pytorch_backend.py index e0b1dc743..2f48e973e 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend.py +++ b/tensornetwork/backends/pytorch/pytorch_backend.py @@ -410,45 +410,16 @@ def trace(self, tensor: Tensor, offset: int = 0, axis1: int = -2, axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is summed. - In the PyTorch backend the trace is always over the main diagonal of the - last two entries. - Args: tensor: A tensor. offset: Offset of the diagonal from the main diagonal. - This argument is not supported by the PyTorch - backend and an error will be raised if they are - specified. axis1, axis2: Axis to be used as the first/second axis of the 2D sub-arrays from which the diagonals should be taken. - Defaults to first/second axis. - These arguments are not supported by the PyTorch - backend and an error will be raised if they are - specified. + Defaults to second-last/last axis. Returns: array_of_diagonals: The batched summed diagonals. """ - if offset != 0: - errstr = (f"offset = {offset} must be 0 (the default)" - f"with PyTorch backend.") - raise NotImplementedError(errstr) - if axis1 == axis2: - raise ValueError(f"axis1 = {axis1} cannot equal axis2 = {axis2}") - N = len(tensor.shape) - if N > 25: - raise ValueError(f"Currently only tensors with ndim <= 25 can be traced" - f"in the PyTorch backend (yours was {N})") - - if axis1 < 0: - axis1 = N+axis1 - if axis2 < 0: - axis2 = N+axis2 - - inds = list(map(chr, range(98, 98+N))) - indsout = [i for n, i in enumerate(inds) if n not in (axis1, axis2)] - inds[axis1] = 'a' - inds[axis2] = 'a' - return torchlib.einsum(''.join(inds) + '->' +''.join(indsout), tensor) + return torchlib.sum(torchlib.diagonal(tensor, offset=offset, dim1=axis1, dim2=axis2), dim=-1) def abs(self, tensor: Tensor) -> Tensor: """ diff --git a/tensornetwork/backends/pytorch/pytorch_backend_test.py b/tensornetwork/backends/pytorch/pytorch_backend_test.py index d5fbade5f..cbedf9952 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend_test.py +++ b/tensornetwork/backends/pytorch/pytorch_backend_test.py @@ -621,27 +621,15 @@ def test_trace(dtype, offset, axis1, axis2): shape = (5, 5, 5, 5) backend = pytorch_backend.PyTorchBackend() array = backend.randn(shape, dtype=dtype, seed=10) - if offset != 0: - with pytest.raises(NotImplementedError): - actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) - - elif axis1 == axis2: - with pytest.raises(ValueError): + if axis1 == axis2: + with pytest.raises(RuntimeError): actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) else: actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) - expected = np.trace(array, axis1=axis1, axis2=axis2) + expected = np.trace(array, offset=offset, axis1=axis1, axis2=axis2) np.testing.assert_allclose(actual, expected, atol=1e-6, rtol=1e-6) -def test_trace_raises(): - shape = tuple([1] * 30) - backend = pytorch_backend.PyTorchBackend() - array = backend.randn(shape, seed=10) - with pytest.raises(ValueError): - _ = backend.trace(array) - - @pytest.mark.parametrize("pivot_axis", [-1, 1, 2]) @pytest.mark.parametrize("dtype", torch_randn_dtypes) def test_pivot(dtype, pivot_axis):