From 013f622dd2307109bbd2c2a68a578c1a6ced3cbe Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 30 Oct 2023 19:31:35 +0000 Subject: [PATCH] grid_sample: support bfloat16 (#112331) This adds bfloat16 support to `torch.nn.functional.grid_sample` this is particularly important when doing feature sampling such as for rendering techniques used in PyTorch3d or for camera projections to voxel grids such as in SimpleBEV. Related to #57707 Test plan: ``` pytest test/test_nn.py -k grid_sample pytest test/test_ops.py -k grid_sample ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/112331 Approved by: https://github.com/zou3519 --- aten/src/ATen/native/cuda/GridSampler.cu | 16 +++++++++++---- test/test_nn.py | 20 +++++++++++++++++++ .../_internal/common_methods_invocations.py | 4 ++-- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/GridSampler.cu b/aten/src/ATen/native/cuda/GridSampler.cu index cdba9f3ab971a..9d87cbc327114 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cu +++ b/aten/src/ATen/native/cuda/GridSampler.cu @@ -760,7 +760,9 @@ void launch_grid_sampler_2d_forward_kernel( auto W = grid.size(2); int64_t count = N * H * W; if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + input.scalar_type(), "grid_sampler_2d_cuda", [&] { if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && canUse32BitIndexMath(output)) { grid_sampler_2d_kernel @@ -803,7 +805,9 @@ void launch_grid_sampler_3d_forward_kernel( auto W = grid.size(3); int64_t count = N * D * H * W; if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + input.scalar_type(), "grid_sampler_3d_cuda", [&] { if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && canUse32BitIndexMath(output)) { grid_sampler_3d_kernel @@ -856,7 +860,9 @@ void launch_grid_sampler_2d_backward_kernel( int64_t count = N * H * W; if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] { if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && canUse32BitIndexMath(grad_output)) { grid_sampler_2d_backward_kernel @@ -913,7 +919,9 @@ void launch_grid_sampler_3d_backward_kernel( int64_t count = N * D * H * W; auto input_requires_grad = output_mask[0]; if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] { if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && canUse32BitIndexMath(grad_output)) { grid_sampler_3d_backward_kernel diff --git a/test/test_nn.py b/test/test_nn.py index d16f1377b5a8b..c34d160b2e224 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -10506,6 +10506,26 @@ def helper(shape_in, shape_out, align_corners): helper((32, 64, 16, 16), (32, 8, 8, 2), False) helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False) + @onlyCUDA + def test_grid_sample_bfloat16_precision(self): + def helper(shape_in, shape_out, align_corners): + for mode in ('bilinear', 'nearest', 'bicubic'): + if len(shape_in) != 4 and mode == 'bicubic': + continue + data = torch.randn(shape_in, device='cuda', dtype=torch.bfloat16) + grid = torch.rand(shape_out, device='cuda', dtype=torch.bfloat16) * 2.0 - 1.0 + + out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners) + out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros', + align_corners=align_corners) + + self.assertEqual(out_half, out_double.bfloat16(), msg=f"grid_sample with mode = {mode} doesn't match") + + helper((32, 64, 16, 16), (32, 8, 8, 2), True) + helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True) + helper((32, 64, 16, 16), (32, 8, 8, 2), False) + helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False) + def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected): logits = torch.randn(shape, dtype=torch.float, device=device) logits = logits.to(dtype) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cd014768eea24..755ccf69d71b9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -18050,7 +18050,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): OpInfo( "nn.functional.grid_sample", dtypes=floating_types(), - dtypesIfCUDA=floating_types_and(torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, sample_inputs_func=sample_inputs_grid_sample, reference_inputs_func=reference_inputs_grid_sample, @@ -18060,7 +18060,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): OpInfo( "grid_sampler_2d", dtypes=floating_types(), - dtypesIfCUDA=floating_types_and(torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, sample_inputs_func=sample_inputs_grid_sampler_2d, supports_gradgrad=False,