Fix int32 overflow in CUDA Gather kernel for large tensors#28108
Merged
justinchuby merged 2 commits intomainfrom Apr 17, 2026
Merged
Fix int32 overflow in CUDA Gather kernel for large tensors#28108justinchuby merged 2 commits intomainfrom
justinchuby merged 2 commits intomainfrom
Conversation
The _GatherKernel used CUDA_LONG (int32_t) for input_index, which overflows when the input tensor has more than INT32_MAX (~2.1B) elements. For example, Gemma4's embed_tokens_per_layer embedding table is [262144, 8960] = 2.35B elements. Token IDs >= 239674 cause the offset computation to exceed INT32_MAX, resulting in an illegal memory access on CUDA. Fix: use int64_t for input_index and explicitly cast input_block_index to int64_t before multiplication to ensure the entire expression is evaluated in 64-bit arithmetic. Fixes #28107 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <justinchu@microsoft.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR fixes an integer overflow in the CUDA EP Gather kernel that could compute a wrapped/negative input offset for very large input tensors (> INT32_MAX elements), leading to CUDA illegal memory accesses.
Changes:
- Switch
input_indexcomputation in_GatherKernelfromCUDA_LONG(int32_t) toint64_t. - Ensure 64-bit arithmetic by explicitly casting
input_block_indextoint64_tbefore multiplication. - Add an inline comment documenting the overflow scenario and a concrete large-embedding example.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
approved these changes
Apr 17, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The
_GatherKernelingather_impl.cuusesCUDA_LONG(int32_t) forinput_index. When the input tensor has more thanINT32_MAX(~2.1 billion) elements, the offset computation overflows, causing an illegal memory access (CUDA error 700).Concrete example: Gemma4's per-layer embedding table is
[262144, 8960]= 2.35 billion elements. Any token ID ≥ 239674 triggers the overflow because:Fix
Change
input_indexfromCUDA_LONG(int32_t) toint64_t, and explicitly castinput_block_indextoint64_tbefore multiplication. The other operands (input_block_size,idx) are alreadyint64_t, so the full expression evaluates in 64-bit arithmetic.Reproduction
See the minimal repro script in issue #28107. On any CUDA GPU with ORT 1.24.x:
Motivation and Context
This affects any model with an embedding table exceeding ~2B elements. Currently blocks Gemma4 multimodal inference on CUDA EP since special tokens like
<|image|>(ID 255999) are above the overflow threshold.Fixes #28107