Skip to content

Commit

Permalink
add channel last 3d support for batch_norm on CPU (pytorch#97774)
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoE authored and pytorchmergebot committed Aug 3, 2023
1 parent 719c493 commit f82e6ff
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
6 changes: 4 additions & 2 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,16 @@ struct Var {
};

static inline bool is_contiguous(const Tensor& t) {
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast);
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) || t.is_contiguous(at::MemoryFormat::ChannelsLast3d);
}

// For some ambiguous cases, it is possible a channels last contiguous Tensor has
// `suggest_memory_format` of Contiguous.
// See https://github.com/pytorch/pytorch/issues/63224 for details.
static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) {
return t.is_contiguous() ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
return t.is_contiguous() ?
at::MemoryFormat::Contiguous : (t.is_contiguous(at::MemoryFormat::ChannelsLast3d) ?
at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast);
}

template<typename scalar_t, typename param_t>
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/batch_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ void batch_norm_cpu_kernel(Tensor& output, const Tensor& input,
save_mean, save_invstd, running_mean, running_var, train, eps);
}
});
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast) || input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "batch_norm_cpu_channels_last", [&] {
batch_norm_cpu_channels_last_impl<scalar_t>(output, input, weight, bias,
save_mean, save_invstd, running_mean, running_var, train, eps);
Expand All @@ -1257,7 +1257,7 @@ void batch_norm_cpu_collect_stats_kernel(
batch_norm_cpu_collect_stats_contiguous_impl<scalar_t>(mean, var_sum, input);
}
});
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast) || input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "batch_norm_cpu_collect_stats_channels_last", [&] {
batch_norm_cpu_collect_stats_channels_last_impl<scalar_t>(mean, var_sum, input);
});
Expand All @@ -1281,7 +1281,7 @@ void batch_norm_cpu_backward_kernel(Tensor& grad_input, Tensor& grad_weight, Ten
grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
}
});
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast) || input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "batch_norm_cpu_backward_channels_last", [&] {
batch_norm_cpu_backward_channels_last_impl<scalar_t>(grad_input, grad_weight, grad_bias,
grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
Expand Down
23 changes: 14 additions & 9 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5083,20 +5083,20 @@ def test_hardtanh_backward(self):
self.assertEqual(x.grad, x_grad_ref)

def test_batchnorm_nhwc_cpu(self):
def helper(self, size, dtype, mixed_dtype=False):
def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last):
channels = size[1]
input = torch.randn(size, dtype=dtype, device='cpu', requires_grad=True)
input = input.contiguous(memory_format=torch.channels_last).to(dtype)
input = input.contiguous(memory_format=format).to(dtype)
input.retain_grad()
grad = torch.randn(size, dtype=dtype, device='cpu')
grad = grad.contiguous(memory_format=torch.channels_last)
bn = nn.BatchNorm2d(channels).cpu().to(dtype)
grad = grad.contiguous(memory_format=format)
bn = mod(channels).cpu().to(dtype)
bn.weight.data.uniform_()
bn.bias.data.uniform_()

ref_input = input.detach().clone().contiguous().requires_grad_(True)
ref_grad = grad.detach().clone().contiguous()
ref_bn = nn.BatchNorm2d(channels).cpu().to(dtype)
ref_bn = mod(channels).cpu().to(dtype)
ref_bn.load_state_dict(bn.state_dict())

if mixed_dtype:
Expand All @@ -5108,7 +5108,7 @@ def helper(self, size, dtype, mixed_dtype=False):
ref_out = ref_bn(ref_input)
ref_out.backward(ref_grad)

self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(out.is_contiguous(memory_format=format))
self.assertTrue(ref_out.is_contiguous())
self.assertEqual(out, ref_out)
self.assertEqual(bn.weight.grad, ref_bn.weight.grad)
Expand All @@ -5117,9 +5117,14 @@ def helper(self, size, dtype, mixed_dtype=False):

# test NC11 and N1HW; test mixed dtype
for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1)]:
helper(self, shape, torch.float, False)
helper(self, shape, torch.bfloat16, False)
helper(self, shape, torch.bfloat16, True)
helper(self, nn.BatchNorm2d, shape, torch.float, False, torch.channels_last)
helper(self, nn.BatchNorm2d, shape, torch.bfloat16, False, torch.channels_last)
helper(self, nn.BatchNorm2d, shape, torch.bfloat16, True, torch.channels_last)

for shape in [(4, 8, 2, 10, 10), (4, 1, 2, 9, 9), (4, 9, 1, 1, 1)]:
helper(self, nn.BatchNorm3d, shape, torch.float, False, torch.channels_last_3d)
helper(self, nn.BatchNorm3d, shape, torch.bfloat16, False, torch.channels_last_3d)
helper(self, nn.BatchNorm3d, shape, torch.bfloat16, True, torch.channels_last_3d)

@parametrize_test(
'bn_module',
Expand Down

0 comments on commit f82e6ff

Please sign in to comment.