Fix int32 overflow in CUDA Cast and UnaryElementWise kernels for tensors with >2^31 elements#28386
Fix int32 overflow in CUDA Cast and UnaryElementWise kernels for tensors with >2^31 elements#28386
Conversation
…ors with >2^31 elements Switch per-thread element index from CUDA_LONG (int32_t) to int64_t in: - _UnaryElementWise kernel (cu_inc/unary_elementwise_impl.cuh) - CastKernelStd kernel (tensor/cast_op.cu) - CastKernelSat kernel (tensor/cast_op.cu) - CudaCastPairwiseKernel (tensor/cast_op.cu) Also fix the launch functions to pass element count as int64_t instead of truncating via static_cast<int>, and fix blocksPerGrid calculation to avoid int32 overflow in the intermediate multiplication. Add regression test for large tensor cast. Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/0b1e04ca-17bd-4f26-aaec-728240d54577 Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/0b1e04ca-17bd-4f26-aaec-728240d54577 Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
Correct and well-scoped fix for a real int32 overflow bug in CUDA Cast and UnaryElementWise kernels. The changes consistently replace CUDA_LONG (int32_t) with int64_t across kernel parameters and index calculations, matching the same fix pattern applied to Gather in PR #28108.
Positives:
- The
static_cast<int64_t>(NumElementsPerThread)correctly anchors the multiplication chain in 64-bit arithmetic before multiplying withblockIdx.x, preventing intermediate overflow. - The
unary_elementwise_impl.cuhheader change propagates the fix to all unary elementwise ops (Abs, Neg, Sqrt, Log, Exp, Erf, etc.) in a single edit. - All three cast kernel variants (
CastKernelStd,CastKernelSat,CudaCastPairwiseKernel) are consistently updated — no kernel was missed. - Removal of
static_cast<int>(num_of_elements)truncation in the launch functions is the most important part, since that's wheresize_t → int32_tsilently lost high bits.
Broader concern (out of scope): The CALCULATE_ELEMENTWISE_INDEX_OR_EXIT macro still uses CUDA_LONG and is used by dozens of other CUDA kernels (expand, tile, scatter_nd, resize, upsample, etc.) — they have the same int32 overflow vulnerability. Consider filing a follow-up issue to track the systemic fix.
| TEST(CastOpTest, LargeTensorCastNoCrash) { | ||
| // Use a tensor large enough to be meaningful but not require excessive memory. | ||
| // 2^24 = 16M elements is enough to exercise the kernel grid calculation while | ||
| // staying within typical CI GPU memory limits. |
There was a problem hiding this comment.
16M elements (2^24) is far below INT32_MAX (2^31). The old code with CUDA_LONG indices would also pass this test — this does not provide regression protection against someone accidentally reverting the index type back to CUDA_LONG.
The test name LargeTensorCastNoCrash and comment "Regression test for CUDA Cast kernel int32 overflow" overstate what it validates. It's a useful correctness smoke test, but not an overflow regression test.
Options to improve:
- Rename to
CastKernelCorrectness_ModerateSizeto reflect what it actually tests. - Add a separate test gated on available GPU memory (e.g., skip if < 10 GB free) that allocates >2^31 elements.
- Add a host-side unit test that verifies the grid launch calculation (
blocksPerGrid,N) uses 64-bit arithmetic for counts > INT32_MAX.
| bool is_odd = (num_of_elements & 0x01) != 0; | ||
|
|
||
| int pair_count = static_cast<int>(num_of_elements / 2); | ||
| int64_t pair_count = static_cast<int64_t>(num_of_elements / 2); |
There was a problem hiding this comment.
Minor: num_of_elements is already size_t, so pair_count could stay size_t instead of converting size_t → int64_t here and then back to size_t in the CeilDiv call below (static_cast<size_t>(pair_count)). The current code is correct but the double conversion is mildly surprising.
Description
Switch per-thread element indices from
CUDA_LONG(int32_t) toint64_tin CUDA Cast and UnaryElementWise kernels to prevent illegal memory access on tensors exceeding 2^31 elements.cu_inc/unary_elementwise_impl.cuh: ChangeNparameter and loop index in_UnaryElementWisekernel fromCUDA_LONGtoint64_t. FixblocksPerGridintermediate multiplication to usesize_t. This fixes the overflow for all unary elementwise ops (Cast, Abs, Neg, Sqrt, Log, Exp, Erf, etc.).tensor/cast_op.cu: Same fix forCastKernelStd,CastKernelSat, andCudaCastPairwiseKernel. Removestatic_cast<int>(num_of_elements)truncation in launch functions.cast_op_test.cc: AddLargeTensorCastNoCrashregression test.Before (crashes):
After:
Motivation and Context
Same class of bug fixed in Gather by #28108 (response to #28107). The Cast kernel uses
CUDA_LONG = int32_tfor its element index, which wraps negative once element count crossesINT32_MAX. This hits any causal LM ONNX export wherevocab_size × seq_length > 2^31— practically every long-context HF model on CUDA EP atseq_length ≥ 16K–32K.