Skip to content

Commit

Permalink
add Half support for BatchNorm on CPU (pytorch#102070)
Browse files Browse the repository at this point in the history
Fixes pytorch#106543

### Testing

Single core:

shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
(1, 4, 256, 256) | 0.7116 | 0.1427 | 0.1744 | 0.2638 | 0.2002 | 0.2556
(1, 32, 100, 100) | 0.8579 | 0.1725 | 0.2077 | 0.3023 | 0.2399 | 0.2995
(32, 16, 200, 200) | 57.3466 | 12.2179 | 13.1320 | 45.9524 | 24.1526 | 24.9882

28 cores:

shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
(1, 4, 256, 256) | 0.2571 | 0.0713 | 0.0846 | 0.1140 | 0.0883 |  0.1043
(1, 32, 100, 100) | 0.1077 | 0.0510 | 0.0548 | 0.0700 | 0.0645 | 0.0713
(32, 16, 200, 200) | 5.5060 | 1.4195 | 1.4663 | 6.773 | 3.0886 | 3.1343

Pull Request resolved: pytorch#102070
Approved by: https://github.com/jgong5, https://github.com/mikaylagawarecki
  • Loading branch information
CaoE authored and pytorchmergebot committed Sep 13, 2023
1 parent f6d8ecf commit 6065e7a
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 220 deletions.
20 changes: 12 additions & 8 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ATen/native/Resize.h>
#include <ATen/native/cpu/mixed_data_type.h>
#include <c10/util/irange.h>
#include <ATen/OpMathType.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -695,10 +696,11 @@ std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});

const bool mixed_type = is_mixed_type(self, running_mean, running_var);
return AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "batch_norm_update_stats_cpu", [&] {
return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_update_stats_cpu", [&] {
using opmath_t = at::opmath_type<scalar_t>;
if (mixed_type) {
check_mixed_data_type(self, running_mean, running_var);
return batch_norm_cpu_update_stats_template<BFloat16, float, Var>(self, running_mean, running_var, momentum, 0);
return batch_norm_cpu_update_stats_template<scalar_t, opmath_t, Var>(self, running_mean, running_var, momentum, 0);
} else {
return batch_norm_cpu_update_stats_template<scalar_t, scalar_t, Var>(self, running_mean, running_var, momentum, 0);
}
Expand All @@ -719,17 +721,18 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cpu_out(const Tensor& self, con
at::native::resize_output(out, self.sizes());

const bool mixed_type = is_mixed_type(self, weight, bias, running_mean, running_var);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "batch_norm", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm", [&] {
using opmath_t = at::opmath_type<scalar_t>;
if (mixed_type) {
check_mixed_data_type(self, weight, bias, running_mean, running_var);
if (!train) {
return batch_norm_cpu_transform_input_template<BFloat16, float>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out);
return batch_norm_cpu_transform_input_template<scalar_t, opmath_t>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out);
} else {
// Resize save_mean and save_var
at::native::resize_output(save_mean, {self.size(1)});
at::native::resize_output(save_var, {self.size(1)});
auto save_stats = batch_norm_cpu_update_stats_template<BFloat16, float, InvStd>(self, running_mean, running_var, momentum, eps, save_mean, save_var);
return batch_norm_cpu_transform_input_template<BFloat16, float>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out);
auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, opmath_t, InvStd>(self, running_mean, running_var, momentum, eps, save_mean, save_var);
return batch_norm_cpu_transform_input_template<scalar_t, opmath_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out);
}
} else {
if (!train) {
Expand Down Expand Up @@ -836,10 +839,11 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_ou
const Tensor& save_invstd = c10::value_or_else(save_invstd_opt, [] {return Tensor();});

const bool mixed_type = is_mixed_type(self, weight, running_mean, running_var, save_mean, save_invstd);
return AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_cpu", [&] {
return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_backward_cpu", [&] {
using opmath_t = at::opmath_type<scalar_t>;
if (mixed_type) {
check_mixed_data_type(self, weight, running_mean, running_var, save_mean, save_invstd);
return batch_norm_backward_cpu_template<BFloat16, float>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
return batch_norm_backward_cpu_template<scalar_t, opmath_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
} else {
return batch_norm_backward_cpu_template<scalar_t, scalar_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
}
Expand Down
359 changes: 190 additions & 169 deletions aten/src/ATen/native/cpu/batch_norm_kernel.cpp

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions aten/src/ATen/native/cpu/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* p
return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
}

inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr) {
return convert_half_float(Vectorized<Half>::loadu(ptr));
}

inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
using Vec = Vectorized<float>;
return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
Expand All @@ -87,6 +91,10 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* p
return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
}

inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr, int64_t count) {
return convert_half_float(Vectorized<Half>::loadu(ptr, count));
}

inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
using Vec = Vectorized<float>;
if (count > Vec::size()) {
Expand Down
11 changes: 11 additions & 0 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,12 @@ def _run_test_output_match(
# Relax atol and rtol for float32 based on empirical results
rtol = 1e-5
atol = 2e-5
elif (
dtype == torch.float16
and op.name in test_suite.fp16_low_precision_list
):
rtol = 1e-2
atol = 1e-3
else:
rtol = None
atol = None
Expand Down Expand Up @@ -707,6 +713,11 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
op_level_debug: bool = False
dynamic_shapes: bool = False

fp16_low_precision_list = [
"nn.functional.batch_norm",
"native_batch_norm",
]

@common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
Expand Down
10 changes: 5 additions & 5 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,8 @@ def run_meta_crossref(
meta_function_device_skips = defaultdict(dict)

meta_function_device_expected_failures['cpu'] = {
torch.native_batch_norm: {bf16},
torch._native_batch_norm_legit: {bf16},
torch.native_batch_norm: {bf16, f16},
torch._native_batch_norm_legit: {bf16, f16},
torch.native_layer_norm: {bf16},
}

Expand Down Expand Up @@ -836,9 +836,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
meta_dispatch_device_skips = defaultdict(dict)

meta_dispatch_device_expected_failures['cpu'] = {
aten.native_batch_norm.default: {bf16},
aten._native_batch_norm_legit.default: {bf16},
aten._native_batch_norm_legit.no_stats: {bf16},
aten.native_batch_norm.default: {bf16, f16},
aten._native_batch_norm_legit.default: {bf16, f16},
aten._native_batch_norm_legit.no_stats: {bf16, f16},
aten.native_layer_norm.default: {bf16},
}

Expand Down
4 changes: 4 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10806,10 +10806,14 @@ class TestConsistency(TestCaseMPS):
'nn.functional.normalize',
'nn.functional.triplet_margin_loss',
'nn.functional.triplet_margin_with_distance_loss',
'nn.functional.batch_norm',
'nn.functional.instance_norm',
'round', 'xlogy', 'addcmul',
'nn.functional.max_pool2d',
'nn.functional.gelu',
'nn.functional.glu',
'_native_batch_norm_legit',
'native_batch_norm',

# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
Expand Down
69 changes: 39 additions & 30 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5049,7 +5049,7 @@ def func(root):
gradcheck(func, [v])
gradgradcheck(func, [v])

# test hardtanh backward froo large tensor
# test hardtanh backward for large tensor
def test_hardtanh_backward(self):
x = torch.randn(128, 10000, requires_grad=True)
grad = torch.randn(128, 10000)
Expand All @@ -5062,7 +5062,7 @@ def test_hardtanh_backward(self):
self.assertEqual(x.grad, x_grad_ref)

def test_batchnorm_nhwc_cpu(self):
def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last):
def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last, precision=None):
channels = size[1]
input = torch.randn(size, dtype=dtype, device='cpu', requires_grad=True)
input = input.contiguous(memory_format=format).to(dtype)
Expand Down Expand Up @@ -5090,20 +5090,25 @@ def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last
self.assertTrue(out.is_contiguous(memory_format=format))
self.assertTrue(ref_out.is_contiguous())
self.assertEqual(out, ref_out)
self.assertEqual(bn.weight.grad, ref_bn.weight.grad)
self.assertEqual(bn.weight.grad, ref_bn.weight.grad, atol=precision, rtol=precision)
self.assertEqual(bn.bias.grad, ref_bn.bias.grad)
self.assertEqual(input.grad, ref_input.grad)

# test NC11 and N1HW; test mixed dtype
for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1)]:
helper(self, nn.BatchNorm2d, shape, torch.float, False, torch.channels_last)
helper(self, nn.BatchNorm2d, shape, torch.bfloat16, False, torch.channels_last)
helper(self, nn.BatchNorm2d, shape, torch.bfloat16, True, torch.channels_last)
for dtype in [torch.float, torch.bfloat16, torch.float16]:
for mixed_dtype in [False, True]:
if dtype == torch.float:
mixed_dtype = False
helper(self, nn.BatchNorm2d, shape, dtype, mixed_dtype, torch.channels_last)

