Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ __global__ void _UnaryElementWise(
const InT* input_data,
OutT* output_data,
const FuncT functor,
CUDA_LONG N) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
int64_t N) {
int64_t start = static_cast<int64_t>(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
InT value[NumElementsPerThread];

CUDA_LONG id = start;
int64_t id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
Expand Down Expand Up @@ -47,8 +47,8 @@ void UnaryElementWiseImpl(
if (count == 0) // special case where there's a dim value of 0 in the shape
return;

int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
CUDA_LONG N = static_cast<CUDA_LONG>(count);
int blocksPerGrid = static_cast<int>(CeilDiv(count, static_cast<size_t>(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread));
int64_t N = static_cast<int64_t>(count);
_UnaryElementWise<InT, OutT, FuncT, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread>
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_data,
Expand Down
28 changes: 14 additions & 14 deletions onnxruntime/core/providers/cuda/tensor/cast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ struct CastStd<Float4E2M1x2, float> {
#endif // DISABLE_FLOAT4_TYPES

template <int NumThreadsPerBlock, int NumElementsPerThread, typename OutT, typename InT>
__global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastStd<OutT, InT> cast) {
CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
__global__ void CastKernelStd(const InT* input, OutT* output, int64_t N, CastStd<OutT, InT> cast) {
int64_t id = static_cast<int64_t>(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x;

#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
Expand All @@ -237,11 +237,11 @@ Status CudaCastStd(cudaStream_t stream, const InT* input, OutT* output, size_t n
if (num_of_elements <= 0)
return Status::OK();

int blocksPerGrid = static_cast<int>(CeilDiv(num_of_elements, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
int blocksPerGrid = static_cast<int>(CeilDiv(num_of_elements, static_cast<size_t>(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread));
CastKernelStd<GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread, OutT, InT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input,
output,
static_cast<int>(num_of_elements),
static_cast<int64_t>(num_of_elements),
CastStd<OutT, InT>());
return Status::OK();
}
Expand All @@ -251,10 +251,10 @@ Status CudaCastStd(cudaStream_t stream, const InT* input, OutT* output, size_t n
template <int NumThreadsPerBlock, int NumElementsPerThread, bool is_odd, typename OutPairType, typename InPairType,
typename OutSingleType, typename InSingleType>
__global__ void CudaCastPairwiseKernel(const InPairType* input, OutPairType* output,
CUDA_LONG pair_count,
int64_t pair_count,
CastStd<OutPairType, InPairType> pair_caster,
CastStd<OutSingleType, InSingleType> singleton_caster) {
CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
int64_t id = static_cast<int64_t>(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x;

#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
Expand Down Expand Up @@ -284,9 +284,9 @@ Status CudaCastPairwise(cudaStream_t stream, const Float4E2M1x2* input, float* o

bool is_odd = (num_of_elements & 0x01) != 0;

int pair_count = static_cast<int>(num_of_elements / 2);
int64_t pair_count = static_cast<int64_t>(num_of_elements / 2);
Comment thread
tianleiwu marked this conversation as resolved.
Outdated

int blocksPerGrid = static_cast<int>(CeilDiv(pair_count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
int blocksPerGrid = static_cast<int>(CeilDiv(static_cast<size_t>(pair_count), static_cast<size_t>(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread));

if (pair_count == 0) {
blocksPerGrid = 1;
Expand Down Expand Up @@ -318,9 +318,9 @@ Status CudaCastPairwise(cudaStream_t stream, const float* input, Float4E2M1x2* o

bool is_odd = (num_of_elements & 0x01) != 0;

int pair_count = static_cast<int>(num_of_elements / 2);
int64_t pair_count = static_cast<int64_t>(num_of_elements / 2);

int blocksPerGrid = static_cast<int>(CeilDiv(pair_count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
int blocksPerGrid = static_cast<int>(CeilDiv(static_cast<size_t>(pair_count), static_cast<size_t>(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread));

if (pair_count == 0) {
blocksPerGrid = 1;
Expand Down Expand Up @@ -353,8 +353,8 @@ template Status CudaCastPairwise<Float4E2M1x2, float>(cudaStream_t stream, const
#if !defined(DISABLE_FLOAT8_TYPES)

template <int NumThreadsPerBlock, int NumElementsPerThread, typename OutT, typename InT>
__global__ void CastKernelSat(const InT* input, OutT* output, CUDA_LONG N, CastSat<OutT, InT> cast, bool saturate) {
CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
__global__ void CastKernelSat(const InT* input, OutT* output, int64_t N, CastSat<OutT, InT> cast, bool saturate) {
int64_t id = static_cast<int64_t>(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x;

#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
Expand All @@ -370,11 +370,11 @@ Status CudaCastSat(cudaStream_t stream, const InT* input, OutT* output, size_t n
if (num_of_element <= 0)
return Status::OK();

int blocksPerGrid = static_cast<int>(CeilDiv(num_of_element, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
int blocksPerGrid = static_cast<int>(CeilDiv(num_of_element, static_cast<size_t>(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread));
CastKernelSat<GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread, OutT, InT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input,
output,
static_cast<int>(num_of_element),
static_cast<int64_t>(num_of_element),
CastSat<OutT, InT>(),
saturate);
return Status::OK();
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3127,5 +3127,29 @@ TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) {
}
}

// Regression test for CUDA Cast kernel int32 overflow (same family as
// https://github.com/microsoft/onnxruntime/issues/28107).
// The CUDA Cast kernel used CUDA_LONG (int32_t) for element indices, which caused
// int32 overflow and illegal memory access on tensors with >2^31 elements.
// This test validates that a large tensor cast works correctly. The full reproducer
// requires >2^31 elements (>8GB for float), so this test uses a moderately large
// tensor to exercise the same code path and validate correctness.
TEST(CastOpTest, LargeTensorCastNoCrash) {
// Use a tensor large enough to be meaningful but not require excessive memory.
// 2^24 = 16M elements is enough to exercise the kernel grid calculation while
// staying within typical CI GPU memory limits.
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
constexpr int64_t num_elements = 1 << 24; // 16M elements
const std::vector<int64_t> shape = {num_elements};

std::vector<float> input(num_elements);
std::vector<int32_t> expected(num_elements);
for (int64_t i = 0; i < num_elements; ++i) {
input[i] = static_cast<float>(i % 1000);
expected[i] = static_cast<int32_t>(i % 1000);
}

TestCastOp<float, int32_t>(gsl::make_span(input), gsl::make_span(expected), shape);
}

} // namespace test
} // namespace onnxruntime
Loading