Skip to content

Commit

Permalink
[ROCm] Update ROCm skip decorators (pytorch#106138)
Browse files Browse the repository at this point in the history
This PR adds a msg argument for skipIfRocm and skipCUDAIfRocm.

Pull Request resolved: pytorch#106138
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/albanD
  • Loading branch information
lcskrishna authored and pytorchmergebot committed Aug 18, 2023
1 parent 28be2c6 commit bc662ff
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
2 changes: 1 addition & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def test_streaming_backwards_sync(self):
self.assertEqual(torch.cuda.current_stream(), bwd_ambient_stream)

# Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
@skipIfRocm
@skipIfRocm(msg="flakey on ROCm https://github.com/pytorch/pytorch/issues/53190")
def test_streaming_backwards_multiple_streams(self):
MultiplyInStream = self._make_multiply_in_stream()

Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11224,7 +11224,7 @@ def ctc_after_softmax(x):
gradcheck(ctc_after_softmax, [x])

@onlyCUDA
@skipCUDAIfRocm
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
@skipCUDAIfCudnnVersionLessThan(7600)
def test_ctc_loss_cudnn(self, device):
batch_size = 16
Expand Down
9 changes: 7 additions & 2 deletions torch/testing/_internal/common_device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,8 +1355,13 @@ def skipCUDAIfNoMagmaAndNoLinalgsolver(fn):
return skipCUDAIfNoMagma(fn)

# Skips a test on CUDA when using ROCm.
def skipCUDAIfRocm(fn):
return skipCUDAIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")(fn)
def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
def dec_fn(fn):
reason = f"skipCUDAIfRocm: {msg}"
return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn)
if func:
return dec_fn(func)
return dec_fn

# Skips a test on CUDA when not using ROCm.
def skipCUDAIfNotRocm(fn):
Expand Down
22 changes: 14 additions & 8 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,14 +1267,20 @@ def has_corresponding_torch_dtype(np_dtype):
torch.complex32: np.complex64
})

def skipIfRocm(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_ROCM:
raise unittest.SkipTest("test doesn't currently work on the ROCm stack")
else:
fn(*args, **kwargs)
return wrapper
def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
def dec_fn(fn):
reason = f"skipIfRocm: {msg}"

@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_ROCM:
raise unittest.SkipTest(reason)
else:
return fn(*args, **kwargs)
return wrapper
if func:
return dec_fn(func)
return dec_fn

def runOnRocm(fn):
@wraps(fn)
Expand Down

0 comments on commit bc662ff

Please sign in to comment.