Skip to content

Fix int32 overflow in CUDA Gather kernel for large tensors#28108

Merged
justinchuby merged 2 commits intomainfrom
justinchu/fix-gather-int32-overflow
Apr 17, 2026
Merged

Fix int32 overflow in CUDA Gather kernel for large tensors#28108
justinchuby merged 2 commits intomainfrom
justinchu/fix-gather-int32-overflow

Conversation

@justinchuby
Copy link
Copy Markdown
Contributor

Description

The _GatherKernel in gather_impl.cu uses CUDA_LONG (int32_t) for input_index. When the input tensor has more than INT32_MAX (~2.1 billion) elements, the offset computation overflows, causing an illegal memory access (CUDA error 700).

Concrete example: Gemma4's per-layer embedding table is [262144, 8960] = 2.35 billion elements. Any token ID ≥ 239674 triggers the overflow because:

239674 × 8960 + 8959 = 2,147,487,999 > INT32_MAX (2,147,483,647)

Fix

Change input_index from CUDA_LONG (int32_t) to int64_t, and explicitly cast input_block_index to int64_t before multiplication. The other operands (input_block_size, idx) are already int64_t, so the full expression evaluates in 64-bit arithmetic.

Reproduction

See the minimal repro script in issue #28107. On any CUDA GPU with ORT 1.24.x:

import numpy as np, onnxruntime as ort, onnx
from onnx import helper, TensorProto

rows, cols = 262144, 8960
data = np.zeros((rows, cols), dtype=np.float32)
indices = np.array([255999], dtype=np.int64)  # > row 239674

graph = helper.make_graph(
    [helper.make_node("Gather", ["data", "indices"], ["out"], axis=0)],
    "g",
    [helper.make_tensor_value_info("data", TensorProto.FLOAT, [rows, cols]),
     helper.make_tensor_value_info("indices", TensorProto.INT64, [1])],
    [helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, cols])],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)])
onnx.save(model, "/tmp/gather.onnx")

sess = ort.InferenceSession("/tmp/gather.onnx", providers=["CUDAExecutionProvider"])
sess.run(None, {"data": data, "indices": indices})  # CRASH: illegal memory access

Motivation and Context

This affects any model with an embedding table exceeding ~2B elements. Currently blocks Gemma4 multimodal inference on CUDA EP since special tokens like <|image|> (ID 255999) are above the overflow threshold.

Fixes #28107

The _GatherKernel used CUDA_LONG (int32_t) for input_index, which
overflows when the input tensor has more than INT32_MAX (~2.1B)
elements. For example, Gemma4's embed_tokens_per_layer embedding
table is [262144, 8960] = 2.35B elements. Token IDs >= 239674
cause the offset computation to exceed INT32_MAX, resulting in an
illegal memory access on CUDA.

Fix: use int64_t for input_index and explicitly cast input_block_index
to int64_t before multiplication to ensure the entire expression is
evaluated in 64-bit arithmetic.

Fixes #28107

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/core/providers/cuda/tensor/gather_impl.cu Outdated
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes an integer overflow in the CUDA EP Gather kernel that could compute a wrapped/negative input offset for very large input tensors (> INT32_MAX elements), leading to CUDA illegal memory accesses.

Changes:

  • Switch input_index computation in _GatherKernel from CUDA_LONG (int32_t) to int64_t.
  • Ensure 64-bit arithmetic by explicitly casting input_block_index to int64_t before multiplication.
  • Add an inline comment documenting the overflow scenario and a concrete large-embedding example.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA Gather kernel crashes with illegal memory access on tensors with >2^31 elements (int32 overflow)

3 participants