Skip to content

Commit

Permalink
[CUDA] 64-bit indexing fixes for cross-entropy kernels (pytorch#112096)
Browse files Browse the repository at this point in the history
For pytorch#108345, pytorch#111484

Addresses the forward kernels implicated in the issues, but will take another look at the backward kernels (in follow-up PRs if necessary).

The spatial softmax kernel is changed to use signed integer indexing rather than unsigned as `ScalarType` only has signed integer types declared for now, but this should be a minor change.

CC @ptrblck @crcrpar (who landed a few related PRs recently).

Pull Request resolved: pytorch#112096
Approved by: https://github.com/mikaylagawarecki
  • Loading branch information
eqy authored and pytorchmergebot committed Nov 6, 2023
1 parent a50f6d3 commit e396687
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 61 deletions.
61 changes: 33 additions & 28 deletions aten/src/ATen/native/cuda/NLLLoss2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/cuda/CUDAException.h>
#include <c10/macros/Macros.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cuda/block_reduce.cuh>

Expand Down Expand Up @@ -74,7 +75,7 @@ __global__ void nll_loss2d_forward_no_reduce_kernel(
}
}

template <typename scalar_t, typename accscalar_t>
template <typename scalar_t, typename accscalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_forward_kernel(
scalar_t* output,
Expand All @@ -91,14 +92,14 @@ __global__ void nll_loss2d_forward_kernel(
accscalar_t input_sum = 0;
accscalar_t acc_weight = 0;

int sample = blockIdx.x / blocks_per_sample;
int toffset = sample * map_nelem;
int ioffset = sample * map_nelem * n_classes;
index_t sample = blockIdx.x / blocks_per_sample;
index_t toffset = sample * map_nelem;
index_t ioffset = sample * map_nelem * n_classes;
int step = blockDim.x * blocks_per_sample;
for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem;
i += step) {
int64_t t = target[toffset + i];
index_t t = target[toffset + i];
if (t != ignore_index) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
cur_weight = weight != nullptr ? weight[t] : static_cast<scalar_t>(1);
Expand Down Expand Up @@ -318,29 +319,33 @@ void nll_loss2d_forward_out_cuda_template(
"nll_loss2d_forward_kernel",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
nll_loss2d_forward_kernel<scalar_t, accscalar_t>
<<<total_blocks,
CUDA_NUM_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
output.mutable_data_ptr<scalar_t>(),
total_weight.mutable_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
optional_data<scalar_t>(weight_),
input_.size(1),
input_.size(2) * input_.size(3),
blocks_per_sample,
ignore_index);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Divide by total_weight
if (reduction == at::Reduction::Mean) {
nll_loss2d_forward_size_average_kernel<scalar_t>
<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output.mutable_data_ptr<scalar_t>(),
total_weight.const_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
AT_DISPATCH_INDEX_TYPES(
at::native::canUse32BitIndexMath(input_, INT_MAX) ? ScalarType::Int : ScalarType::Long,
"nll_loss2d_forward_launcher", [&] {
nll_loss2d_forward_kernel<scalar_t, accscalar_t, index_t>
<<<total_blocks,
CUDA_NUM_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
output.mutable_data_ptr<scalar_t>(),
total_weight.mutable_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
optional_data<scalar_t>(weight_),
input_.size(1),
input_.size(2) * input_.size(3),
blocks_per_sample,
ignore_index);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Divide by total_weight
if (reduction == at::Reduction::Mean) {
nll_loss2d_forward_size_average_kernel<scalar_t>
<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output.mutable_data_ptr<scalar_t>(),
total_weight.const_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
});
}
Expand Down
71 changes: 38 additions & 33 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/PersistentSoftmax.cuh>
#include <ATen/native/IndexingUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -211,20 +212,20 @@ T spatialBlockReduceX(T *shared, T val) {
return shared[0];
}

template <typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
template <typename scalar_t, typename accscalar_t, typename outscalar_t, typename index_t, template<typename, typename, typename> class Epilogue>
__global__ void cunn_SpatialSoftMaxForward(
outscalar_t *output, const scalar_t *input,
uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
index_t outer_size, index_t dim_size, index_t inner_size)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
const uint32_t outer_stride = inner_size * dim_size;
const uint32_t dim_stride = inner_size;
const index_t outer_stride = inner_size * dim_size;
const index_t dim_stride = inner_size;

for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
const uint32_t outer_offset = outer_index * outer_stride;
for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
const uint32_t data_offset = outer_offset + inner_index;
for (index_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
const index_t outer_offset = outer_index * outer_stride;
for (index_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
const index_t data_offset = outer_offset + inner_index;
////////////////////////////////////////////////////////////
// These two blocks are really equivalent, but specializing on
// blockDim.x == 1 makes the kernel faster when it's unused.
Expand All @@ -234,33 +235,33 @@ __global__ void cunn_SpatialSoftMaxForward(

if (blockDim.x > 1) {
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
max_input = Max<accscalar_t>()(max_input, value);
}
max_input = spatialBlockReduceX<accscalar_t, Max>(sdata,max_input);

accscalar_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
- max_input);
sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);

Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
} else {
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
max_input = Max<accscalar_t>()(max_input, value);
}
accscalar_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
- max_input);
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
}
}
Expand Down Expand Up @@ -791,25 +792,29 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
dim3 grid, block;
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] {
using accscalar_t = acc_type<scalar_t, true>;
if (!half_to_float) {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
output.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
output.mutable_data_ptr<accscalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
AT_DISPATCH_INDEX_TYPES(
at::native::canUse32BitIndexMath(input, INT_MAX) ? ScalarType::Int : ScalarType::Long,
"host_softmax_launcher", [&] {
if (!half_to_float) {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, index_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, index_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
output.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, index_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, index_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
output.mutable_data_ptr<accscalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
});
}
}
Expand Down
28 changes: 28 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11563,6 +11563,19 @@ def test_nll_loss_large_tensor(self, device, reduction):
with torch.no_grad():
self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol))

