Skip to content

Commit

Permalink
Reland: Remove remaining global set_default_dtype calls from tests (p…
Browse files Browse the repository at this point in the history
…ytorch#108088)

Fixes pytorch#68972

Relands pytorch#107246

To avoid causing Meta-internal CI failures, this PR avoids always asserting that the default dtype is float in the `TestCase.setUp/tearDown` methods. Instead, the assert is only done if `TestCase._default_dtype_check_enabled == True`. `_default_dtype_check_enabled` is set to True in the `if __name__ == "__main__":` blocks of all the relevant test files that have required changes for this issue

Pull Request resolved: pytorch#108088
Approved by: https://github.com/ezyang
  • Loading branch information
kurtamohler authored and pytorchmergebot committed Sep 7, 2023
1 parent 54e7327 commit 3f88e31
Show file tree
Hide file tree
Showing 20 changed files with 922 additions and 879 deletions.
23 changes: 11 additions & 12 deletions test/distributed/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if
import torch.nn.functional as F

torch.set_default_dtype(torch.double)

NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")

# batched grad doesn't support data parallel
Expand All @@ -40,11 +38,11 @@ def __init__(self, t):
def forward(self, x):
return x * self.t_rg + self.t_not_rg

m = TestModule(torch.randn(100, device='cuda', requires_grad=True))
m = TestModule(torch.randn(100, device='cuda', requires_grad=True, dtype=torch.double))
self.assertTrue(m.t_rg.requires_grad)

dpm = nn.DataParallel(m, [0, 1])
inp = torch.randn(2, 100, device='cuda')
inp = torch.randn(2, 100, device='cuda', dtype=torch.double)

def fn(t):
return dpm(inp)
Expand Down Expand Up @@ -512,11 +510,11 @@ def _test_scatter(self, tensor):

@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
def test_scatter_cpu(self):
self._test_scatter(torch.randn((4, 4)))
self._test_scatter(torch.randn((4, 4), dtype=torch.double))

@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
def test_scatter_gpu(self):
self._test_scatter(torch.randn((4, 4)).cuda())
self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda())

@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
Expand All @@ -539,8 +537,8 @@ def forward(self, x):

def _test_gather(self, output_device):
inputs = (
torch.randn(2, 4, device='cuda:0', requires_grad=True),
torch.randn(2, 4, device='cuda:1', requires_grad=True),
torch.randn(2, 4, device='cuda:0', requires_grad=True, dtype=torch.double),
torch.randn(2, 4, device='cuda:1', requires_grad=True, dtype=torch.double),
)
result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([4, 4]))
Expand All @@ -550,7 +548,7 @@ def _test_gather(self, output_device):
self.assertEqual(result.get_device(), output_device)
else:
self.assertFalse(result.is_cuda)
grad = torch.randn((4, 4))
grad = torch.randn((4, 4), dtype=torch.double)
if output_device != -1:
grad = grad.cuda(output_device)
result.backward(grad)
Expand All @@ -560,8 +558,8 @@ def _test_gather(self, output_device):

# test scalar inputs, should stack into a vector in this case
inputs = (
torch.randn((), device='cuda:0', requires_grad=True),
torch.randn((), device='cuda:1', requires_grad=True),
torch.randn((), device='cuda:0', requires_grad=True, dtype=torch.double),
torch.randn((), device='cuda:1', requires_grad=True, dtype=torch.double),
)
result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([2]))
Expand All @@ -571,7 +569,7 @@ def _test_gather(self, output_device):
self.assertEqual(result.get_device(), output_device)
else:
self.assertFalse(result.is_cuda)
grad = torch.randn(2)
grad = torch.randn(2, dtype=torch.double)
if output_device != -1:
grad = grad.cuda(output_device)
result.backward(grad)
Expand Down Expand Up @@ -878,4 +876,5 @@ def forward(self, input):
instantiate_device_type_tests(TestDataParallelDeviceType, globals())

if __name__ == '__main__':
TestCase._default_dtype_check_enabled = True
run_tests()
Loading

0 comments on commit 3f88e31

Please sign in to comment.