precisons = {torch.float: 1e-4, torch.bfloat16: None, torch.float16: None}
for shape in [(4, 8, 2, 10, 10), (4, 1, 2, 9, 9), (4, 9, 1, 1, 1)]:
helper(self, nn.BatchNorm3d, shape, torch.float, False, torch.channels_last_3d)
helper(self, nn.BatchNorm3d, shape, torch.bfloat16, False, torch.channels_last_3d)
helper(self, nn.BatchNorm3d, shape, torch.bfloat16, True, torch.channels_last_3d)
for dtype in [torch.float, torch.bfloat16, torch.float16]:
for mixed_dtype in [False, True]:
if dtype == torch.float:
mixed_dtype = False
helper(self, nn.BatchNorm3d, shape, dtype, mixed_dtype, torch.channels_last_3d, precisons[dtype])

@parametrize_test(
'bn_module',
Expand All @@ -5113,32 +5118,36 @@ def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last
],
)
def test_batchnorm_non_contig_cpu(self, bn_module):
input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu()
input = input.permute(0, 2, 1, 3)
def helper(self, dtype):
input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu()
input = input.permute(0, 2, 1, 3)

bn = bn_module(2).cpu().float().eval()
bn.weight.data.uniform_()
bn.bias.data.uniform_()
bn = bn_module(2).cpu().float().eval()
bn.weight.data.uniform_()
bn.bias.data.uniform_()

ref_input = input.detach().clone().contiguous()
ref_bn = nn.BatchNorm2d(2).cpu().float().eval()
ref_bn.load_state_dict(bn.state_dict())
ref_input = input.detach().clone().contiguous()
ref_bn = nn.BatchNorm2d(2).cpu().float().eval()
ref_bn.load_state_dict(bn.state_dict())

out = bn(input)
ref_out = ref_bn(ref_input)
out = bn(input)
ref_out = ref_bn(ref_input)

self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_out.is_contiguous())
self.assertEqual(out, ref_out)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_out.is_contiguous())
self.assertEqual(out, ref_out)

input_bf = torch.arange(24, dtype=torch.bfloat16).reshape(1, 3, 2, 4)
input_bf = input_bf.permute(0, 2, 1, 3)
input_f = input_bf.float()
bn_mix = bn_module(2).float().eval()
ref_bn_f = deepcopy(bn_mix)
out_bf = bn_mix(input_bf)
ref_out_bf = ref_bn_f(input_f)
self.assertEqual(ref_out_bf, out_bf.float(), atol=0.05, rtol=0.05)
input_bf = torch.arange(24, dtype=dtype).reshape(1, 3, 2, 4)
input_bf = input_bf.permute(0, 2, 1, 3)
input_f = input_bf.float()
bn_mix = bn_module(2).float().eval()
ref_bn_f = deepcopy(bn_mix)
out_bf = bn_mix(input_bf)
ref_out_bf = ref_bn_f(input_f)
self.assertEqual(ref_out_bf, out_bf.float(), atol=0.05, rtol=0.05)

helper(self, torch.bfloat16)
helper(self, torch.float16)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
Expand Down
12 changes: 4 additions & 8 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12302,8 +12302,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
)),
OpInfo('native_batch_norm',
aten_name='native_batch_norm',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
Expand Down Expand Up @@ -12333,8 +12332,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
OpInfo('_native_batch_norm_legit',
aten_name='_native_batch_norm_legit',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
Expand Down Expand Up @@ -12787,8 +12785,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_expanded_weight=True,),
OpInfo('nn.functional.instance_norm',
# no ref because instance_norm will often have numerical instability (large numbers or nan)
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -13953,8 +13950,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details
OpInfo('nn.functional.batch_norm',
aten_name='batch_norm',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down

0 comments on commit 6065e7a

Please sign in to comment.