Skip to content

[mxfp8 etc.]Fix int32 overflow in TRTLLM fused MoE permute/finalize kernels#2643

Draft
charlotte12l wants to merge 1 commit intoflashinfer-ai:mainfrom
charlotte12l:ima_fp8_all
Draft

[mxfp8 etc.]Fix int32 overflow in TRTLLM fused MoE permute/finalize kernels#2643
charlotte12l wants to merge 1 commit intoflashinfer-ai:mainfrom
charlotte12l:ima_fp8_all

Conversation

@charlotte12l
Copy link
Copy Markdown
Contributor

@charlotte12l charlotte12l commented Feb 26, 2026

📌 Description

Extend the int32 overflow fix from #2642 to the remaining standalone
kernels in trtllm_fused_moe_dev_kernel.cu: permuteKernel, finalizeKernel,
and finalizeDeepSeekKernel.

The same root cause applies — permutedIdx * dimension can exceed INT32_MAX
with large MoE configurations. This PR fixes all remaining overflow sites:

  • permuteKernel: permutedIdx * hiddenDim (line 642)
  • permuteKernel (DeepSeek scale path): permutedIdx + totalNumPaddedTokens * scaleIdx (line 653)
  • finalizeKernel: permutedIdx * hiddenDimPadded (lines 718, 720)
  • finalizeDeepSeekKernel: permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128) (line 982)
  • finalizeDeepSeekKernel: permutedIdx * hiddenDimPadded (line 989)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

We suspect this may be the root cause for our another IMA, not validated yet

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 26, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @charlotte12l, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses potential integer overflow issues in several TRTLLM fused MoE kernels by explicitly casting intermediate calculations to int64_t. This ensures correct indexing and memory access for large MoE configurations, preventing data corruption or unexpected behavior that could arise from int32 limits. The changes are a continuation of a previous fix, broadening its scope to cover all remaining vulnerable operations.

Highlights

  • Extended Integer Overflow Fix: The existing int32 overflow fix (from PR [fp8_blockwise]Fix int32 overflow in TRTLLM fused MoE activation kernel #2642) has been extended to cover additional standalone kernels within trtllm_fused_moe_dev_kernel.cu.
  • Affected Kernels and Expressions: The fix specifically targets permuteKernel, finalizeKernel, and finalizeDeepSeekKernel to prevent overflow in calculations involving permutedIdx * dimension or similar large intermediate products by casting to int64_t.
  • Specific Overflow Sites Addressed: The identified overflow sites include permutedIdx * hiddenDim in permuteKernel, permutedIdx + totalNumPaddedTokens * scaleIdx in permuteKernel (DeepSeek scale path), permutedIdx * hiddenDimPadded in finalizeKernel, and two similar expressions in finalizeDeepSeekKernel.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_dev_kernel.cu
    • Updated permuteKernel to cast permutedIdx * params.hiddenDim to int64_t to prevent overflow.
    • Modified permuteKernel (DeepSeek scale path) to cast permutedIdx + params.totalNumPaddedTokens[0] * scaleIdx to int64_t for idx_out calculation.
    • Adjusted finalizeKernel to cast permutedIdx * params.hiddenDimPadded to int64_t in two locations to prevent overflow.
    • Updated finalizeDeepSeekKernel to cast permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128) to int64_t for scaleIdx calculation.
    • Modified finalizeDeepSeekKernel to cast permutedIdx * params.hiddenDimPadded to int64_t when accessing params.inPtr.
Activity
  • The author suspects this fix may resolve another internal issue (IMA), though it has not yet been validated.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses several potential int32 overflow issues in the MoE kernels by casting indices to int64_t before multiplication. The changes in permuteKernel, finalizeKernel, and finalizeDeepSeekKernel are consistent with the goal of preventing overflows with large tensor dimensions. I have one suggestion for an additional location where an overflow could occur, for the sake of completeness and to make the fix more robust.

@@ -597,7 +597,7 @@ __global__ void permuteKernel(KernelParams params) {
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];

int const idx_in = tokenIdx + params.numTokens * scaleIdx;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and to prevent potential integer overflow with large numTokens and hiddenDim, it's safer to use int64_t for this index calculation as well, similar to the change for idx_out. The product params.numTokens * scaleIdx could overflow if both numTokens and hiddenDim (which determines the range of scaleIdx) are large.

          int64_t const idx_in = (int64_t)tokenIdx + (int64_t)params.numTokens * scaleIdx;