# Ref: https://github.com/pytorch/pytorch/issue/108345
@onlyCUDA
@largeTensorTest("20GB", "cpu")
@largeTensorTest("20GB", "cuda")
@parametrize_test("reduction", ("none", "mean", "sum"))
def test_cross_entropy_64bit(self, device, reduction):
labels = torch.zeros(190, 50, dtype=torch.long, device=device)
logits = torch.ones(190, 229000, 50, dtype=torch.float, device=device)
loss = torch.nn.functional.cross_entropy(logits, labels)
loss_cpu = torch.nn.functional.cross_entropy(logits.cpu(), labels.cpu())
print(logits.numel(), labels.numel(), loss.numel())
self.assertTrue(torch.allclose(loss_cpu, loss.cpu(), rtol=1e-4, atol=1e-4))

def _nll_loss_helper(self, input_size, reduction, expected, device):
input = torch.rand(input_size, requires_grad=True, device=device)
num_channels = input_size[1]
Expand Down Expand Up @@ -12739,6 +12752,21 @@ def compare_scaling(grads):
clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach)
self.assertEqual(p1.grad, p2.grad)

# reference issue: https://github.com/pytorch/pytorch/issues/111484
@onlyCUDA
@largeTensorTest("30GB", "cuda")
def test_softmax_forward_64bit_indexing(self, device):
batch_size = 70
seq_len = 2048
vocab_size = 50000

shift_labels = torch.zeros(batch_size, seq_len - 1, dtype=torch.long, device=device)
logits = torch.ones(batch_size, seq_len - 1, vocab_size, dtype=torch.float16, device=device)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
nll = loss_fct(logits.permute(0, 2, 1), shift_labels).float()
rtol, atol = torch.testing._comparison.get_tolerances(torch.float16, rtol=None, atol=None)
self.assertEqual(nll, torch.ones_like(nll) * torch.log(torch.tensor(vocab_size)), rtol=rtol, atol=atol)

@onlyCUDA
@largeTensorTest("20GB", "cuda")
def test_softmax_backward_64bit_indexing(self, device):
Expand Down

0 comments on commit e396687

Please sign in to comment.