Skip to content

[fp8_blockwise]Fix int32 overflow in TRTLLM fused MoE activation kernel#2642

Merged
IwakuraRein merged 2 commits intoflashinfer-ai:mainfrom
charlotte12l:ima_fp8_blockwise
Mar 5, 2026
Merged

[fp8_blockwise]Fix int32 overflow in TRTLLM fused MoE activation kernel#2642
IwakuraRein merged 2 commits intoflashinfer-ai:mainfrom
charlotte12l:ima_fp8_blockwise

Conversation

@charlotte12l
Copy link
Copy Markdown
Contributor

@charlotte12l charlotte12l commented Feb 26, 2026

📌 Description

Fix CUDA Illegal Memory Access (IMA) caused by int32 overflow in activationKernel and activationDeepSeekKernel in the TRTLLM fused MoE pipeline.

Root cause: The index computation permutedIdx * params.innerDim + hiddenIdx uses int32 arithmetic. With large MoE configurations (e.g. 256 global experts, topK=8, DP=2, EP=2), the values can exceed INT32_MAX:

  • num_tokens = 65536 (max_num_batched_tokens * DP)
  • totalNumPaddedTokens up to 524,288 65536 * 8, worst case all tokens route to local experts)
  • innerDim = 2 * intermediate_size, suppose its >5k
  • 524,287 * innerDim may be > INT32_MAX (2,147,483,647)

The overflow produces a negative index, causing out-of-bounds memory access.
Fix: Cast permutedIdx to int64_t before the multiplication in both
activationKernel (line 82) and activationDeepSeekKernel (line 337).

The overflow may also cause issue in other places, e.g. #2643, but I don't have time to validate #2643 yet.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • [ x] I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • [x ] I have installed the hooks with pre-commit install.
  • [ x] I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

Verified locally with the same model, works

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Fixed integer overflow issues in tensor indexing calculations, enabling proper support for larger tensor dimensions without overflow errors. Improves stability for large-scale tensor processing operations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 26, 2026

📝 Walkthrough

Walkthrough

The pull request modifies a CUDA kernel file to convert 32-bit index calculations to 64-bit (int64_t) across multiple index computations in activationKernel and activationDeepSeekKernel functions. This prevents integer overflow when handling larger tensor dimensions and permuted indices.

Changes

Cohort / File(s) Summary
Index Type Conversions for Overflow Prevention
csrc/trtllm_fused_moe_dev_kernel.cu
Converted baseIdx, outIdx, scale1Idx, scale2Idx, and scaleOut_idx from 32-bit int to int64_t across activationKernel and activationDeepSeekKernel to prevent overflow in large tensor indexing operations with permutedIdx and token/scale offsets.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

  • fix: DeepSeek activation uninitialized data #2128: Both PRs modify activationDeepSeekKernel in the same kernel file; this PR addresses index overflow prevention while the related PR introduces thread-constant and per-token CTA zero-initialization logic.

Suggested reviewers

  • wenscarl
  • cyx-6
  • djmmoss
  • yzh119

Poem

🐰 Overflow be gone, we've grown so tall,
Indices that once fit now overflow it all,
From thirty-two to sixty-four we climb,
Larger tensors handled just in time! 🎯

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: fixing int32 overflow in TRTLLM fused MoE activation kernels, which aligns with the primary purpose of this PR.
Description check ✅ Passed The PR description comprehensively documents the issue, root cause, and fix with concrete examples and line references.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @charlotte12l, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical int32 integer overflow bug in the TRTLLM fused MoE activation kernels, specifically affecting activationKernel and activationDeepSeekKernel. The overflow, which occurred during index computations with large MoE configurations, led to negative indices and subsequent CUDA Illegal Memory Access errors. By casting problematic int variables to int64_t for index calculations, the change prevents these overflows, ensuring the stability and correct operation of the MoE pipeline under high load.

Highlights

  • Integer Overflow Fix: Corrected int32 overflow issues in index calculations within activationKernel and activationDeepSeekKernel in the TRTLLM fused MoE activation kernels.
  • Memory Access Error Prevention: Addressed CUDA Illegal Memory Access (IMA) errors that occurred due to negative indices resulting from the int32 overflow.
  • Enhanced Robustness: Improved the stability and correctness of the MoE pipeline for large configurations by ensuring index calculations can handle larger values through the use of int64_t.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_dev_kernel.cu
    • Updated index calculations for baseIdx and outIdx in activationKernel to use int64_t to prevent integer overflow.
    • Modified index calculations for baseIdx, scale1Idx, scale2Idx, scaleOut_idx, and outIdx in activationDeepSeekKernel to use int64_t to prevent integer overflow.
Activity
  • No specific activity (comments, reviews, etc.) was recorded for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes an int32 overflow bug in activationKernel and activationDeepSeekKernel by promoting index calculations to int64_t. The changes are accurate and address the reported issue. However, my review of csrc/trtllm_fused_moe_dev_kernel.cu indicates that several other kernels in the file suffer from the same potential for int32 overflow in index calculations. These kernels include permuteKernel, finalizeKernel, and finalizeDeepSeekKernel. For example, in permuteKernel, expressions like tokenIdx * params.hiddenDim and permutedIdx * params.hiddenDim can overflow with large numTokens and hiddenDim values. I strongly recommend applying similar int64_t casting to all such index computations throughout the file to ensure the fused MoE pipeline is fully robust against these critical memory access bugs.

@charlotte12l charlotte12l changed the title [flashinfer] Fix int32 overflow in TRTLLM fused MoE activation kernel… Fix int32 overflow in TRTLLM fused MoE activation kernel Feb 26, 2026
@charlotte12l charlotte12l changed the title Fix int32 overflow in TRTLLM fused MoE activation kernel [fp8_blockwise]Fix int32 overflow in TRTLLM fused MoE activation kernel Feb 26, 2026
@charlotte12l charlotte12l marked this pull request as ready for review February 27, 2026 01:26
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_dev_kernel.cu (1)

81-82: Please add a regression test for the >INT32_MAX indexing path.

Given this is a correctness fix for a prior IMA, a targeted large-shape test would help prevent regressions in future kernel refactors.

Also applies to: 265-273, 311-312, 329-329

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_dev_kernel.cu` around lines 81 - 82, Add a regression
test that exercises the >INT32_MAX indexing path by constructing input
tensors/shapes so that (permutedIdx * params.innerDim) > INT32_MAX and verifying
correctness; specifically, create a large-shape unit/integration test that
invokes the kernels in csrc/trtllm_fused_moe_dev_kernel.cu (the code paths using
baseIdx = (int64_t)permutedIdx * params.innerDim + hiddenIdx and the related
sections around lines 265-273, 311-312, 329) to ensure those int64_t indices are
exercised and results match the expected output (or a smaller reference
computation). Ensure the test runs in CI, is deterministic, and includes both
forward and backward (if applicable) checks to catch regressions in future
refactors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@csrc/trtllm_fused_moe_dev_kernel.cu`:
- Around line 81-82: Add a regression test that exercises the >INT32_MAX
indexing path by constructing input tensors/shapes so that (permutedIdx *
params.innerDim) > INT32_MAX and verifying correctness; specifically, create a
large-shape unit/integration test that invokes the kernels in
csrc/trtllm_fused_moe_dev_kernel.cu (the code paths using baseIdx =
(int64_t)permutedIdx * params.innerDim + hiddenIdx and the related sections
around lines 265-273, 311-312, 329) to ensure those int64_t indices are
exercised and results match the expected output (or a smaller reference
computation). Ensure the test runs in CI, is deterministic, and includes both
forward and backward (if applicable) checks to catch regressions in future
refactors.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1589ebb and 061e481.

📒 Files selected for processing (1)
  • csrc/trtllm_fused_moe_dev_kernel.cu

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Feb 27, 2026

/bot run

@aleozlx aleozlx added the run-ci label Feb 27, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !356 has been created, and the CI pipeline #44998402 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44998402: 10/20 passed

@charlotte12l
Copy link
Copy Markdown
Contributor Author

#44998402

Not able to open the link?

@IwakuraRein
Copy link
Copy Markdown
Collaborator

@ChristinaZ for viz

Copy link
Copy Markdown
Collaborator

@IwakuraRein IwakuraRein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the findings!

@IwakuraRein
Copy link
Copy Markdown
Collaborator

@flashinfer-ci-bot run

@yongwww yongwww added run-ci and removed run-ci labels Mar 3, 2026
Copy link
Copy Markdown
Contributor

@ChristinaZ ChristinaZ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work. The modification looks good to me.

@IwakuraRein IwakuraRein merged commit 1b02c56 into flashinfer-ai:main Mar 5, 2026
136 of 161 checks passed
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…el (flashinfer-ai#2642)

<!-- .github/pull_request_template.md -->

## 📌 Description

Fix CUDA Illegal Memory Access (IMA) caused by int32 overflow in
activationKernel and activationDeepSeekKernel in the TRTLLM fused MoE
pipeline.

Root cause: The index computation `permutedIdx * params.innerDim +
hiddenIdx` uses int32 arithmetic. With large MoE configurations (e.g.
256 global experts, topK=8, DP=2, EP=2), the values can exceed
INT32_MAX:
- num_tokens = 65536 (max_num_batched_tokens * DP)
- totalNumPaddedTokens up to 524,288 65536 * 8, worst case all tokens
route to local experts)
- innerDim =   2 * intermediate_size, suppose its >5k
- 524,287 * innerDim may be > INT32_MAX (2,147,483,647)

The overflow produces a negative index, causing out-of-bounds memory
access.
Fix: Cast permutedIdx to int64_t before the multiplication in both
activationKernel (line 82) and activationDeepSeekKernel (line 337).

The overflow may also cause issue in other places, e.g.
flashinfer-ai#2643, but I don't have
time to validate flashinfer-ai#2643
yet.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x ] I have installed the hooks with `pre-commit install`.
- [ x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

Verified locally with the same model, works

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed integer overflow issues in tensor indexing calculations,
enabling proper support for larger tensor dimensions without overflow
errors. Improves stability for large-scale tensor processing operations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants