Skip to content

Commit

Permalink
add Half support for GroupNorm on CPU (pytorch#100234)
Browse files Browse the repository at this point in the history
### Testing
Single socket (28cores):

* Contiguous:

shape | forward / s| forward / s| backward / s| backward / s
-- | -- | -- | -- | --
  | fp32 | mixed fp32 fp16 | fp32 | mixed fp32 fp16
[10,   128, 10, 10] | 2.45E-05 | 3.26E-05 | 6.87E-05 | 7.40E-05
[10,   128, 80, 80] | 0.000726 | 0.000606 | 0.002183 | 0.001112

* Channels Last:

shape | forward / s| forward / s| backward / s| backward / s
-- | -- | -- | -- | --
  | fp32 | mixed fp32 fp16 | fp32 | mixed fp32 fp16
[10,   128, 10, 10] | 2.88E-05 | 2.72E-05 | 6.56E-05 | 6.63E-05
[10,   128, 80, 80] | 0.00076 | 0.000256 | 0.002385 | 0.000735

Single core:

* Contiguous:

shape | forward / s| forward / s| backward / s| backward / s
-- | -- | -- | -- | --
  | fp32 | mixed fp32 fp16 | fp32 | mixed fp32 fp16
[10,   128, 10, 10] | 9.47E-05 | 1.90E-04 | 2.03E-04 | 3.10E-04
[10,   128, 80, 80] | 6.25E-03 | 8.98E-03 | 0.016485 | 0.01369

* Channels Last:

shape | forward / s| forward / s| backward / s| backward / s
-- | -- | -- | -- | --
  | fp32 | mixed fp32 fp16 | fp32 | mixed fp32 fp16
[10,   128, 10, 10] | 8.66E-05 | 7.89E-05 | 1.95E-04 | 1.43E-04
[10,   128, 80, 80] | 5.97E-03 | 3.13E-03 | 0.01626 | 8.70E-03

Pull Request resolved: pytorch#100234
Approved by: https://github.com/jgong5, https://github.com/mikaylagawarecki
  • Loading branch information
CaoE authored and pytorchmergebot committed Sep 1, 2023
1 parent 54dcb0e commit 8f02884
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 474 deletions.
753 changes: 378 additions & 375 deletions aten/src/ATen/native/cpu/group_norm_kernel.cpp

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/mixed_data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ inline bool is_mixed_type(const Tensor& input, const Args&... parameters) {
}

// currently on CPU, mixed data type is only supported
// when input is 'BFloat16' and parameters are 'Float'
// when input is 'BFloat16' or 'Half' and parameters are 'Float'
inline void check_mixed_data_type(const Tensor& input) {
TORCH_CHECK(input.scalar_type() == ScalarType::BFloat16,
TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()),
"mixed dtype (CPU): all inputs must share same datatype.");
}

Expand Down
64 changes: 33 additions & 31 deletions aten/src/ATen/native/cpu/moments_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace at {
namespace native {
inline namespace CPU_CAPABILITY {

template<typename T> using acc_t = at::opmath_type<T>;
template<typename T> using opmath_t = at::opmath_type<T>;

constexpr int64_t kChunkSize = 16;

Expand Down Expand Up @@ -56,14 +56,15 @@ C10_ALWAYS_INLINE void AddMomentsVec(
}

template <typename T>
inline void UpdateMomentsVec(
inline typename std::enable_if<std::is_same<T, opmath_t<T>>::value, void>::type
UpdateMomentsVec(
int64_t m0,
const T* X_ptr,
const std::array<vec::Vectorized<acc_t<T>>, kChunkSize>& c_vecs,
const std::array<vec::Vectorized<opmath_t<T>>, kChunkSize>& c_vecs,
int64_t& m0_stk0,
vec::Vectorized<acc_t<T>>& m1_stk0,
vec::Vectorized<acc_t<T>>& m2_stk0) {
using Vec = vec::Vectorized<acc_t<T>>;
vec::Vectorized<opmath_t<T>>& m1_stk0,
vec::Vectorized<opmath_t<T>>& m2_stk0) {
using Vec = vec::Vectorized<opmath_t<T>>;
Vec m1_vec(0);
Vec m2_vec(0);
for (const auto j : c10::irange(m0)) {
Expand All @@ -77,22 +78,23 @@ inline void UpdateMomentsVec(

// each bfloat16 vector will be converted to two float vectors,
// and accumulated successively on m1_stk0/m2_stk0.
template <>
inline void UpdateMomentsVec<BFloat16>(
template <typename T>
inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value, void>::type
UpdateMomentsVec(
int64_t m0,
const BFloat16* X_ptr,
const std::array<vec::Vectorized<float>, kChunkSize>& c_vecs,
const T* X_ptr,
const std::array<vec::Vectorized<at::opmath_type<T>>, kChunkSize>& c_vecs,
int64_t& m0_stk0,
vec::Vectorized<float>& m1_stk0,
vec::Vectorized<float>& m2_stk0) {
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
vec::Vectorized<at::opmath_type<T>>& m1_stk0,
vec::Vectorized<at::opmath_type<T>>& m2_stk0) {
using Vec = vec::Vectorized<T>;
using fVec = vec::Vectorized<at::opmath_type<T>>;
fVec m1_fvec0(0), m1_fvec1(0);
fVec m2_fvec0(0), m2_fvec1(0);
for (const auto j : c10::irange(m0)) {
const bVec x_bvec = bVec::loadu(X_ptr + j * bVec::size());
const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
m1_fvec0 += delta_fvec0 * c_vecs[j];
Expand All @@ -109,17 +111,17 @@ inline void UpdateMomentsVec<BFloat16>(
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// https://en.wikipedia.org/wiki/Pairwise_summation
template <typename T, int64_t kMaxDepth>
std::pair<acc_t<T>, acc_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
using T_ACC = acc_t<T>;
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
using math_t = opmath_t<T>;

constexpr int64_t kVecSize = vec::Vectorized<T>::size();
constexpr int64_t kAccVecSize = vec::Vectorized<T_ACC>::size();
constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = utils::CeilLog2(m);

using Vec = vec::Vectorized<T_ACC>;
const Vec kZeroVec(T_ACC(0));
using Vec = vec::Vectorized<math_t>;
const Vec kZeroVec(math_t(0));
c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
Expand All @@ -130,7 +132,7 @@ std::pair<acc_t<T>, acc_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t
static std::array<Vec, kChunkSize> c_vecs = ([]() {
std::array<Vec, kChunkSize> result;
for (const auto i : c10::irange(kChunkSize)) {
result[i] = Vec(T_ACC(1) / static_cast<T_ACC>(i + 1));
result[i] = Vec(math_t(1) / static_cast<math_t>(i + 1));
}
return result;
})();
Expand All @@ -156,19 +158,19 @@ std::pair<acc_t<T>, acc_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t
m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
}

std::array<T_ACC, kAccVecSize> m1_arr{};
std::array<T_ACC, kAccVecSize> m2_arr{};
std::array<math_t, kAccVecSize> m1_arr{};
std::array<math_t, kAccVecSize> m2_arr{};
m1_stk[0].store(m1_arr.data());
m2_stk[0].store(m2_arr.data());

int64_t m0 = 0;
T_ACC m1 = 0;
T_ACC m2 = 0;
math_t m1 = 0;
math_t m2 = 0;
for (int64_t i = n * kVecSize; i < N; ++i) {
T_ACC x = static_cast<T_ACC>(X[i]);
const T_ACC delta = x - m1;
math_t x = static_cast<math_t>(X[i]);
const math_t delta = x - m1;
++m0;
m1 += delta / static_cast<T_ACC>(m0);
m1 += delta / static_cast<math_t>(m0);
m2 += delta * (x - m1);
}
// for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
Expand All @@ -177,11 +179,11 @@ std::pair<acc_t<T>, acc_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t
AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
}

return std::make_pair(m1, m2 / static_cast<T_ACC>(N - ddof));
return std::make_pair(m1, m2 / static_cast<math_t>(N - ddof));
}

template <typename T>
std::pair<acc_t<T>, acc_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
using Vec = vec::Vectorized<T>;
constexpr int64_t kVecSize = Vec::size();
const int64_t n = N / kVecSize;
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/native/group_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm_backward(
c10::MaybeOwned<Tensor> gamma_maybe_owned =
at::borrow_from_optional_tensor(gamma_opt);
const Tensor& gamma = *gamma_maybe_owned;

TORCH_CHECK(
X.suggest_memory_format() == dY.suggest_memory_format(),
"Expected memory formats of X and dY are same.");
TORCH_CHECK(
X.scalar_type() == dY.scalar_type(),
"Expected scalar types of X and dY are same.");
bool mixed_type = is_mixed_type(X, mean, rstd);
if (mixed_type) {
check_mixed_data_type(X, mean, rstd);
Expand Down
115 changes: 54 additions & 61 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8041,13 +8041,13 @@ def _test_GroupNorm_cuda_half(self):
self.assertEqualTypeString(output, input)

def _test_GroupNorm_cpu_mixed_dtype(self):
def helper(self, size, groups, memory_format):
def helper(self, size, groups, memory_format, dtype):
channels = size[1]
input = torch.randn(size, dtype=torch.bfloat16).cpu()
input = torch.randn(size).cpu().to(dtype=dtype)
input_bf1 = input.contiguous(memory_format=memory_format).detach().requires_grad_(True)
input_bf2 = input_bf1.clone().detach().requires_grad_(True)
input_f = input_bf1.float().detach().requires_grad_(True)
m_bf = nn.GroupNorm(groups, channels).cpu().bfloat16()
m_bf = nn.GroupNorm(groups, channels).cpu().to(dtype=dtype)
m_f = deepcopy(m_bf).float()
m_f2 = deepcopy(m_f)
# bfloat16 input and bfloat16 parameters
Expand All @@ -8058,49 +8058,48 @@ def helper(self, size, groups, memory_format):
out3 = m_f2(input_f)
self.assertEqual(out, out2, atol=5e-3, rtol=5e-3)
self.assertEqual(out2.float(), out3, atol=5e-3, rtol=5e-3)
grad_out = torch.randn(out2.shape, dtype=torch.bfloat16).cpu()
grad_out = torch.randn(out2.shape).cpu().to(dtype=dtype)
grad_out_bf1 = grad_out.contiguous(memory_format=memory_format).detach().requires_grad_(True)
grad_out_bf2 = grad_out_bf1.clone().detach().requires_grad_(True)
grad_out_f = grad_out_bf2.clone().float().detach().requires_grad_(True)
# bfloat16 input grad and float parameters
# bfloat16/half input grad and float parameters
out2.backward(grad_out_bf2, retain_graph=True)
# float input grad and float parameters
out3.backward(grad_out_f, retain_graph=True)
# bfloat16 input grad and bfloat16 parameters
# bfloat16/half input grad and bfloat16/half parameters
out.backward(grad_out_bf1, retain_graph=True)
self.assertEqual(m_f.weight.grad, m_f2.weight.grad, atol=1e-5, rtol=1e-5)
# Need higher tolerances atol=1e-4 and rtol=1e-4 on macos
self.assertEqual(m_f.weight.grad, m_f2.weight.grad, atol=1e-4, rtol=1e-4)
self.assertEqual(m_f.bias.grad, m_f2.bias.grad, atol=1e-5, rtol=1e-5)
self.assertEqual(input_bf2.grad.float(), input_f.grad, atol=5e-5, rtol=5e-3)
# Full bf16 has lower precision compared with mixed bf16 and fp32 .
# Full bf16/half has lower precision compared with mixed bf16/half and fp32.
# Use Amp to keep module parameters in acc dtype, i.e. float, for better numerical stability
self.assertEqual(m_bf.weight.grad.float(), m_f.weight.grad, atol=1e-3, rtol=1.2e-1)
self.assertEqual(m_bf.bias.grad.float(), m_f.bias.grad, atol=1e-3, rtol=1e-2)
self.assertEqual(input_bf1.grad, input_bf2.grad, atol=1e-2, rtol=1e-2)

helper(self, (1, 8, 4, 3), 2, torch.contiguous_format)
helper(self, (1, 8, 4, 3), 2, torch.channels_last)
helper(self, (1, 8, 3, 4), 4, torch.contiguous_format)
helper(self, (1, 8, 3, 4), 4, torch.channels_last)
helper(self, (4, 8, 40, 40), 4, torch.contiguous_format),
helper(self, (4, 8, 40, 40), 4, torch.channels_last),
helper(self, (4, 40, 40, 40), 2, torch.contiguous_format)
helper(self, (4, 40, 40, 40), 2, torch.channels_last)
helper(self, (1, 8, 40, 40), 4, torch.contiguous_format)
helper(self, (1, 8, 40, 40), 2, torch.channels_last)
helper(self, (1, 8, 40, 40), 2, torch.contiguous_format)
helper(self, (1, 8, 50, 50), 2, torch.channels_last)
helper(self, (1, 8, 50, 50), 4, torch.contiguous_format)
helper(self, (1, 8, 50, 50), 4, torch.channels_last)
helper(self, (1, 40, 50, 50), 2, torch.contiguous_format)
helper(self, (1, 40, 50, 50), 2, torch.channels_last)
helper(self, (1, 9, 3, 4, 5), 3, torch.contiguous_format)
helper(self, (1, 9, 3, 4, 5), 3, torch.channels_last_3d)
helper(self, (1, 60, 10, 10, 10), 3, torch.contiguous_format)
helper(self, (1, 60, 10, 10, 10), 3, torch.channels_last_3d)
helper(self, (1, 9, 10, 50, 50), 3, torch.contiguous_format)
helper(self, (1, 9, 10, 50, 50), 3, torch.channels_last_3d)
helper(self, (1, 60, 10, 50, 50), 3, torch.contiguous_format)
helper(self, (1, 60, 10, 50, 50), 3, torch.channels_last_3d)
atol = None
rtol = None
if dtype == torch.bfloat16:
atol = 1e-2
rtol = 1.2e-1
else:
assert dtype == torch.half
atol = 5e-3
rtol = 1.5e-2
self.assertEqual(m_bf.weight.grad, m_f.weight.grad.to(dtype=dtype), atol=atol, rtol=rtol)
self.assertEqual(m_bf.bias.grad, m_f.bias.grad.to(dtype=dtype), atol=atol, rtol=rtol)
self.assertEqual(input_bf1.grad, input_bf2.grad, atol=atol, rtol=rtol)

cl_formats = {4: torch.channels_last, 5: torch.channels_last_3d}
for dtype in [torch.bfloat16, torch.half]:
for shape, g in [((1, 8, 4, 3), 2), ((1, 8, 3, 4), 4),
((4, 40, 40, 40), 2), ((4, 8, 40, 40), 4),
((1, 8, 40, 40), 4), ((1, 8, 40, 40), 2),
((1, 8, 50, 50), 2), ((1, 8, 50, 50), 4),
((1, 40, 50, 50), 2), ((1, 9, 3, 4, 5), 3),
((1, 60, 10, 10, 10), 3), ((1, 9, 10, 50, 50), 3),
((1, 60, 10, 50, 50), 3), ((1, 8, 65, 55), 2),
((1, 3, 65, 55), 1), ((1, 3, 20, 20), 1)]:
for is_cl in [False, True]:
format = cl_formats[len(shape)] if is_cl else torch.contiguous_format
helper(self, shape, g, format, dtype)

def _test_module_empty_inputs(self, module, inputs):
for _inp in inputs:
Expand Down Expand Up @@ -8552,7 +8551,7 @@ def test_GroupNorm_empty(self, device):
_test_module_empty_input(self, mod, inp)

@onlyCPU
@dtypes(torch.float, torch.double, torch.bfloat16)
@dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
def test_groupnorm_nhwc(self, device, dtype):
def helper(self, size, groups, memory_format, is_mixed):
channels = size[1]
Expand Down Expand Up @@ -8583,30 +8582,24 @@ def helper(self, size, groups, memory_format, is_mixed):
self.assertTrue(out.is_contiguous(memory_format=memory_format))
self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format))
self.assertEqual(out, ref_out)
# parameters in bfloat16 is not recommended
self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=5e-4, rtol=5e-4)
self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=5e-4, rtol=5e-4)
self.assertEqual(input.grad, ref_input.grad, atol=5e-4, rtol=8e-3)

helper(self, (4, 8, 10, 10), 4, torch.channels_last, False)
helper(self, (2, 30, 9, 9), 3, torch.channels_last, False)
helper(self, (4, 8, 40, 40), 4, torch.channels_last, False)
helper(self, (4, 40, 40, 40), 2, torch.channels_last, False)
helper(self, (2, 30, 50, 50), 3, torch.channels_last, False)
helper(self, (2, 60, 50, 50), 3, torch.channels_last, False)
helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, False)
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, False)
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, False)

helper(self, (4, 8, 10, 10), 4, torch.channels_last, True)
helper(self, (2, 30, 9, 9), 3, torch.channels_last, True)
helper(self, (4, 8, 40, 40), 4, torch.channels_last, True)
helper(self, (4, 40, 40, 40), 2, torch.channels_last, True)
helper(self, (2, 30, 50, 50), 3, torch.channels_last, True)
helper(self, (2, 60, 50, 50), 3, torch.channels_last, True)
helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, True)
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, True)
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, True)
# parameters in bfloat16/Half is not recommended
atol = 5e-4
rtol = 8e-3

self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol)
self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol)
self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol)

for is_mixed in [True, False]:
helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed)
helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed)
helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed)
helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed)
helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed)
helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed)
helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed)
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)

@onlyNativeDeviceTypes
def test_GroupNorm_memory_format(self, device):
Expand Down
3 changes: 1 addition & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12644,8 +12644,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
aten_name='group_norm',
aliases=('group_norm',),
ref=reference_group_norm,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down
7 changes: 5 additions & 2 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3131,13 +3131,16 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad
),
ModuleInfo(torch.nn.GroupNorm,
module_inputs_func=module_inputs_torch_nn_GroupNorm,
dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=False),
dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True),
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64, torch.bfloat16]),
# Tracking at https://github.com/pytorch/pytorch/issues/98089
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
'TestModule', 'test_memory_format', device_type='cpu'),
# No channels_last support for GroupNorm currently.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cuda'),
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='mps'),
DecorateInfo(unittest.skip("Skipped!"), "TestModule", "test_grad",
active_if=TEST_WITH_ROCM, device_type='cuda'),)
),
Expand Down

0 comments on commit 8f02884

Please sign in to comment.