diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py index 65c879be3d77d..5349019791979 100644 --- a/test/distributed/test_data_parallel.py +++ b/test/distributed/test_data_parallel.py @@ -19,8 +19,6 @@ from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if import torch.nn.functional as F -torch.set_default_dtype(torch.double) - NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL") # batched grad doesn't support data parallel @@ -40,11 +38,11 @@ def __init__(self, t): def forward(self, x): return x * self.t_rg + self.t_not_rg - m = TestModule(torch.randn(100, device='cuda', requires_grad=True)) + m = TestModule(torch.randn(100, device='cuda', requires_grad=True, dtype=torch.double)) self.assertTrue(m.t_rg.requires_grad) dpm = nn.DataParallel(m, [0, 1]) - inp = torch.randn(2, 100, device='cuda') + inp = torch.randn(2, 100, device='cuda', dtype=torch.double) def fn(t): return dpm(inp) @@ -512,11 +510,11 @@ def _test_scatter(self, tensor): @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_scatter_cpu(self): - self._test_scatter(torch.randn((4, 4))) + self._test_scatter(torch.randn((4, 4), dtype=torch.double)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_scatter_gpu(self): - self._test_scatter(torch.randn((4, 4)).cuda()) + self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda()) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed") @@ -539,8 +537,8 @@ def forward(self, x): def _test_gather(self, output_device): inputs = ( - torch.randn(2, 4, device='cuda:0', requires_grad=True), - torch.randn(2, 4, device='cuda:1', requires_grad=True), + torch.randn(2, 4, device='cuda:0', requires_grad=True, dtype=torch.double), + torch.randn(2, 4, device='cuda:1', requires_grad=True, dtype=torch.double), ) result = dp.gather(inputs, output_device) self.assertEqual(result.size(), torch.Size([4, 4])) @@ -550,7 +548,7 @@ def _test_gather(self, output_device): self.assertEqual(result.get_device(), output_device) else: self.assertFalse(result.is_cuda) - grad = torch.randn((4, 4)) + grad = torch.randn((4, 4), dtype=torch.double) if output_device != -1: grad = grad.cuda(output_device) result.backward(grad) @@ -560,8 +558,8 @@ def _test_gather(self, output_device): # test scalar inputs, should stack into a vector in this case inputs = ( - torch.randn((), device='cuda:0', requires_grad=True), - torch.randn((), device='cuda:1', requires_grad=True), + torch.randn((), device='cuda:0', requires_grad=True, dtype=torch.double), + torch.randn((), device='cuda:1', requires_grad=True, dtype=torch.double), ) result = dp.gather(inputs, output_device) self.assertEqual(result.size(), torch.Size([2])) @@ -571,7 +569,7 @@ def _test_gather(self, output_device): self.assertEqual(result.get_device(), output_device) else: self.assertFalse(result.is_cuda) - grad = torch.randn(2) + grad = torch.randn(2, dtype=torch.double) if output_device != -1: grad = grad.cuda(output_device) result.backward(grad) @@ -878,4 +876,5 @@ def forward(self, input): instantiate_device_type_tests(TestDataParallelDeviceType, globals()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 047ed77289389..5ad72b3eff898 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -38,14 +38,10 @@ import torch -# TODO: remove this global setting -# Distributions tests use double as the default dtype -torch.set_default_dtype(torch.double) - from torch import inf, nan from torch.testing._internal.common_utils import \ (TestCase, run_tests, set_rng_seed, load_tests, - gradcheck, skipIfTorchDynamo) + gradcheck, skipIfTorchDynamo, set_default_dtype) from torch.testing._internal.common_cuda import TEST_CUDA from torch.autograd import grad import torch.autograd.forward_ad as fwAD @@ -103,694 +99,697 @@ def is_all_nan(tensor): return (tensor != tensor).all() -# Register all distributions for generic tests. Example = namedtuple('Example', ['Dist', 'params']) -EXAMPLES = [ - Example(Bernoulli, [ - {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, - {'probs': torch.tensor([0.3], requires_grad=True)}, - {'probs': 0.3}, - {'logits': torch.tensor([0.], requires_grad=True)}, - ]), - Example(Geometric, [ - {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, - {'probs': torch.tensor([0.3], requires_grad=True)}, - {'probs': 0.3}, - ]), - Example(Beta, [ - { - 'concentration1': torch.randn(2, 3).exp().requires_grad_(), - 'concentration0': torch.randn(2, 3).exp().requires_grad_(), - }, - { - 'concentration1': torch.randn(4).exp().requires_grad_(), - 'concentration0': torch.randn(4).exp().requires_grad_(), - }, - ]), - Example(Categorical, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, - {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, - ]), - Example(Binomial, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10])}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10, 8])}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), - 'total_count': torch.tensor([[10., 8.], [5., 3.]])}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), - 'total_count': torch.tensor(0.)}, - ]), - Example(NegativeBinomial, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, - {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': 10}, - {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10])}, - {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10, 8])}, - {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), - 'total_count': torch.tensor([[10., 8.], [5., 3.]])}, - {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), - 'total_count': torch.tensor(0.)}, - ]), - Example(Multinomial, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10}, - ]), - Example(Cauchy, [ - {'loc': 0.0, 'scale': 1.0}, - {'loc': torch.tensor([0.0]), 'scale': 1.0}, - {'loc': torch.tensor([[0.0], [0.0]]), - 'scale': torch.tensor([[1.0], [1.0]])} - ]), - Example(Chi2, [ - {'df': torch.randn(2, 3).exp().requires_grad_()}, - {'df': torch.randn(1).exp().requires_grad_()}, - ]), - Example(StudentT, [ - {'df': torch.randn(2, 3).exp().requires_grad_()}, - {'df': torch.randn(1).exp().requires_grad_()}, - ]), - Example(Dirichlet, [ - {'concentration': torch.randn(2, 3).exp().requires_grad_()}, - {'concentration': torch.randn(4).exp().requires_grad_()}, - ]), - Example(Exponential, [ - {'rate': torch.randn(5, 5).abs().requires_grad_()}, - {'rate': torch.randn(1).abs().requires_grad_()}, - ]), - Example(FisherSnedecor, [ - { - 'df1': torch.randn(5, 5).abs().requires_grad_(), - 'df2': torch.randn(5, 5).abs().requires_grad_(), - }, - { - 'df1': torch.randn(1).abs().requires_grad_(), - 'df2': torch.randn(1).abs().requires_grad_(), - }, - { - 'df1': torch.tensor([1.0]), - 'df2': 1.0, - } - ]), - Example(Gamma, [ - { - 'concentration': torch.randn(2, 3).exp().requires_grad_(), - 'rate': torch.randn(2, 3).exp().requires_grad_(), - }, - { - 'concentration': torch.randn(1).exp().requires_grad_(), - 'rate': torch.randn(1).exp().requires_grad_(), - }, - ]), - Example(Gumbel, [ - { - 'loc': torch.randn(5, 5, requires_grad=True), - 'scale': torch.randn(5, 5).abs().requires_grad_(), - }, - { - 'loc': torch.randn(1, requires_grad=True), - 'scale': torch.randn(1).abs().requires_grad_(), - }, - ]), - Example(HalfCauchy, [ - {'scale': 1.0}, - {'scale': torch.tensor([[1.0], [1.0]])} - ]), - Example(HalfNormal, [ - {'scale': torch.randn(5, 5).abs().requires_grad_()}, - {'scale': torch.randn(1).abs().requires_grad_()}, - {'scale': torch.tensor([1e-5, 1e-5], requires_grad=True)} - ]), - Example(Independent, [ - { - 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), - torch.randn(2, 3).abs().requires_grad_()), - 'reinterpreted_batch_ndims': 0, - }, - { - 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), - torch.randn(2, 3).abs().requires_grad_()), - 'reinterpreted_batch_ndims': 1, - }, - { - 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), - torch.randn(2, 3).abs().requires_grad_()), - 'reinterpreted_batch_ndims': 2, - }, - { - 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), - torch.randn(2, 3, 5).abs().requires_grad_()), - 'reinterpreted_batch_ndims': 2, - }, - { - 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), - torch.randn(2, 3, 5).abs().requires_grad_()), - 'reinterpreted_batch_ndims': 3, - }, - ]), - Example(Kumaraswamy, [ - { - 'concentration1': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), - 'concentration0': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), - }, - { - 'concentration1': torch.rand(4).uniform_(1, 2).requires_grad_(), - 'concentration0': torch.rand(4).uniform_(1, 2).requires_grad_(), - }, - ]), - Example(LKJCholesky, [ - { - 'dim': 2, - 'concentration': 0.5 - }, - { - 'dim': 3, - 'concentration': torch.tensor([0.5, 1., 2.]), - }, - { - 'dim': 100, - 'concentration': 4. - }, - ]), - Example(Laplace, [ - { - 'loc': torch.randn(5, 5, requires_grad=True), - 'scale': torch.randn(5, 5).abs().requires_grad_(), - }, - { - 'loc': torch.randn(1, requires_grad=True), - 'scale': torch.randn(1).abs().requires_grad_(), - }, - { - 'loc': torch.tensor([1.0, 0.0], requires_grad=True), - 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), - }, - ]), - Example(LogNormal, [ - { - 'loc': torch.randn(5, 5, requires_grad=True), - 'scale': torch.randn(5, 5).abs().requires_grad_(), - }, - { - 'loc': torch.randn(1, requires_grad=True), - 'scale': torch.randn(1).abs().requires_grad_(), - }, - { - 'loc': torch.tensor([1.0, 0.0], requires_grad=True), - 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), - }, - ]), - Example(LogisticNormal, [ - { - 'loc': torch.randn(5, 5).requires_grad_(), - 'scale': torch.randn(5, 5).abs().requires_grad_(), - }, - { - 'loc': torch.randn(1).requires_grad_(), - 'scale': torch.randn(1).abs().requires_grad_(), - }, - { - 'loc': torch.tensor([1.0, 0.0], requires_grad=True), - 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), - }, - ]), - Example(LowRankMultivariateNormal, [ - { - 'loc': torch.randn(5, 2, requires_grad=True), - 'cov_factor': torch.randn(5, 2, 1, requires_grad=True), - 'cov_diag': torch.tensor([2.0, 0.25], requires_grad=True), - }, - { - 'loc': torch.randn(4, 3, requires_grad=True), - 'cov_factor': torch.randn(3, 2, requires_grad=True), - 'cov_diag': torch.tensor([5.0, 1.5, 3.], requires_grad=True), - } - ]), - Example(MultivariateNormal, [ - { - 'loc': torch.randn(5, 2, requires_grad=True), - 'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True), - }, - { - 'loc': torch.randn(2, 3, requires_grad=True), - 'precision_matrix': torch.tensor([[2.0, 0.1, 0.0], - [0.1, 0.25, 0.0], - [0.0, 0.0, 0.3]], requires_grad=True), - }, - { - 'loc': torch.randn(5, 3, 2, requires_grad=True), - 'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]], - [[2.0, 0.0], [0.3, 0.25]], - [[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True), - }, - { - 'loc': torch.tensor([1.0, -1.0]), - 'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), - }, - ]), - Example(Normal, [ - { - 'loc': torch.randn(5, 5, requires_grad=True), - 'scale': torch.randn(5, 5).abs().requires_grad_(), - }, - { - 'loc': torch.randn(1, requires_grad=True), - 'scale': torch.randn(1).abs().requires_grad_(), - }, - { - 'loc': torch.tensor([1.0, 0.0], requires_grad=True), - 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), - }, - ]), - Example(OneHotCategorical, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, - {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, - ]), - Example(OneHotCategoricalStraightThrough, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, - {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, - ]), - Example(Pareto, [ - { - 'scale': 1.0, - 'alpha': 1.0 - }, - { - 'scale': torch.randn(5, 5).abs().requires_grad_(), - 'alpha': torch.randn(5, 5).abs().requires_grad_() - }, - { - 'scale': torch.tensor([1.0]), - 'alpha': 1.0 - } - ]), - Example(Poisson, [ - { - 'rate': torch.randn(5, 5).abs().requires_grad_(), - }, - { - 'rate': torch.randn(3).abs().requires_grad_(), - }, - { - 'rate': 0.2, - }, - { - 'rate': torch.tensor([0.0], requires_grad=True), - }, - { - 'rate': 0.0, - } - ]), - Example(RelaxedBernoulli, [ - { - 'temperature': torch.tensor([0.5], requires_grad=True), - 'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True), - }, - { - 'temperature': torch.tensor([2.0]), - 'probs': torch.tensor([0.3]), - }, - { - 'temperature': torch.tensor([7.2]), - 'logits': torch.tensor([-2.0, 2.0, 1.0, 5.0]) - } - ]), - Example(RelaxedOneHotCategorical, [ - { - 'temperature': torch.tensor([0.5], requires_grad=True), - 'probs': torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True) - }, - { - 'temperature': torch.tensor([2.0]), - 'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]]) - }, - { - 'temperature': torch.tensor([7.2]), - 'logits': torch.tensor([[-2.0, 2.0], [1.0, 5.0]]) - } - ]), - Example(TransformedDistribution, [ - { - 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), - torch.randn(2, 3).abs().requires_grad_()), - 'transforms': [], - }, - { - 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), - torch.randn(2, 3).abs().requires_grad_()), - 'transforms': ExpTransform(), - }, - { - 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), - torch.randn(2, 3, 5).abs().requires_grad_()), - 'transforms': [AffineTransform(torch.randn(3, 5), torch.randn(3, 5)), - ExpTransform()], - }, - { - 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), - torch.randn(2, 3, 5).abs().requires_grad_()), - 'transforms': AffineTransform(1, 2), - }, - { - 'base_distribution': Uniform(torch.tensor(1e8).log(), torch.tensor(1e10).log()), - 'transforms': ExpTransform(), - }, - ]), - Example(Uniform, [ - { - 'low': torch.zeros(5, 5, requires_grad=True), - 'high': torch.ones(5, 5, requires_grad=True), - }, - { - 'low': torch.zeros(1, requires_grad=True), - 'high': torch.ones(1, requires_grad=True), - }, - { - 'low': torch.tensor([1.0, 1.0], requires_grad=True), - 'high': torch.tensor([2.0, 3.0], requires_grad=True), - }, - ]), - Example(Weibull, [ - { - 'scale': torch.randn(5, 5).abs().requires_grad_(), - 'concentration': torch.randn(1).abs().requires_grad_() - } - ]), - Example(Wishart, [ - { - 'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True), - 'df': torch.tensor([3.], requires_grad=True), - }, - { - 'precision_matrix': torch.tensor([[2.0, 0.1, 0.0], - [0.1, 0.25, 0.0], - [0.0, 0.0, 0.3]], requires_grad=True), - 'df': torch.tensor([5., 4], requires_grad=True), - }, - { - 'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]], - [[2.0, 0.0], [0.3, 0.25]], - [[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True), - 'df': torch.tensor([5., 3.5, 3], requires_grad=True), - }, - { - 'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), - 'df': torch.tensor([3.0]), - }, - { - 'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), - 'df': 3.0, - }, - ]), - Example(MixtureSameFamily, [ - { - 'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)), - 'component_distribution': Normal(torch.randn(5, requires_grad=True), - torch.rand(5, requires_grad=True)), - }, - { - 'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)), - 'component_distribution': MultivariateNormal( - loc=torch.randn(5, 2, requires_grad=True), - covariance_matrix=torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True)), - }, - ]), - Example(VonMises, [ - { - 'loc': torch.tensor(1.0, requires_grad=True), - 'concentration': torch.tensor(10.0, requires_grad=True) - }, - { - 'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True), - 'concentration': torch.tensor([1.0, 10.0], requires_grad=True) - }, - ]), - Example(ContinuousBernoulli, [ - {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, - {'probs': torch.tensor([0.3], requires_grad=True)}, - {'probs': 0.3}, - {'logits': torch.tensor([0.], requires_grad=True)}, - ]) -] - -BAD_EXAMPLES = [ - Example(Bernoulli, [ - {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, - {'probs': torch.tensor([-0.5], requires_grad=True)}, - {'probs': 1.00001}, - ]), - Example(Beta, [ - { - 'concentration1': torch.tensor([0.0], requires_grad=True), - 'concentration0': torch.tensor([0.0], requires_grad=True), - }, - { - 'concentration1': torch.tensor([-1.0], requires_grad=True), - 'concentration0': torch.tensor([-2.0], requires_grad=True), - }, - ]), - Example(Geometric, [ - {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, - {'probs': torch.tensor([-0.3], requires_grad=True)}, - {'probs': 1.00000001}, - ]), - Example(Categorical, [ - {'probs': torch.tensor([[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, - {'probs': torch.tensor([[-1.0, 10.0], [0.0, -1.0]], requires_grad=True)}, - ]), - Example(Binomial, [ - {'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), - 'total_count': 10}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), - 'total_count': 10}, - ]), - Example(NegativeBinomial, [ - {'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), - 'total_count': 10}, - {'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), - 'total_count': 10}, - ]), - Example(Cauchy, [ - {'loc': 0.0, 'scale': -1.0}, - {'loc': torch.tensor([0.0]), 'scale': 0.0}, - {'loc': torch.tensor([[0.0], [-2.0]]), - 'scale': torch.tensor([[-0.000001], [1.0]])} - ]), - Example(Chi2, [ - {'df': torch.tensor([0.], requires_grad=True)}, - {'df': torch.tensor([-2.], requires_grad=True)}, - ]), - Example(StudentT, [ - {'df': torch.tensor([0.], requires_grad=True)}, - {'df': torch.tensor([-2.], requires_grad=True)}, - ]), - Example(Dirichlet, [ - {'concentration': torch.tensor([0.], requires_grad=True)}, - {'concentration': torch.tensor([-2.], requires_grad=True)} - ]), - Example(Exponential, [ - {'rate': torch.tensor([0., 0.], requires_grad=True)}, - {'rate': torch.tensor([-2.], requires_grad=True)} - ]), - Example(FisherSnedecor, [ - { - 'df1': torch.tensor([0., 0.], requires_grad=True), - 'df2': torch.tensor([-1., -100.], requires_grad=True), - }, - { - 'df1': torch.tensor([1., 1.], requires_grad=True), - 'df2': torch.tensor([0., 0.], requires_grad=True), - } - ]), - Example(Gamma, [ - { - 'concentration': torch.tensor([0., 0.], requires_grad=True), - 'rate': torch.tensor([-1., -100.], requires_grad=True), - }, - { - 'concentration': torch.tensor([1., 1.], requires_grad=True), - 'rate': torch.tensor([0., 0.], requires_grad=True), - } - ]), - Example(Gumbel, [ - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([0., 1.], requires_grad=True), - }, - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([1., -1.], requires_grad=True), - }, - ]), - Example(HalfCauchy, [ - {'scale': -1.0}, - {'scale': 0.0}, - {'scale': torch.tensor([[-0.000001], [1.0]])} - ]), - Example(HalfNormal, [ - {'scale': torch.tensor([0., 1.], requires_grad=True)}, - {'scale': torch.tensor([1., -1.], requires_grad=True)}, - ]), - Example(LKJCholesky, [ - { - 'dim': -2, - 'concentration': 0.1 - }, - { - 'dim': 1, - 'concentration': 2., - }, - { - 'dim': 2, - 'concentration': 0., - }, - ]), - Example(Laplace, [ - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([0., 1.], requires_grad=True), - }, - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([1., -1.], requires_grad=True), - }, - ]), - Example(LogNormal, [ - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([0., 1.], requires_grad=True), - }, - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([1., -1.], requires_grad=True), - }, - ]), - Example(MultivariateNormal, [ - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True), - }, - ]), - Example(Normal, [ - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([0., 1.], requires_grad=True), - }, - { - 'loc': torch.tensor([1., 1.], requires_grad=True), - 'scale': torch.tensor([1., -1.], requires_grad=True), - }, - { - 'loc': torch.tensor([1.0, 0.0], requires_grad=True), - 'scale': torch.tensor([1e-5, -1e-5], requires_grad=True), - }, - ]), - Example(OneHotCategorical, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, - {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, - ]), - Example(OneHotCategoricalStraightThrough, [ - {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, - {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, - ]), - Example(Pareto, [ - { - 'scale': 0.0, - 'alpha': 0.0 - }, - { - 'scale': torch.tensor([0.0, 0.0], requires_grad=True), - 'alpha': torch.tensor([-1e-5, 0.0], requires_grad=True) - }, - { - 'scale': torch.tensor([1.0]), - 'alpha': -1.0 - } - ]), - Example(Poisson, [ - { - 'rate': torch.tensor([-0.1], requires_grad=True), - }, - { - 'rate': -1.0, - } - ]), - Example(RelaxedBernoulli, [ - { - 'temperature': torch.tensor([1.5], requires_grad=True), - 'probs': torch.tensor([1.7, 0.2, 0.4], requires_grad=True), - }, - { - 'temperature': torch.tensor([2.0]), - 'probs': torch.tensor([-1.0]), - } - ]), - Example(RelaxedOneHotCategorical, [ - { - 'temperature': torch.tensor([0.5], requires_grad=True), - 'probs': torch.tensor([[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True) - }, - { - 'temperature': torch.tensor([2.0]), - 'probs': torch.tensor([[-1.0, 0.0], [-1.0, 1.1]]) - } - ]), - Example(TransformedDistribution, [ - { - 'base_distribution': Normal(0, 1), - 'transforms': lambda x: x, - }, - { - 'base_distribution': Normal(0, 1), - 'transforms': [lambda x: x], - }, - ]), - Example(Uniform, [ - { - 'low': torch.tensor([2.0], requires_grad=True), - 'high': torch.tensor([2.0], requires_grad=True), - }, - { - 'low': torch.tensor([0.0], requires_grad=True), - 'high': torch.tensor([0.0], requires_grad=True), - }, - { - 'low': torch.tensor([1.0], requires_grad=True), - 'high': torch.tensor([0.0], requires_grad=True), - } - ]), - Example(Weibull, [ - { - 'scale': torch.tensor([0.0], requires_grad=True), - 'concentration': torch.tensor([0.0], requires_grad=True) - }, - { - 'scale': torch.tensor([1.0], requires_grad=True), - 'concentration': torch.tensor([-1.0], requires_grad=True) - } - ]), - Example(Wishart, [ - { - 'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True), - 'df': torch.tensor([1.5], requires_grad=True), - }, - { - 'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True), - 'df': torch.tensor([3.], requires_grad=True), - }, - { - 'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True), - 'df': 3., - }, - ]), - Example(ContinuousBernoulli, [ - {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, - {'probs': torch.tensor([-0.5], requires_grad=True)}, - {'probs': 1.00001}, - ]) -] + +# Register all distributions for generic tests. +def _get_examples(): + return [ + Example(Bernoulli, [ + {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([0.3], requires_grad=True)}, + {'probs': 0.3}, + {'logits': torch.tensor([0.], requires_grad=True)}, + ]), + Example(Geometric, [ + {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([0.3], requires_grad=True)}, + {'probs': 0.3}, + ]), + Example(Beta, [ + { + 'concentration1': torch.randn(2, 3).exp().requires_grad_(), + 'concentration0': torch.randn(2, 3).exp().requires_grad_(), + }, + { + 'concentration1': torch.randn(4).exp().requires_grad_(), + 'concentration0': torch.randn(4).exp().requires_grad_(), + }, + ]), + Example(Categorical, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, + {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), + Example(Binomial, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10])}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10, 8])}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), + 'total_count': torch.tensor([[10., 8.], [5., 3.]])}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), + 'total_count': torch.tensor(0.)}, + ]), + Example(NegativeBinomial, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, + {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': 10}, + {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10])}, + {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10, 8])}, + {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), + 'total_count': torch.tensor([[10., 8.], [5., 3.]])}, + {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), + 'total_count': torch.tensor(0.)}, + ]), + Example(Multinomial, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10}, + ]), + Example(Cauchy, [ + {'loc': 0.0, 'scale': 1.0}, + {'loc': torch.tensor([0.0]), 'scale': 1.0}, + {'loc': torch.tensor([[0.0], [0.0]]), + 'scale': torch.tensor([[1.0], [1.0]])} + ]), + Example(Chi2, [ + {'df': torch.randn(2, 3).exp().requires_grad_()}, + {'df': torch.randn(1).exp().requires_grad_()}, + ]), + Example(StudentT, [ + {'df': torch.randn(2, 3).exp().requires_grad_()}, + {'df': torch.randn(1).exp().requires_grad_()}, + ]), + Example(Dirichlet, [ + {'concentration': torch.randn(2, 3).exp().requires_grad_()}, + {'concentration': torch.randn(4).exp().requires_grad_()}, + ]), + Example(Exponential, [ + {'rate': torch.randn(5, 5).abs().requires_grad_()}, + {'rate': torch.randn(1).abs().requires_grad_()}, + ]), + Example(FisherSnedecor, [ + { + 'df1': torch.randn(5, 5).abs().requires_grad_(), + 'df2': torch.randn(5, 5).abs().requires_grad_(), + }, + { + 'df1': torch.randn(1).abs().requires_grad_(), + 'df2': torch.randn(1).abs().requires_grad_(), + }, + { + 'df1': torch.tensor([1.0]), + 'df2': 1.0, + } + ]), + Example(Gamma, [ + { + 'concentration': torch.randn(2, 3).exp().requires_grad_(), + 'rate': torch.randn(2, 3).exp().requires_grad_(), + }, + { + 'concentration': torch.randn(1).exp().requires_grad_(), + 'rate': torch.randn(1).exp().requires_grad_(), + }, + ]), + Example(Gumbel, [ + { + 'loc': torch.randn(5, 5, requires_grad=True), + 'scale': torch.randn(5, 5).abs().requires_grad_(), + }, + { + 'loc': torch.randn(1, requires_grad=True), + 'scale': torch.randn(1).abs().requires_grad_(), + }, + ]), + Example(HalfCauchy, [ + {'scale': 1.0}, + {'scale': torch.tensor([[1.0], [1.0]])} + ]), + Example(HalfNormal, [ + {'scale': torch.randn(5, 5).abs().requires_grad_()}, + {'scale': torch.randn(1).abs().requires_grad_()}, + {'scale': torch.tensor([1e-5, 1e-5], requires_grad=True)} + ]), + Example(Independent, [ + { + 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), + torch.randn(2, 3).abs().requires_grad_()), + 'reinterpreted_batch_ndims': 0, + }, + { + 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), + torch.randn(2, 3).abs().requires_grad_()), + 'reinterpreted_batch_ndims': 1, + }, + { + 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), + torch.randn(2, 3).abs().requires_grad_()), + 'reinterpreted_batch_ndims': 2, + }, + { + 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), + torch.randn(2, 3, 5).abs().requires_grad_()), + 'reinterpreted_batch_ndims': 2, + }, + { + 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), + torch.randn(2, 3, 5).abs().requires_grad_()), + 'reinterpreted_batch_ndims': 3, + }, + ]), + Example(Kumaraswamy, [ + { + 'concentration1': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), + 'concentration0': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), + }, + { + 'concentration1': torch.rand(4).uniform_(1, 2).requires_grad_(), + 'concentration0': torch.rand(4).uniform_(1, 2).requires_grad_(), + }, + ]), + Example(LKJCholesky, [ + { + 'dim': 2, + 'concentration': 0.5 + }, + { + 'dim': 3, + 'concentration': torch.tensor([0.5, 1., 2.]), + }, + { + 'dim': 100, + 'concentration': 4. + }, + ]), + Example(Laplace, [ + { + 'loc': torch.randn(5, 5, requires_grad=True), + 'scale': torch.randn(5, 5).abs().requires_grad_(), + }, + { + 'loc': torch.randn(1, requires_grad=True), + 'scale': torch.randn(1).abs().requires_grad_(), + }, + { + 'loc': torch.tensor([1.0, 0.0], requires_grad=True), + 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), + }, + ]), + Example(LogNormal, [ + { + 'loc': torch.randn(5, 5, requires_grad=True), + 'scale': torch.randn(5, 5).abs().requires_grad_(), + }, + { + 'loc': torch.randn(1, requires_grad=True), + 'scale': torch.randn(1).abs().requires_grad_(), + }, + { + 'loc': torch.tensor([1.0, 0.0], requires_grad=True), + 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), + }, + ]), + Example(LogisticNormal, [ + { + 'loc': torch.randn(5, 5).requires_grad_(), + 'scale': torch.randn(5, 5).abs().requires_grad_(), + }, + { + 'loc': torch.randn(1).requires_grad_(), + 'scale': torch.randn(1).abs().requires_grad_(), + }, + { + 'loc': torch.tensor([1.0, 0.0], requires_grad=True), + 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), + }, + ]), + Example(LowRankMultivariateNormal, [ + { + 'loc': torch.randn(5, 2, requires_grad=True), + 'cov_factor': torch.randn(5, 2, 1, requires_grad=True), + 'cov_diag': torch.tensor([2.0, 0.25], requires_grad=True), + }, + { + 'loc': torch.randn(4, 3, requires_grad=True), + 'cov_factor': torch.randn(3, 2, requires_grad=True), + 'cov_diag': torch.tensor([5.0, 1.5, 3.], requires_grad=True), + } + ]), + Example(MultivariateNormal, [ + { + 'loc': torch.randn(5, 2, requires_grad=True), + 'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True), + }, + { + 'loc': torch.randn(2, 3, requires_grad=True), + 'precision_matrix': torch.tensor([[2.0, 0.1, 0.0], + [0.1, 0.25, 0.0], + [0.0, 0.0, 0.3]], requires_grad=True), + }, + { + 'loc': torch.randn(5, 3, 2, requires_grad=True), + 'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]], + [[2.0, 0.0], [0.3, 0.25]], + [[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True), + }, + { + 'loc': torch.tensor([1.0, -1.0]), + 'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), + }, + ]), + Example(Normal, [ + { + 'loc': torch.randn(5, 5, requires_grad=True), + 'scale': torch.randn(5, 5).abs().requires_grad_(), + }, + { + 'loc': torch.randn(1, requires_grad=True), + 'scale': torch.randn(1).abs().requires_grad_(), + }, + { + 'loc': torch.tensor([1.0, 0.0], requires_grad=True), + 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), + }, + ]), + Example(OneHotCategorical, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, + {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), + Example(OneHotCategoricalStraightThrough, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, + {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), + Example(Pareto, [ + { + 'scale': 1.0, + 'alpha': 1.0 + }, + { + 'scale': (torch.randn(5, 5).abs() + 0.1).requires_grad_(), + 'alpha': (torch.randn(5, 5).abs() + 0.1).requires_grad_() + }, + { + 'scale': torch.tensor([1.0]), + 'alpha': 1.0 + } + ]), + Example(Poisson, [ + { + 'rate': torch.randn(5, 5).abs().requires_grad_(), + }, + { + 'rate': torch.randn(3).abs().requires_grad_(), + }, + { + 'rate': 0.2, + }, + { + 'rate': torch.tensor([0.0], requires_grad=True), + }, + { + 'rate': 0.0, + } + ]), + Example(RelaxedBernoulli, [ + { + 'temperature': torch.tensor([0.5], requires_grad=True), + 'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True), + }, + { + 'temperature': torch.tensor([2.0]), + 'probs': torch.tensor([0.3]), + }, + { + 'temperature': torch.tensor([7.2]), + 'logits': torch.tensor([-2.0, 2.0, 1.0, 5.0]) + } + ]), + Example(RelaxedOneHotCategorical, [ + { + 'temperature': torch.tensor([0.5], requires_grad=True), + 'probs': torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True) + }, + { + 'temperature': torch.tensor([2.0]), + 'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + }, + { + 'temperature': torch.tensor([7.2]), + 'logits': torch.tensor([[-2.0, 2.0], [1.0, 5.0]]) + } + ]), + Example(TransformedDistribution, [ + { + 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), + torch.randn(2, 3).abs().requires_grad_()), + 'transforms': [], + }, + { + 'base_distribution': Normal(torch.randn(2, 3, requires_grad=True), + torch.randn(2, 3).abs().requires_grad_()), + 'transforms': ExpTransform(), + }, + { + 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), + torch.randn(2, 3, 5).abs().requires_grad_()), + 'transforms': [AffineTransform(torch.randn(3, 5), torch.randn(3, 5)), + ExpTransform()], + }, + { + 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True), + torch.randn(2, 3, 5).abs().requires_grad_()), + 'transforms': AffineTransform(1, 2), + }, + { + 'base_distribution': Uniform(torch.tensor(1e8).log(), torch.tensor(1e10).log()), + 'transforms': ExpTransform(), + }, + ]), + Example(Uniform, [ + { + 'low': torch.zeros(5, 5, requires_grad=True), + 'high': torch.ones(5, 5, requires_grad=True), + }, + { + 'low': torch.zeros(1, requires_grad=True), + 'high': torch.ones(1, requires_grad=True), + }, + { + 'low': torch.tensor([1.0, 1.0], requires_grad=True), + 'high': torch.tensor([2.0, 3.0], requires_grad=True), + }, + ]), + Example(Weibull, [ + { + 'scale': torch.randn(5, 5).abs().requires_grad_(), + 'concentration': torch.randn(1).abs().requires_grad_() + } + ]), + Example(Wishart, [ + { + 'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True), + 'df': torch.tensor([3.], requires_grad=True), + }, + { + 'precision_matrix': torch.tensor([[2.0, 0.1, 0.0], + [0.1, 0.25, 0.0], + [0.0, 0.0, 0.3]], requires_grad=True), + 'df': torch.tensor([5., 4], requires_grad=True), + }, + { + 'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]], + [[2.0, 0.0], [0.3, 0.25]], + [[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True), + 'df': torch.tensor([5., 3.5, 3], requires_grad=True), + }, + { + 'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), + 'df': torch.tensor([3.0]), + }, + { + 'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), + 'df': 3.0, + }, + ]), + Example(MixtureSameFamily, [ + { + 'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)), + 'component_distribution': Normal(torch.randn(5, requires_grad=True), + torch.rand(5, requires_grad=True)), + }, + { + 'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)), + 'component_distribution': MultivariateNormal( + loc=torch.randn(5, 2, requires_grad=True), + covariance_matrix=torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True)), + }, + ]), + Example(VonMises, [ + { + 'loc': torch.tensor(1.0, requires_grad=True), + 'concentration': torch.tensor(10.0, requires_grad=True) + }, + { + 'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True), + 'concentration': torch.tensor([1.0, 10.0], requires_grad=True) + }, + ]), + Example(ContinuousBernoulli, [ + {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([0.3], requires_grad=True)}, + {'probs': 0.3}, + {'logits': torch.tensor([0.], requires_grad=True)}, + ]) + ] + +def _get_bad_examples(): + return [ + Example(Bernoulli, [ + {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([-0.5], requires_grad=True)}, + {'probs': 1.00001}, + ]), + Example(Beta, [ + { + 'concentration1': torch.tensor([0.0], requires_grad=True), + 'concentration0': torch.tensor([0.0], requires_grad=True), + }, + { + 'concentration1': torch.tensor([-1.0], requires_grad=True), + 'concentration0': torch.tensor([-2.0], requires_grad=True), + }, + ]), + Example(Geometric, [ + {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([-0.3], requires_grad=True)}, + {'probs': 1.00000001}, + ]), + Example(Categorical, [ + {'probs': torch.tensor([[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[-1.0, 10.0], [0.0, -1.0]], requires_grad=True)}, + ]), + Example(Binomial, [ + {'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), + 'total_count': 10}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), + 'total_count': 10}, + ]), + Example(NegativeBinomial, [ + {'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), + 'total_count': 10}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), + 'total_count': 10}, + ]), + Example(Cauchy, [ + {'loc': 0.0, 'scale': -1.0}, + {'loc': torch.tensor([0.0]), 'scale': 0.0}, + {'loc': torch.tensor([[0.0], [-2.0]]), + 'scale': torch.tensor([[-0.000001], [1.0]])} + ]), + Example(Chi2, [ + {'df': torch.tensor([0.], requires_grad=True)}, + {'df': torch.tensor([-2.], requires_grad=True)}, + ]), + Example(StudentT, [ + {'df': torch.tensor([0.], requires_grad=True)}, + {'df': torch.tensor([-2.], requires_grad=True)}, + ]), + Example(Dirichlet, [ + {'concentration': torch.tensor([0.], requires_grad=True)}, + {'concentration': torch.tensor([-2.], requires_grad=True)} + ]), + Example(Exponential, [ + {'rate': torch.tensor([0., 0.], requires_grad=True)}, + {'rate': torch.tensor([-2.], requires_grad=True)} + ]), + Example(FisherSnedecor, [ + { + 'df1': torch.tensor([0., 0.], requires_grad=True), + 'df2': torch.tensor([-1., -100.], requires_grad=True), + }, + { + 'df1': torch.tensor([1., 1.], requires_grad=True), + 'df2': torch.tensor([0., 0.], requires_grad=True), + } + ]), + Example(Gamma, [ + { + 'concentration': torch.tensor([0., 0.], requires_grad=True), + 'rate': torch.tensor([-1., -100.], requires_grad=True), + }, + { + 'concentration': torch.tensor([1., 1.], requires_grad=True), + 'rate': torch.tensor([0., 0.], requires_grad=True), + } + ]), + Example(Gumbel, [ + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([0., 1.], requires_grad=True), + }, + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([1., -1.], requires_grad=True), + }, + ]), + Example(HalfCauchy, [ + {'scale': -1.0}, + {'scale': 0.0}, + {'scale': torch.tensor([[-0.000001], [1.0]])} + ]), + Example(HalfNormal, [ + {'scale': torch.tensor([0., 1.], requires_grad=True)}, + {'scale': torch.tensor([1., -1.], requires_grad=True)}, + ]), + Example(LKJCholesky, [ + { + 'dim': -2, + 'concentration': 0.1 + }, + { + 'dim': 1, + 'concentration': 2., + }, + { + 'dim': 2, + 'concentration': 0., + }, + ]), + Example(Laplace, [ + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([0., 1.], requires_grad=True), + }, + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([1., -1.], requires_grad=True), + }, + ]), + Example(LogNormal, [ + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([0., 1.], requires_grad=True), + }, + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([1., -1.], requires_grad=True), + }, + ]), + Example(MultivariateNormal, [ + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True), + }, + ]), + Example(Normal, [ + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([0., 1.], requires_grad=True), + }, + { + 'loc': torch.tensor([1., 1.], requires_grad=True), + 'scale': torch.tensor([1., -1.], requires_grad=True), + }, + { + 'loc': torch.tensor([1.0, 0.0], requires_grad=True), + 'scale': torch.tensor([1e-5, -1e-5], requires_grad=True), + }, + ]), + Example(OneHotCategorical, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), + Example(OneHotCategoricalStraightThrough, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), + Example(Pareto, [ + { + 'scale': 0.0, + 'alpha': 0.0 + }, + { + 'scale': torch.tensor([0.0, 0.0], requires_grad=True), + 'alpha': torch.tensor([-1e-5, 0.0], requires_grad=True) + }, + { + 'scale': torch.tensor([1.0]), + 'alpha': -1.0 + } + ]), + Example(Poisson, [ + { + 'rate': torch.tensor([-0.1], requires_grad=True), + }, + { + 'rate': -1.0, + } + ]), + Example(RelaxedBernoulli, [ + { + 'temperature': torch.tensor([1.5], requires_grad=True), + 'probs': torch.tensor([1.7, 0.2, 0.4], requires_grad=True), + }, + { + 'temperature': torch.tensor([2.0]), + 'probs': torch.tensor([-1.0]), + } + ]), + Example(RelaxedOneHotCategorical, [ + { + 'temperature': torch.tensor([0.5], requires_grad=True), + 'probs': torch.tensor([[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True) + }, + { + 'temperature': torch.tensor([2.0]), + 'probs': torch.tensor([[-1.0, 0.0], [-1.0, 1.1]]) + } + ]), + Example(TransformedDistribution, [ + { + 'base_distribution': Normal(0, 1), + 'transforms': lambda x: x, + }, + { + 'base_distribution': Normal(0, 1), + 'transforms': [lambda x: x], + }, + ]), + Example(Uniform, [ + { + 'low': torch.tensor([2.0], requires_grad=True), + 'high': torch.tensor([2.0], requires_grad=True), + }, + { + 'low': torch.tensor([0.0], requires_grad=True), + 'high': torch.tensor([0.0], requires_grad=True), + }, + { + 'low': torch.tensor([1.0], requires_grad=True), + 'high': torch.tensor([0.0], requires_grad=True), + } + ]), + Example(Weibull, [ + { + 'scale': torch.tensor([0.0], requires_grad=True), + 'concentration': torch.tensor([0.0], requires_grad=True) + }, + { + 'scale': torch.tensor([1.0], requires_grad=True), + 'concentration': torch.tensor([-1.0], requires_grad=True) + } + ]), + Example(Wishart, [ + { + 'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True), + 'df': torch.tensor([1.5], requires_grad=True), + }, + { + 'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True), + 'df': torch.tensor([3.], requires_grad=True), + }, + { + 'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True), + 'df': 3., + }, + ]), + Example(ContinuousBernoulli, [ + {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([-0.5], requires_grad=True)}, + {'probs': 1.00001}, + ]) + ] class DistributionsTestCase(TestCase): @@ -898,13 +897,13 @@ def _check_enumerate_support(self, dist, examples): self.assertEqual(actual, expected_with_expand) def test_repr(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for param in params: dist = Dist(**param) self.assertTrue(repr(dist).startswith(dist.__class__.__name__)) def test_sample_detached(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): variable_params = [p for p in param.values() if getattr(p, 'requires_grad', False)] if not variable_params: @@ -916,7 +915,7 @@ def test_sample_detached(self): @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_rsample_requires_grad(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): if not any(getattr(p, 'requires_grad', False) for p in param.values()): continue @@ -928,7 +927,7 @@ def test_rsample_requires_grad(self): msg=f'{Dist.__name__} example {i + 1}/{len(params)}, .rsample() does not require grad') def test_enumerate_support_type(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): dist = Dist(**param) try: @@ -964,15 +963,15 @@ def test(): self.assertIsNotNone(cov.grad) def test_has_examples(self): - distributions_with_examples = {e.Dist for e in EXAMPLES} + distributions_with_examples = {e.Dist for e in _get_examples()} for Dist in globals().values(): if isinstance(Dist, type) and issubclass(Dist, Distribution) \ and Dist is not Distribution and Dist is not ExponentialFamily: self.assertIn(Dist, distributions_with_examples, - f"Please add {Dist.__name__} to the EXAMPLES list in test_distributions.py") + f"Please add {Dist.__name__} to the _get_examples list in test_distributions.py") def test_support_attributes(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for param in params: d = Dist(**param) event_dim = len(d.event_shape) @@ -989,7 +988,7 @@ def test_support_attributes(self): def test_distribution_expand(self): shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))] - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for param in params: for shape in shapes: d = Dist(**param) @@ -1014,7 +1013,7 @@ def test_distribution_expand(self): def test_distribution_subclass_expand(self): expand_by = torch.Size((2,)) - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): class SubClass(Dist): pass @@ -1032,6 +1031,7 @@ class SubClass(Dist): self.assertEqual(expanded.log_prob(sample), d.log_prob(sample)) self.assertEqual(actual_shape, expected_shape) + @set_default_dtype(torch.double) def test_bernoulli(self): p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) r = torch.tensor(0.3, requires_grad=True) @@ -1077,6 +1077,7 @@ def test_bernoulli_3d(self): (2, 5, 2, 3, 5)) self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) + @set_default_dtype(torch.double) def test_geometric(self): p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) r = torch.tensor(0.3, requires_grad=True) @@ -1097,6 +1098,7 @@ def test_geometric(self): self._check_forward_ad(lambda x: x.geometric_(0.2)) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_geometric_log_prob_and_entropy(self): p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) s = 0.3 @@ -1120,6 +1122,7 @@ def test_geometric_sample(self): scipy.stats.geom(p=prob, loc=-1), f'Geometric(prob={prob})') + @set_default_dtype(torch.double) def test_binomial(self): p = torch.arange(0.05, 1, 0.1).requires_grad_() for total_count in [1, 2, 10]: @@ -1137,6 +1140,7 @@ def test_binomial_sample(self): f'Binomial(total_count={count}, probs={prob})') @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_binomial_log_prob_and_entropy(self): probs = torch.arange(0.05, 1, 0.1) for total_count in [1, 2, 10]: @@ -1168,6 +1172,7 @@ def test_binomial_stable(self): self.assertEqual(grad(y, x)[0], torch.tensor(-0.5)) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_binomial_log_prob_vectorized_count(self): probs = torch.tensor([0.2, 0.7, 0.9]) for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])), @@ -1184,6 +1189,7 @@ def test_binomial_enumerate_support(self): ] self._check_enumerate_support(Binomial, examples) + @set_default_dtype(torch.double) def test_binomial_extreme_vals(self): total_count = 100 bin0 = Binomial(total_count, 0) @@ -1199,6 +1205,7 @@ def test_binomial_extreme_vals(self): self.assertEqual(bin2.sample(), zero_counts) self.assertEqual(bin2.log_prob(zero_counts), zero_counts) + @set_default_dtype(torch.double) def test_binomial_vectorized_count(self): set_rng_seed(1) # see Note [Randomized statistical tests] total_count = torch.tensor([[4, 7], [3, 8]], dtype=torch.float64) @@ -1210,6 +1217,7 @@ def test_binomial_vectorized_count(self): self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0) self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0) + @set_default_dtype(torch.double) def test_negative_binomial(self): p = torch.arange(0.05, 1, 0.1).requires_grad_() for total_count in [1, 2, 10]: @@ -1233,6 +1241,7 @@ def ref_log_prob(idx, x, log_prob): self._check_log_prob(NegativeBinomial(total_count, logits=logits), ref_log_prob) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_negative_binomial_log_prob_vectorized_count(self): probs = torch.tensor([0.2, 0.7, 0.9]) for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])), @@ -1258,6 +1267,7 @@ def test_zero_excluded_binomial(self): assert (vals == 0.0).sum() > 4000 assert (vals == 1.0).sum() > 4000 + @set_default_dtype(torch.double) def test_multinomial_1d(self): total_count = 10 p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) @@ -1269,6 +1279,7 @@ def test_multinomial_1d(self): self.assertRaises(NotImplementedError, Multinomial(10, p).rsample) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_multinomial_1d_log_prob_and_entropy(self): total_count = 10 p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) @@ -1287,6 +1298,7 @@ def test_multinomial_1d_log_prob_and_entropy(self): expected = scipy.stats.multinomial.entropy(total_count, dist.probs.detach().numpy()) self.assertEqual(dist.entropy(), expected, atol=1e-3, rtol=0) + @set_default_dtype(torch.double) def test_multinomial_2d(self): total_count = 10 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] @@ -1304,6 +1316,7 @@ def test_multinomial_2d(self): self.assertEqual(Multinomial(total_count, s).sample(), torch.tensor([[total_count, 0], [0, total_count]], dtype=torch.float64)) + @set_default_dtype(torch.double) def test_categorical_1d(self): p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) self.assertTrue(is_all_nan(Categorical(p).mean)) @@ -1315,6 +1328,7 @@ def test_categorical_1d(self): self._gradcheck_log_prob(Categorical, (p,)) self.assertRaises(NotImplementedError, Categorical(p).rsample) + @set_default_dtype(torch.double) def test_categorical_2d(self): probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] @@ -1357,6 +1371,7 @@ def test_categorical_enumerate_support(self): ] self._check_enumerate_support(Categorical, examples) + @set_default_dtype(torch.double) def test_one_hot_categorical_1d(self): p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) self.assertEqual(OneHotCategorical(p).sample().size(), (3,)) @@ -1366,6 +1381,7 @@ def test_one_hot_categorical_1d(self): self._gradcheck_log_prob(OneHotCategorical, (p,)) self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample) + @set_default_dtype(torch.double) def test_one_hot_categorical_2d(self): probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] @@ -1400,6 +1416,7 @@ def test_poisson_shape(self): self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,)) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @set_default_dtype(torch.double) def test_poisson_log_prob(self): rate = torch.randn(2, 3).abs().requires_grad_() rate_1d = torch.randn(1).abs().requires_grad_() @@ -1442,6 +1459,7 @@ def test_poisson_gpu_sample(self): f'Poisson(lambda={rate}, cuda)', failure_rate=1e-3) + @set_default_dtype(torch.double) def test_relaxed_bernoulli(self): p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) r = torch.tensor(0.3, requires_grad=True) @@ -1483,6 +1501,7 @@ def sample(self, *args, **kwargs): s = dist.rsample() self.assertEqual(equal_probs, s) + @set_default_dtype(torch.double) def test_relaxed_one_hot_categorical_1d(self): p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) temp = torch.tensor(0.67, requires_grad=True) @@ -1492,6 +1511,7 @@ def test_relaxed_one_hot_categorical_1d(self): self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), (1, 3)) self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p)) + @set_default_dtype(torch.double) def test_relaxed_one_hot_categorical_2d(self): probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] @@ -1541,6 +1561,7 @@ def pmf(self, samples): s = dist.rsample() self.assertEqual(equal_probs, s) + @set_default_dtype(torch.double) def test_uniform(self): low = torch.zeros(5, 5, requires_grad=True) high = (torch.ones(5, 5) * 3).requires_grad_() @@ -1597,6 +1618,7 @@ def test_vonmises_logprob(self): norm = prob.mean().item() * 2 * math.pi self.assertLess(abs(norm - 1), 1e-3) + @set_default_dtype(torch.double) def test_cauchy(self): loc = torch.zeros(5, 5, requires_grad=True) scale = torch.ones(5, 5, requires_grad=True) @@ -1627,6 +1649,7 @@ def test_cauchy(self): self._check_forward_ad(lambda x: x.cauchy_()) + @set_default_dtype(torch.double) def test_halfcauchy(self): scale = torch.ones(5, 5, requires_grad=True) scale_1d = torch.ones(1, requires_grad=True) @@ -1650,6 +1673,7 @@ def test_halfcauchy(self): self.assertEqual(scale.grad, eps) scale.grad.zero_() + @set_default_dtype(torch.double) def test_halfnormal(self): std = torch.randn(5, 5).abs().requires_grad_() std_1d = torch.randn(1).abs().requires_grad_() @@ -1694,6 +1718,7 @@ def test_halfnormal_sample(self): scipy.stats.halfnorm(scale=std), f'HalfNormal(scale={std})') + @set_default_dtype(torch.double) def test_lognormal(self): mean = torch.randn(5, 5, requires_grad=True) std = torch.randn(5, 5).abs().requires_grad_() @@ -1746,6 +1771,7 @@ def test_lognormal_sample(self): scipy.stats.lognorm(scale=math.exp(mean), s=std), f'LogNormal(loc={mean}, scale={std})') + @set_default_dtype(torch.double) def test_logisticnormal(self): set_rng_seed(1) # see Note [Randomized statistical tests] mean = torch.randn(5, 5).requires_grad_() @@ -1901,6 +1927,7 @@ def rvs(self, n_sample): f'''MixtureSameFamily(Categorical(probs={probs}), Normal(loc={loc}, scale={scale}))''') + @set_default_dtype(torch.double) def test_normal(self): loc = torch.randn(5, 5, requires_grad=True) scale = torch.randn(5, 5).abs().requires_grad_() @@ -1958,6 +1985,7 @@ def test_normal_sample(self): scipy.stats.norm(loc=loc, scale=scale), f'Normal(mean={loc}, std={scale})') + @set_default_dtype(torch.double) def test_lowrank_multivariate_normal_shape(self): mean = torch.randn(5, 3, requires_grad=True) mean_no_batch = torch.randn(3, requires_grad=True) @@ -2077,6 +2105,7 @@ def test_lowrank_multivariate_normal_moments(self): empirical_var = samples.var(0) self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0) + @set_default_dtype(torch.double) def test_multivariate_normal_shape(self): mean = torch.randn(5, 3, requires_grad=True) mean_no_batch = torch.randn(3, requires_grad=True) @@ -2136,6 +2165,7 @@ def gradcheck_func(samples, mu, sigma, prec, scale_tril): multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril) multivariate_normal_log_prob_gradcheck(mean_no_batch, None, None, scale_tril_batched) + @set_default_dtype(torch.double) def test_multivariate_normal_stable_with_precision_matrix(self): x = torch.randn(10) P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) # RBF kernel @@ -2200,6 +2230,7 @@ def test_multivariate_normal_sample(self): f'MultivariateNormal(loc={mean}, scale_tril={scale_tril})', multivariate=True) + @set_default_dtype(torch.double) def test_multivariate_normal_properties(self): loc = torch.randn(5) scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5)) @@ -2208,6 +2239,7 @@ def test_multivariate_normal_properties(self): self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0])) self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix)) + @set_default_dtype(torch.double) def test_multivariate_normal_moments(self): set_rng_seed(0) # see Note [Randomized statistical tests] mean = torch.randn(5) @@ -2220,6 +2252,7 @@ def test_multivariate_normal_moments(self): self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0) # We applied same tests in Multivariate Normal distribution for Wishart distribution + @set_default_dtype(torch.double) def test_wishart_shape(self): set_rng_seed(0) # see Note [Randomized statistical tests] ndim = 3 @@ -2289,6 +2322,7 @@ def test_wishart_stable_with_precision_matrix(self): Wishart(torch.tensor(ndim), precision_matrix=P) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @set_default_dtype(torch.double) def test_wishart_log_prob(self): set_rng_seed(0) # see Note [Randomized statistical tests] ndim = 3 @@ -2334,6 +2368,7 @@ def test_wishart_log_prob(self): self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_wishart_sample(self): set_rng_seed(0) # see Note [Randomized statistical tests] ndim = 3 @@ -2383,6 +2418,7 @@ def test_wishart_moments(self): empirical_var = samples.var(0) self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0) + @set_default_dtype(torch.double) def test_exponential(self): rate = torch.randn(5, 5).abs().requires_grad_() rate_1d = torch.randn(1).abs().requires_grad_() @@ -2431,6 +2467,7 @@ def test_exponential_sample(self): scipy.stats.expon(scale=1. / rate), f'Exponential(rate={rate})') + @set_default_dtype(torch.double) def test_laplace(self): loc = torch.randn(5, 5, requires_grad=True) scale = torch.randn(5, 5).abs().requires_grad_() @@ -2475,6 +2512,7 @@ def ref_log_prob(idx, x, log_prob): self._check_log_prob(Laplace(loc, scale), ref_log_prob) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_laplace_sample(self): set_rng_seed(1) # see Note [Randomized statistical tests] for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): @@ -2609,6 +2647,7 @@ def ref_log_prob(idx, x, log_prob): self._check_log_prob(Gumbel(loc, scale), ref_log_prob) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_gumbel_sample(self): set_rng_seed(1) # see note [Randomized statistical tests] for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): @@ -2733,6 +2772,7 @@ def ref_log_prob(idx, x, log_prob): self._check_log_prob(StudentT(df), ref_log_prob) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @set_default_dtype(torch.double) def test_studentT_sample(self): set_rng_seed(11) # see Note [Randomized statistical tests] for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): @@ -2761,6 +2801,7 @@ def test_dirichlet_shape(self): self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4)) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_dirichlet_log_prob(self): num_samples = 10 alpha = torch.exp(torch.randn(5)) @@ -2830,6 +2871,7 @@ def test_beta_log_prob(self): self.assertEqual(float(actual_log_prob), float(expected_log_prob), atol=1e-3, rtol=0) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @set_default_dtype(torch.double) def test_beta_sample(self): set_rng_seed(1) # see Note [Randomized statistical tests] for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]): @@ -2874,6 +2916,7 @@ def test_beta_underflow_gpu(self): self.assertEqual(frac_zeros, 0.5, atol=0.12, rtol=0) self.assertEqual(frac_ones, 0.5, atol=0.12, rtol=0) + @set_default_dtype(torch.double) def test_continuous_bernoulli(self): p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) r = torch.tensor(0.3, requires_grad=True) @@ -2938,7 +2981,7 @@ def tril_cholesky_to_tril_corr(x): self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample)) def test_independent_shape(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for param in params: base_dist = Dist(**param) x = base_dist.sample() @@ -2966,7 +3009,7 @@ def test_independent_shape(self): pass def test_independent_expand(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for param in params: base_dist = Dist(**param) for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1): @@ -2982,9 +3025,10 @@ def test_independent_expand(self): self.assertEqual(expanded.event_shape, indep_dist.event_shape) self.assertEqual(expanded.batch_shape, expanded_shape) + @set_default_dtype(torch.double) def test_cdf_icdf_inverse(self): # Tests the invertibility property on the distributions - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): dist = Dist(**param) samples = dist.sample(sample_shape=(20,)) @@ -3009,9 +3053,10 @@ def test_gamma_log_prob_at_boundary(self): self.assertAlmostEqual(dist.log_prob(0), log_prob) self.assertAlmostEqual(dist.log_prob(0), scipy_dist.logpdf(0)) + @set_default_dtype(torch.double) def test_cdf_log_prob(self): # Tests if the differentiation of the CDF gives the PDF at a given value - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): # We do not need grads wrt params here, e.g. shape of gamma distribution. param = {key: value.detach() if isinstance(value, torch.Tensor) else value @@ -3243,6 +3288,7 @@ def _test_continuous_distribution_mode(self, dist, sanitized_mode, batch_isfinit ordering = (delta > -1e-12).all(axis=0) self.assertTrue(ordering[batch_isfinite].all()) + @set_default_dtype(torch.double) def test_mode(self): discrete_distributions = ( Bernoulli, Binomial, Categorical, Geometric, NegativeBinomial, OneHotCategorical, Poisson, @@ -3252,7 +3298,7 @@ def test_mode(self): RelaxedBernoulli, RelaxedOneHotCategorical, ) - for dist_cls, params in EXAMPLES: + for dist_cls, params in _get_examples(): for param in params: dist = dist_cls(**param) if isinstance(dist, no_mode_available) or type(dist) is TransformedDistribution: @@ -3450,6 +3496,7 @@ def test_dirichlet_multivariate(self): "error = %.2g" % torch.abs(expected_grad - actual_grad).max(), ])) + @set_default_dtype(torch.double) def test_dirichlet_tangent_field(self): num_samples = 20 alpha_grid = [0.5, 1.0, 2.0] @@ -3496,7 +3543,7 @@ def setUp(self): self.tensor_sample_2 = torch.ones(3, 2, 3) def test_entropy_shape(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): dist = Dist(validate_args=False, **param) try: @@ -4276,7 +4323,7 @@ def test_kl_edgecases(self): self.assertEqual(kl_divergence(Categorical(torch.tensor([0., 1.])), Categorical(torch.tensor([0., 1.]))), 0) def test_kl_shape(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): dist = Dist(**param) try: @@ -4300,9 +4347,10 @@ def test_kl_transformed(self): self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,)) self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,)) + @set_default_dtype(torch.double) def test_entropy_monte_carlo(self): set_rng_seed(0) # see Note [Randomized statistical tests] - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): dist = Dist(**param) try: @@ -4320,8 +4368,9 @@ def test_entropy_monte_carlo(self): f'max error = {torch.abs(actual - expected).max()}', ])) + @set_default_dtype(torch.double) def test_entropy_exponential_family(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): if not issubclass(Dist, ExponentialFamily): continue for i, param in enumerate(params): @@ -4352,7 +4401,7 @@ def test_params_constraints(self): RelaxedOneHotCategorical ) - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for i, param in enumerate(params): dist = Dist(**param) for name, value in param.items(): @@ -4379,7 +4428,7 @@ def test_params_constraints(self): self.assertTrue(constraint.check(value).all(), msg=message) def test_support_constraints(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): self.assertIsInstance(Dist.support, Constraint) for i, param in enumerate(params): dist = Dist(**param) @@ -4615,7 +4664,7 @@ def setUp(self): super().setUp() # ContinuousBernoulli is not tested because log_prob is not computed simply # from 'logits', but 'probs' is also needed - self.examples = [e for e in EXAMPLES if e.Dist in + self.examples = [e for e in _get_examples() if e.Dist in (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)] def test_lazy_logits_initialization(self): @@ -4661,11 +4710,11 @@ def test_lazy_probs_initialization(self): class TestAgainstScipy(DistributionsTestCase): def setUp(self): super().setUp() - positive_var = torch.randn(20).exp() - positive_var2 = torch.randn(20).exp() - random_var = torch.randn(20) - simplex_tensor = softmax(torch.randn(20), dim=-1) - cov_tensor = torch.randn(20, 20) + positive_var = torch.randn(20, dtype=torch.double).exp() + positive_var2 = torch.randn(20, dtype=torch.double).exp() + random_var = torch.randn(20, dtype=torch.double) + simplex_tensor = softmax(torch.randn(20, dtype=torch.double), dim=-1) + cov_tensor = torch.randn(20, 20, dtype=torch.double) cov_tensor = cov_tensor @ cov_tensor.mT self.distribution_pairs = [ ( @@ -4726,7 +4775,7 @@ def setUp(self): scipy.stats.lognorm(s=positive_var.clamp(max=3), scale=random_var.exp()) ), ( - LowRankMultivariateNormal(random_var, torch.zeros(20, 1), positive_var2), + LowRankMultivariateNormal(random_var, torch.zeros(20, 1, dtype=torch.double), positive_var2), scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)) ), ( @@ -4813,6 +4862,7 @@ def test_variance_stddev(self): self.assertEqual(pytorch_dist.variance, scipy_dist.var(), msg=pytorch_dist) self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, msg=pytorch_dist) + @set_default_dtype(torch.double) def test_cdf(self): for pytorch_dist, scipy_dist in self.distribution_pairs: samples = pytorch_dist.sample((5,)) @@ -4824,7 +4874,7 @@ def test_cdf(self): def test_icdf(self): for pytorch_dist, scipy_dist in self.distribution_pairs: - samples = torch.rand((5,) + pytorch_dist.batch_shape) + samples = torch.rand((5,) + pytorch_dist.batch_shape, dtype=torch.double) try: icdf = pytorch_dist.icdf(samples) except NotImplementedError: @@ -4952,14 +5002,15 @@ def test_stack_transform(self): class TestValidation(DistributionsTestCase): def test_valid(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for param in params: Dist(validate_args=True, **param) + @set_default_dtype(torch.double) def test_invalid_log_probs_arg(self): # Check that validation errors are indeed disabled, # but they might raise another error - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): if Dist == TransformedDistribution: # TransformedDistribution has a distribution instance # as the argument, so we cannot do much about that @@ -5004,8 +5055,9 @@ def test_invalid_log_probs_arg(self): fail_string.format(Dist.__name__, i + 1, len(params)) ) from e + @set_default_dtype(torch.double) def test_invalid(self): - for Dist, params in BAD_EXAMPLES: + for Dist, params in _get_bad_examples(): for i, param in enumerate(params): try: with self.assertRaises(ValueError): @@ -5040,7 +5092,7 @@ def log_prob(self, value): class TestJit(DistributionsTestCase): def _examples(self): - for Dist, params in EXAMPLES: + for Dist, params in _get_examples(): for param in params: keys = param.keys() values = tuple(param[key] for key in keys) @@ -5079,6 +5131,7 @@ def _perturb(self, Dist, keys, values, sample): sample = Dist(**param).sample() return values, sample + @set_default_dtype(torch.double) def test_sample(self): for Dist, keys, values, sample in self._examples(): @@ -5138,6 +5191,7 @@ def f(*values): if Dist not in xfail: self.assertTrue(any(n.isNondeterministic() for n in traced_f.graph.nodes())) + @set_default_dtype(torch.double) def test_log_prob(self): for Dist, keys, values, sample in self._examples(): # FIXME traced functions produce incorrect results @@ -5229,6 +5283,7 @@ def f(*values): self.assertEqual(expected, actual, msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}') + @set_default_dtype(torch.double) def test_entropy(self): for Dist, keys, values, sample in self._examples(): # FIXME traced functions produce incorrect results @@ -5253,6 +5308,7 @@ def f(*values): self.assertEqual(expected, actual, msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}') + @set_default_dtype(torch.double) def test_cdf(self): for Dist, keys, values, sample in self._examples(): @@ -5276,4 +5332,5 @@ def f(sample, *values): if __name__ == '__main__' and torch._C.has_lapack: + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index c8c1441adbf72..b207860dd0de4 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -1995,8 +1995,8 @@ def setUp(self): torch.set_default_dtype(torch.double) def tearDown(self): - super().tearDown() torch.set_default_dtype(self.default_dtype) + super().tearDown() def test_conv_bn_folding(self): conv_bias = [True, False] diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index c18b345a94ce2..20d3dd4163103 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -1881,15 +1881,11 @@ def forward(self, x, y): x = torch.randn(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) - prev_default = torch.get_default_dtype() + with common_utils.set_default_dtype(torch.float): + self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y)) - torch.set_default_dtype(torch.float) - self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y)) - - torch.set_default_dtype(torch.double) - self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y)) - - torch.set_default_dtype(prev_default) + with common_utils.set_default_dtype(torch.double): + self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y)) # In scripting x, y do not carry shape and dtype info. # The following test only works when onnx shape inference is enabled. @@ -1905,23 +1901,20 @@ def forward(self, x, y): x = torch.randn(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) - prev_default = torch.get_default_dtype() - # 1. x,y are int, and output is float. # This can be handled by the default case, where both are cast to float. # It works even if type of x, y are unknown. - torch.set_default_dtype(torch.float) - self.run_test(torch.jit.script(DivModule()), (x, y)) + with common_utils.set_default_dtype(torch.float): + self.run_test(torch.jit.script(DivModule()), (x, y)) # 2. x,y are int, and output is double. # This can be handled by the default case, where both are cast to double. # It works even if type of x, y are unknown. - torch.set_default_dtype(torch.double) - self.run_test(torch.jit.script(DivModule()), (x, y)) + with common_utils.set_default_dtype(torch.double): + self.run_test(torch.jit.script(DivModule()), (x, y)) # 3. x is int, y is double, and output is double. # This can only be handled when both type of x and y are known. - torch.set_default_dtype(prev_default) x = torch.randn(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double) self.run_test(torch.jit.script(DivModule()), (x, y)) @@ -13542,4 +13535,5 @@ def test_rnn(self, *args, **kwargs): if __name__ == "__main__": + common_utils.TestCase._default_dtype_check_enabled = True common_utils.run_tests() diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index 1bbe492e237e3..71595892be26f 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -73,6 +73,8 @@ from torch.testing._internal.jit_utils import get_forward from torch.testing._internal.jit_utils import get_forward_graph +from torch.testing._internal.common_utils import set_default_dtype + from torch.jit._recursive import wrap_cpp_module # Standard library @@ -315,12 +317,12 @@ def forward(self, x): m = fuse_conv_bn_jit(m) FileCheck().check_count("prim::CallMethod", 2, exactly=True).run(m.graph) + @set_default_dtype(torch.double) def test_foldbn_complex_cases(self): # This test case attempt to try combinations of conv2d/conv3d with bias/nobias # as well as BatchNorm with affine/no-affine along with varying the # number of layers. # this only works when default dtype is double - torch.set_default_dtype(torch.double) bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} @@ -374,8 +376,6 @@ def forward(self, x): self.assertEqual(eager(x), scripted_or_traced(x)) - torch.set_default_dtype(torch.float) - def test_fuse_linear(self): class FunctionalLinear(torch.nn.Module): def __init__(self, weight, bias): diff --git a/test/test_complex.py b/test/test_complex.py index 36f17b3b01f6a..e48d29213e3b0 100644 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -6,7 +6,7 @@ dtypes, onlyCPU, ) -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, set_default_dtype from torch.testing._internal.common_dtype import complex_types devices = (torch.device('cpu'), torch.device('cuda:0')) @@ -21,10 +21,8 @@ def test_to_list(self, device, dtype): @dtypes(torch.float32, torch.float64) def test_dtype_inference(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/36834 - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - x = torch.tensor([3., 3. + 5.j], device=device) - torch.set_default_dtype(default_dtype) + with set_default_dtype(dtype): + x = torch.tensor([3., 3. + 5.j], device=device) self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat) @onlyCPU @@ -168,4 +166,5 @@ def test_ne(self, device, dtype): instantiate_device_type_tests(TestComplexTensor, globals()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_cpp_api_parity.py b/test/test_cpp_api_parity.py index 107c902e24268..f14a0973e03a2 100644 --- a/test/test_cpp_api_parity.py +++ b/test/test_cpp_api_parity.py @@ -1,8 +1,5 @@ # Owner(s): ["module: cpp"] -import torch -# NN tests use double as the default dtype -torch.set_default_dtype(torch.double) import os @@ -59,4 +56,5 @@ class TestCppApiParity(common.TestCase): functional_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE) if __name__ == "__main__": + common.TestCase._default_dtype_check_enabled = True common.run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index cb71f5c888a83..6ffca95e33f1d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -16259,6 +16259,7 @@ def test_version(self): add_nn_module_test(**test) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() import jit.test_module_interface suite = unittest.findTestCases(jit.test_module_interface) diff --git a/test/test_linalg.py b/test/test_linalg.py index ad8dd36a0bdf0..3bbc537b9c595 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -35,8 +35,6 @@ import torch.backends.opt_einsum as opt_einsum # Protects against includes accidentally setting the default dtype -# NOTE: jit_metaprogramming_utils sets the default dtype to double! -torch.set_default_dtype(torch.float32) assert torch.get_default_dtype() is torch.float32 if TEST_SCIPY: @@ -7585,4 +7583,5 @@ def test(): instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 37cb5f8d8352e..cfa34077ecf22 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -26,8 +26,6 @@ ) # Protects against includes accidentally setting the default dtype -# NOTE: jit_metaprogramming_utils sets the default dtype to double! -torch.set_default_dtype(torch.float32) assert torch.get_default_dtype() is torch.float32 @@ -263,4 +261,5 @@ def test_float32_output_errors_with_bias(self, device) -> None: instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index 7e798c9442650..faebb160e3a95 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12810,4 +12810,5 @@ def test_fuse_linear_bn_requires_grad(self): instantiate_parametrized_tests(TestNN) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 8830a52e29250..ae3355f7dff7b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2183,4 +2183,5 @@ def test_fake_crossref_backward_amp(self, device, dtype, op): instantiate_device_type_tests(TestTags, globals()) if __name__ == "__main__": + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_ops_fwd_gradients.py b/test/test_ops_fwd_gradients.py index 4b7b1c785d5f0..bec2725822e47 100644 --- a/test/test_ops_fwd_gradients.py +++ b/test/test_ops_fwd_gradients.py @@ -4,14 +4,11 @@ import torch from torch.testing._internal.common_utils import ( - TestGradients, run_tests, skipIfTorchInductor, IS_MACOS) + TestGradients, run_tests, skipIfTorchInductor, IS_MACOS, TestCase) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, OpDTypes) -# TODO: fixme https://github.com/pytorch/pytorch/issues/68972 -torch.set_default_dtype(torch.float32) - # TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033 # AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The # issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33 @@ -73,4 +70,5 @@ def test_inplace_forward_mode_AD(self, device, dtype, op): instantiate_device_type_tests(TestFwdGradients, globals()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_ops_gradients.py b/test/test_ops_gradients.py index 39caa9e7c4ec8..93db89ab7dd8f 100644 --- a/test/test_ops_gradients.py +++ b/test/test_ops_gradients.py @@ -3,16 +3,13 @@ from functools import partial import torch -from torch.testing._internal.common_utils import TestGradients, run_tests +from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, OpDTypes) -# TODO: fixme https://github.com/pytorch/pytorch/issues/68972 -torch.set_default_dtype(torch.float32) - # gradcheck requires double precision _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]) @@ -90,4 +87,5 @@ def test_inplace_gradgrad(self, device, dtype, op): instantiate_device_type_tests(TestBwdGradients, globals()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_ops_jit.py b/test/test_ops_jit.py index c01f6d36f3a50..d8e80048bb217 100644 --- a/test/test_ops_jit.py +++ b/test/test_ops_jit.py @@ -7,7 +7,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_utils import \ - (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample) + (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, TestCase) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference @@ -15,9 +15,6 @@ from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda -# TODO: fixme https://github.com/pytorch/pytorch/issues/68972 -torch.set_default_dtype(torch.float32) - # variant testing is only done with torch.float and torch.cfloat to avoid # excessive test times and maximize signal to noise ratio _variant_ops = partial(ops, dtypes=OpDTypes.supported, @@ -297,4 +294,5 @@ def _fn(*sample_args, **sample_kwargs): instantiate_device_type_tests(TestJit, globals()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 493a10f674cdb..a6253b0cbcedb 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -15,6 +15,7 @@ from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest, + set_default_dtype, set_default_tensor_type, TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo) from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, @@ -1965,37 +1966,36 @@ def test_ones(self, device): # TODO: this test should be updated @onlyCPU def test_constructor_dtypes(self, device): - default_type = torch.tensor([]).type() self.assertIs(torch.tensor([]).dtype, torch.get_default_dtype()) self.assertIs(torch.uint8, torch.ByteTensor.dtype) self.assertIs(torch.float32, torch.FloatTensor.dtype) self.assertIs(torch.float64, torch.DoubleTensor.dtype) - torch.set_default_tensor_type('torch.FloatTensor') - self.assertIs(torch.float32, torch.get_default_dtype()) - self.assertIs(torch.FloatStorage, torch.Storage) + with set_default_tensor_type('torch.FloatTensor'): + self.assertIs(torch.float32, torch.get_default_dtype()) + self.assertIs(torch.FloatStorage, torch.Storage) # only floating-point types are supported as the default type self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor')) - torch.set_default_dtype(torch.float64) - self.assertIs(torch.float64, torch.get_default_dtype()) - self.assertIs(torch.DoubleStorage, torch.Storage) + with set_default_dtype(torch.float64): + self.assertIs(torch.float64, torch.get_default_dtype()) + self.assertIs(torch.DoubleStorage, torch.Storage) - torch.set_default_tensor_type(torch.FloatTensor) - self.assertIs(torch.float32, torch.get_default_dtype()) - self.assertIs(torch.FloatStorage, torch.Storage) + with set_default_tensor_type(torch.FloatTensor): + self.assertIs(torch.float32, torch.get_default_dtype()) + self.assertIs(torch.FloatStorage, torch.Storage) if torch.cuda.is_available(): - torch.set_default_tensor_type(torch.cuda.FloatTensor) - self.assertIs(torch.float32, torch.get_default_dtype()) - self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype) - self.assertIs(torch.cuda.FloatStorage, torch.Storage) + with set_default_tensor_type(torch.cuda.FloatTensor): + self.assertIs(torch.float32, torch.get_default_dtype()) + self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype) + self.assertIs(torch.cuda.FloatStorage, torch.Storage) - torch.set_default_dtype(torch.float64) - self.assertIs(torch.float64, torch.get_default_dtype()) - self.assertIs(torch.cuda.DoubleStorage, torch.Storage) + with set_default_dtype(torch.float64): + self.assertIs(torch.float64, torch.get_default_dtype()) + self.assertIs(torch.cuda.DoubleStorage, torch.Storage) # don't allow passing dtype to set_default_tensor_type self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32)) @@ -2008,12 +2008,11 @@ def test_constructor_dtypes(self, device): torch.float, torch.double, torch.bfloat16): - torch.set_default_dtype(t) + with set_default_dtype(t): + pass else: self.assertRaises(TypeError, lambda: torch.set_default_dtype(t)) - torch.set_default_tensor_type(default_type) - # TODO: this test should be updated @onlyCPU def test_constructor_device_legacy(self, device): @@ -2049,14 +2048,10 @@ def test_constructor_device_legacy(self, device): self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cpu')) self.assertRaises(RuntimeError, lambda: i.new(i, device='cpu')) - default_type = torch.Tensor().type() - torch.set_default_tensor_type(torch.cuda.FloatTensor) - self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu')) - self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu')) - self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu')) - torch.set_default_tensor_type(torch.cuda.FloatTensor) - torch.set_default_tensor_type(default_type) - + with set_default_tensor_type(torch.cuda.FloatTensor): + self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu')) + self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu')) + self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu')) x = torch.randn((3,), device='cuda') self.assertRaises(RuntimeError, lambda: x.new(device='cpu')) self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu')) @@ -2158,8 +2153,6 @@ def check_copy(copy, is_leaf, requires_grad, data_ptr=None): @onlyCPU def test_tensor_factory_type_inference(self, device): def test_inference(default_dtype): - saved_dtype = torch.get_default_dtype() - torch.set_default_dtype(default_dtype) default_complex_dtype = torch.complex64 if default_dtype == torch.float32 else torch.complex128 self.assertIs(default_dtype, torch.tensor(()).dtype) self.assertIs(default_dtype, torch.tensor(5.).dtype) @@ -2181,10 +2174,10 @@ def test_inference(default_dtype): self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype) self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype) self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype) - torch.set_default_dtype(saved_dtype) - test_inference(torch.float64) - test_inference(torch.float32) + for dtype in [torch.float64, torch.float32]: + with set_default_dtype(dtype): + test_inference(dtype) # TODO: this test should be updated @suppress_warnings @@ -2471,8 +2464,6 @@ def test_arange(self, device): @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") @onlyCPU def test_arange_inference(self, device): - saved_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float32) # end only self.assertIs(torch.float32, torch.arange(1.).dtype) self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype) @@ -2501,7 +2492,6 @@ def test_arange_inference(self, device): torch.arange(torch.tensor(1), torch.tensor(3), torch.tensor(1, dtype=torch.int16)).dtype) - torch.set_default_dtype(saved_dtype) # cannot call storage() on meta tensor @skipMeta @@ -2818,28 +2808,24 @@ def test_tensor_factories_empty(self, device): @onlyCUDA def test_tensor_factory_gpu_type_inference(self, device): - saved_type = torch.tensor([]).type() - torch.set_default_tensor_type(torch.cuda.DoubleTensor) - torch.set_default_dtype(torch.float32) - self.assertIs(torch.float32, torch.tensor(0.).dtype) - self.assertEqual(torch.device(device), torch.tensor(0.).device) - torch.set_default_dtype(torch.float64) - self.assertIs(torch.float64, torch.tensor(0.).dtype) - self.assertEqual(torch.device(device), torch.tensor(0.).device) - torch.set_default_tensor_type(saved_type) + with set_default_tensor_type(torch.cuda.DoubleTensor): + with set_default_dtype(torch.float32): + self.assertIs(torch.float32, torch.tensor(0.).dtype) + self.assertEqual(torch.device(device), torch.tensor(0.).device) + with set_default_dtype(torch.float64): + self.assertIs(torch.float64, torch.tensor(0.).dtype) + self.assertEqual(torch.device(device), torch.tensor(0.).device) @onlyCUDA def test_tensor_factory_gpu_type(self, device): - saved_type = torch.tensor([]).type() - torch.set_default_tensor_type(torch.cuda.FloatTensor) - x = torch.zeros((5, 5)) - self.assertIs(torch.float32, x.dtype) - self.assertTrue(x.is_cuda) - torch.set_default_tensor_type(torch.cuda.DoubleTensor) - x = torch.zeros((5, 5)) - self.assertIs(torch.float64, x.dtype) - self.assertTrue(x.is_cuda) - torch.set_default_tensor_type(saved_type) + with set_default_tensor_type(torch.cuda.FloatTensor): + x = torch.zeros((5, 5)) + self.assertIs(torch.float32, x.dtype) + self.assertTrue(x.is_cuda) + with set_default_tensor_type(torch.cuda.DoubleTensor): + x = torch.zeros((5, 5)) + self.assertIs(torch.float64, x.dtype) + self.assertTrue(x.is_cuda) @skipCPUIf(True, 'compares device with cpu') @dtypes(torch.int, torch.long, torch.float, torch.double) @@ -3081,27 +3067,23 @@ def test_logspace(self, device, dtype): def test_full_inference(self, device, dtype): size = (2, 2) - prev_default = torch.get_default_dtype() - torch.set_default_dtype(dtype) - - # Tests bool fill value inference - t = torch.full(size, True) - self.assertEqual(t.dtype, torch.bool) - - # Tests integer fill value inference - t = torch.full(size, 1) - self.assertEqual(t.dtype, torch.long) + with set_default_dtype(dtype): + # Tests bool fill value inference + t = torch.full(size, True) + self.assertEqual(t.dtype, torch.bool) - # Tests float fill value inference - t = torch.full(size, 1.) - self.assertEqual(t.dtype, dtype) + # Tests integer fill value inference + t = torch.full(size, 1) + self.assertEqual(t.dtype, torch.long) - # Tests complex inference - t = torch.full(size, (1 + 1j)) - ctype = torch.complex128 if dtype is torch.double else torch.complex64 - self.assertEqual(t.dtype, ctype) + # Tests float fill value inference + t = torch.full(size, 1.) + self.assertEqual(t.dtype, dtype) - torch.set_default_dtype(prev_default) + # Tests complex inference + t = torch.full(size, (1 + 1j)) + ctype = torch.complex128 if dtype is torch.double else torch.complex64 + self.assertEqual(t.dtype, ctype) def test_full_out(self, device): size = (5,) @@ -4070,4 +4052,5 @@ def test_device_without_index(self, device): instantiate_device_type_tests(TestAsArray, globals()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index ffcf3bfcf07cb..c13a71911f75a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -32,7 +32,7 @@ TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON, IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf, - TEST_WITH_CROSSREF, skipIfTorchDynamo, + TEST_WITH_CROSSREF, skipIfTorchDynamo, set_default_dtype, skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, @@ -6010,6 +6010,7 @@ def test_index_add_all_dtypes(self): self.assertEqual(added, -tensor) @skipIfTorchInductor("AssertionError: RuntimeError not raised by ") + @set_default_dtype(torch.double) def test_index_add_correctness(self): # Check whether index_add can get correct result when # alpha is 1, and dtype of index is torch.long, @@ -7273,21 +7274,21 @@ def test_print(self): self.assertExpectedInline(str(y), expected_str) # test dtype - torch.set_default_dtype(torch.float) - x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64) - self.assertEqual(x.__repr__(), str(x)) - expected_str = '''\ + with set_default_dtype(torch.float): + x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64) + self.assertEqual(x.__repr__(), str(x)) + expected_str = '''\ tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, inf], dtype=torch.float64)''' - self.assertExpectedInline(str(x), expected_str) + self.assertExpectedInline(str(x), expected_str) # test changing default dtype - torch.set_default_dtype(torch.float64) - self.assertEqual(x.__repr__(), str(x)) - expected_str = '''\ + with set_default_dtype(torch.float64): + self.assertEqual(x.__repr__(), str(x)) + expected_str = '''\ tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, inf])''' - self.assertExpectedInline(str(x), expected_str) + self.assertExpectedInline(str(x), expected_str) # test summary x = torch.zeros(10000) @@ -9267,4 +9268,5 @@ class TestTensorDeviceOps(TestCase): instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/test/test_type_info.py b/test/test_type_info.py index de63e3dc591ae..3d0e35f6050dc 100644 --- a/test/test_type_info.py +++ b/test/test_type_info.py @@ -1,6 +1,6 @@ # Owner(s): ["module: typing"] -from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY, load_tests +from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY, load_tests, set_default_dtype # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -38,7 +38,6 @@ def test_iinfo(self): @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_finfo(self): - initial_default_type = torch.get_default_dtype() for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]: x = torch.zeros((2, 2), dtype=dtype) xinfo = torch.finfo(x.dtype) @@ -52,8 +51,8 @@ def test_finfo(self): self.assertEqual(xinfo.resolution, xninfo.resolution) self.assertEqual(xinfo.dtype, xninfo.dtype) if not dtype.is_complex: - torch.set_default_dtype(dtype) - self.assertEqual(torch.finfo(dtype), torch.finfo()) + with set_default_dtype(dtype): + self.assertEqual(torch.finfo(dtype), torch.finfo()) # Special test case for BFloat16 type x = torch.zeros((2, 2), dtype=torch.bfloat16) @@ -66,11 +65,9 @@ def test_finfo(self): self.assertEqual(xinfo.tiny, xinfo.smallest_normal) self.assertEqual(xinfo.resolution, 0.01) self.assertEqual(xinfo.dtype, "bfloat16") - torch.set_default_dtype(x.dtype) - self.assertEqual(torch.finfo(x.dtype), torch.finfo()) - - # Restore the default type to ensure that the test has no side effect - torch.set_default_dtype(initial_default_type) + with set_default_dtype(x.dtype): + self.assertEqual(torch.finfo(x.dtype), torch.finfo()) if __name__ == '__main__': + TestCase._default_dtype_check_enabled = True run_tests() diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index bddde2e0f1899..6344e460a2955 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -1785,14 +1785,14 @@ def unsqueeze_inp(inp): dict( fullname='EmbeddingBag_sparse', constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double), - cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)', + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', input_fn=lambda: torch.randperm(2).repeat(1, 2), check_gradgrad=False, has_sparse_gradients=True, ), dict( constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True), - cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)', + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', input_fn=lambda: torch.randperm(2).repeat(1, 2), fullname='Embedding_sparse', check_gradgrad=False, @@ -3168,7 +3168,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 ), dict( module_name='MSELoss', - input_size=(2, 3, 4, 5), + input_fn=lambda: torch.rand((2, 3, 4, 5), dtype=torch.double), target_fn=lambda: torch.randn((2, 3, 4, 5), dtype=torch.double, requires_grad=True), reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel() if get_reduction(m) == 'mean' else 1)), @@ -3314,9 +3314,9 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 dict( module_name='MultiMarginLoss', constructor_args=(1, 1., torch.rand(10, dtype=torch.double)), - cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10))', + cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10).to(torch::kFloat64))', legacy_constructor_args=(1, torch.rand(10, dtype=torch.double)), - input_size=(5, 10), + input_fn=lambda: torch.rand(5, 10, dtype=torch.double), target_fn=lambda: torch.rand(5).mul(8).floor().long(), reference_fn=lambda i, t, m: multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)), @@ -3403,7 +3403,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 dict( module_name='BCEWithLogitsLoss', constructor_args=(torch.rand(10, dtype=torch.double),), - cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10))', + cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10).to(torch::kFloat64))', input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.randn(15, 10).gt(0).to(torch.get_default_dtype()), desc='weights', @@ -3412,7 +3412,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 dict( module_name='BCEWithLogitsLoss', constructor_args=(torch.rand((), dtype=torch.double),), - cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}))', + cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}).to(torch::kFloat64))', input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.randn(()).gt(0).to(torch.get_default_dtype()), desc='scalar_weights', @@ -3826,7 +3826,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 ), dict( module_name='MSELoss', - input_size=(), + input_fn=lambda: torch.rand((), dtype=torch.double), target_fn=lambda: torch.randn((), requires_grad=True, dtype=torch.double), reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel() if get_reduction(m) == 'mean' else 1)), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 447f25259bc4e..61f72026ae39a 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1653,6 +1653,15 @@ def set_default_dtype(dtype): finally: torch.set_default_dtype(saved_dtype) +@contextlib.contextmanager +def set_default_tensor_type(tensor_type): + saved_tensor_type = torch.tensor([]).type() + torch.set_default_tensor_type(tensor_type) + try: + yield + finally: + torch.set_default_tensor_type(saved_tensor_type) + def iter_indices(tensor): if tensor.dim() == 0: return range(0) @@ -2241,6 +2250,10 @@ class TestCase(expecttest.TestCase): _precision: float = 0 _rel_tol: float = 0 + # Toggles whether to assert that `torch.get_default_dtype()` returns + # `torch.float` when `setUp` and `tearDown` are called. + _default_dtype_check_enabled: bool = False + # checker to early terminate test suite if unrecoverable failure occurs. def _should_stop_test_suite(self): if torch.cuda.is_initialized(): @@ -2557,6 +2570,9 @@ def setUp(self): # decorator to disable the invariant checks. torch.sparse.check_sparse_tensor_invariants.enable() + if self._default_dtype_check_enabled: + assert torch.get_default_dtype() == torch.float + def tearDown(self): # There exists test cases that override TestCase.setUp # definition, so we cannot assume that _check_invariants @@ -2568,6 +2584,9 @@ def tearDown(self): else: torch.sparse.check_sparse_tensor_invariants.disable() + if self._default_dtype_check_enabled: + assert torch.get_default_dtype() == torch.float + @staticmethod def _make_crow_indices(n_rows, n_cols, nnz, *, device, dtype, random=True):