diff --git a/tensornetwork/backends/pytorch/pytorch_backend_test.py b/tensornetwork/backends/pytorch/pytorch_backend_test.py index d5fbade5f..7fcc5f256 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend_test.py +++ b/tensornetwork/backends/pytorch/pytorch_backend_test.py @@ -621,16 +621,12 @@ def test_trace(dtype, offset, axis1, axis2): shape = (5, 5, 5, 5) backend = pytorch_backend.PyTorchBackend() array = backend.randn(shape, dtype=dtype, seed=10) - if offset != 0: - with pytest.raises(NotImplementedError): - actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) - - elif axis1 == axis2: - with pytest.raises(ValueError): + if axis1 == axis2: + with pytest.raises(RuntimeError): actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) else: actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) - expected = np.trace(array, axis1=axis1, axis2=axis2) + expected = np.trace(array, offset=offset, axis1=axis1, axis2=axis2) np.testing.assert_allclose(actual, expected, atol=1e-6, rtol=1e-6)