Skip to content

Commit

Permalink
grid_sample: support bfloat16 (pytorch#112331)
Browse files Browse the repository at this point in the history
This adds bfloat16 support to `torch.nn.functional.grid_sample` this is particularly important when doing feature sampling such as for rendering techniques used in PyTorch3d or for camera projections to voxel grids such as in SimpleBEV.

Related to pytorch#57707

Test plan:

```
pytest test/test_nn.py -k grid_sample
pytest test/test_ops.py -k grid_sample
```
Pull Request resolved: pytorch#112331
Approved by: https://github.com/zou3519
  • Loading branch information
d4l3k authored and pytorchmergebot committed Oct 30, 2023
1 parent 3b58755 commit 013f622
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
16 changes: 12 additions & 4 deletions aten/src/ATen/native/cuda/GridSampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,9 @@ void launch_grid_sampler_2d_forward_kernel(
auto W = grid.size(2);
int64_t count = N * H * W;
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
input.scalar_type(), "grid_sampler_2d_cuda", [&] {
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
canUse32BitIndexMath(output)) {
grid_sampler_2d_kernel<scalar_t>
Expand Down Expand Up @@ -803,7 +805,9 @@ void launch_grid_sampler_3d_forward_kernel(
auto W = grid.size(3);
int64_t count = N * D * H * W;
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
input.scalar_type(), "grid_sampler_3d_cuda", [&] {
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
canUse32BitIndexMath(output)) {
grid_sampler_3d_kernel<scalar_t>
Expand Down Expand Up @@ -856,7 +860,9 @@ void launch_grid_sampler_2d_backward_kernel(

int64_t count = N * H * W;
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
canUse32BitIndexMath(grad_output)) {
grid_sampler_2d_backward_kernel<scalar_t>
Expand Down Expand Up @@ -913,7 +919,9 @@ void launch_grid_sampler_3d_backward_kernel(
int64_t count = N * D * H * W;
auto input_requires_grad = output_mask[0];
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
canUse32BitIndexMath(grad_output)) {
grid_sampler_3d_backward_kernel<scalar_t>
Expand Down
20 changes: 20 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10506,6 +10506,26 @@ def helper(shape_in, shape_out, align_corners):
helper((32, 64, 16, 16), (32, 8, 8, 2), False)
helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False)

@onlyCUDA
def test_grid_sample_bfloat16_precision(self):
def helper(shape_in, shape_out, align_corners):
for mode in ('bilinear', 'nearest', 'bicubic'):
if len(shape_in) != 4 and mode == 'bicubic':
continue
data = torch.randn(shape_in, device='cuda', dtype=torch.bfloat16)
grid = torch.rand(shape_out, device='cuda', dtype=torch.bfloat16) * 2.0 - 1.0

out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners)
out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros',
align_corners=align_corners)

self.assertEqual(out_half, out_double.bfloat16(), msg=f"grid_sample with mode = {mode} doesn't match")

helper((32, 64, 16, 16), (32, 8, 8, 2), True)
helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True)
helper((32, 64, 16, 16), (32, 8, 8, 2), False)
helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False)

def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected):
logits = torch.randn(shape, dtype=torch.float, device=device)
logits = logits.to(dtype)
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18050,7 +18050,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
OpInfo(
"nn.functional.grid_sample",
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_grid_sample,
reference_inputs_func=reference_inputs_grid_sample,
Expand All @@ -18060,7 +18060,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
OpInfo(
"grid_sampler_2d",
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_grid_sampler_2d,
supports_gradgrad=False,
Expand Down

0 comments on commit 013f622

Please sign in to comment.