Skip to content
This repository has been archived by the owner on Nov 7, 2024. It is now read-only.

Commit

Permalink
Fix pytorch trace tests
Browse files Browse the repository at this point in the history
  • Loading branch information
merajhashemi committed Mar 25, 2021
1 parent aae89ae commit 5098980
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions tensornetwork/backends/pytorch/pytorch_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,16 +621,12 @@ def test_trace(dtype, offset, axis1, axis2):
shape = (5, 5, 5, 5)
backend = pytorch_backend.PyTorchBackend()
array = backend.randn(shape, dtype=dtype, seed=10)
if offset != 0:
with pytest.raises(NotImplementedError):
actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2)

elif axis1 == axis2:
with pytest.raises(ValueError):
if axis1 == axis2:
with pytest.raises(RuntimeError):
actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2)
else:
actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2)
expected = np.trace(array, axis1=axis1, axis2=axis2)
expected = np.trace(array, offset=offset, axis1=axis1, axis2=axis2)
np.testing.assert_allclose(actual, expected, atol=1e-6, rtol=1e-6)


Expand Down

0 comments on commit 5098980

Please sign in to comment.