@charlotte12l charlotte12l changed the title Fix int32 overflow in TRTLLM fused MoE permute/finalize kernels (all quant paths) [mxfp8 etc.]Fix int32 overflow in TRTLLM fused MoE permute/finalize kernels Feb 26, 2026
IwakuraRein pushed a commit that referenced this pull request Mar 5, 2026
…el (#2642)

<!-- .github/pull_request_template.md -->

## 📌 Description

Fix CUDA Illegal Memory Access (IMA) caused by int32 overflow in
activationKernel and activationDeepSeekKernel in the TRTLLM fused MoE
pipeline.

Root cause: The index computation `permutedIdx * params.innerDim +
hiddenIdx` uses int32 arithmetic. With large MoE configurations (e.g.
256 global experts, topK=8, DP=2, EP=2), the values can exceed
INT32_MAX:
- num_tokens = 65536 (max_num_batched_tokens * DP)
- totalNumPaddedTokens up to 524,288 65536 * 8, worst case all tokens
route to local experts)
- innerDim =   2 * intermediate_size, suppose its >5k
- 524,287 * innerDim may be > INT32_MAX (2,147,483,647)


The overflow produces a negative index, causing out-of-bounds memory
access.
Fix: Cast permutedIdx to int64_t before the multiplication in both
activationKernel (line 82) and activationDeepSeekKernel (line 337).

The overflow may also cause issue in other places, e.g.
#2643, but I don't have
time to validate #2643
yet.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x ] I have installed the hooks with `pre-commit install`.
- [ x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

Verified locally with the same model, works 

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed integer overflow issues in tensor indexing calculations,
enabling proper support for larger tensor dimensions without overflow
errors. Improves stability for large-scale tensor processing operations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…el (flashinfer-ai#2642)

<!-- .github/pull_request_template.md -->

## 📌 Description

Fix CUDA Illegal Memory Access (IMA) caused by int32 overflow in
activationKernel and activationDeepSeekKernel in the TRTLLM fused MoE
pipeline.

Root cause: The index computation `permutedIdx * params.innerDim +
hiddenIdx` uses int32 arithmetic. With large MoE configurations (e.g.
256 global experts, topK=8, DP=2, EP=2), the values can exceed
INT32_MAX:
- num_tokens = 65536 (max_num_batched_tokens * DP)
- totalNumPaddedTokens up to 524,288 65536 * 8, worst case all tokens
route to local experts)
- innerDim =   2 * intermediate_size, suppose its >5k
- 524,287 * innerDim may be > INT32_MAX (2,147,483,647)

The overflow produces a negative index, causing out-of-bounds memory
access.
Fix: Cast permutedIdx to int64_t before the multiplication in both
activationKernel (line 82) and activationDeepSeekKernel (line 337).

The overflow may also cause issue in other places, e.g.
flashinfer-ai#2643, but I don't have
time to validate flashinfer-ai#2643
yet.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x ] I have installed the hooks with `pre-commit install`.
- [ x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

Verified locally with the same model, works

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed integer overflow issues in tensor indexing calculations,
enabling proper support for larger tensor dimensions without overflow
errors. Improves stability for large-scale tensor processing operations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Mar 31, 2026
…el (#2642)

<!-- .github/pull_request_template.md -->

## 📌 Description

Fix CUDA Illegal Memory Access (IMA) caused by int32 overflow in
activationKernel and activationDeepSeekKernel in the TRTLLM fused MoE
pipeline.

Root cause: The index computation `permutedIdx * params.innerDim +
hiddenIdx` uses int32 arithmetic. With large MoE configurations (e.g.
256 global experts, topK=8, DP=2, EP=2), the values can exceed
INT32_MAX:
- num_tokens = 65536 (max_num_batched_tokens * DP)
- totalNumPaddedTokens up to 524,288 65536 * 8, worst case all tokens
route to local experts)
- innerDim =   2 * intermediate_size, suppose its >5k
- 524,287 * innerDim may be > INT32_MAX (2,147,483,647)


The overflow produces a negative index, causing out-of-bounds memory
access.
Fix: Cast permutedIdx to int64_t before the multiplication in both
activationKernel (line 82) and activationDeepSeekKernel (line 337).

The overflow may also cause issue in other places, e.g.
flashinfer-ai/flashinfer#2643, but I don't have
time to validate flashinfer-ai/flashinfer#2643
yet.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x ] I have installed the hooks with `pre-commit install`.
- [ x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

Verified locally with the same model, works 

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed integer overflow issues in tensor indexing calculations,
enabling proper support for larger tensor dimensions without overflow
errors. Improves stability for large-scale tensor processing operations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants