Skip to content

fix: int32 overflow in trtllm_fp4_block_scale_moe causing "Unsupported hidden state scale shape" for EP32+ configs#2853

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
qiching:fix/int32-overflow-fp4-moe
Mar 24, 2026
Merged

fix: int32 overflow in trtllm_fp4_block_scale_moe causing "Unsupported hidden state scale shape" for EP32+ configs#2853
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
qiching:fix/int32-overflow-fp4-moe

Conversation

@qiching
Copy link
Copy Markdown
Contributor

@qiching qiching commented Mar 22, 2026

📌 Description

Fix int32 overflow in trtllm_fp4_block_scale_moe that causes a misleading NotImplementedError: Unsupported hidden state scale shape when deploying large Expert Parallel configurations (e.g., EP32 with DeepSeek-R1 NVFP4).

Step 1, NVFP4 activation quantization (per EP rank)

Each of the 32 EP ranks quantizes its local activations via vllm.ops.scaled_fp4_quant with is_sf_swizzled_layout=False. From nvfp4_quant_entry.cu:

output_sf = torch::empty(
    {m, n / CVT_FP4_SF_VEC_SIZE},
    torch::TensorOptions().device(device).dtype(torch::kUInt8));

For m=10240 (max_num_batched_tokens), n=7168 (hidden_size):

hidden_states: [10240, 3584] uint8 (FP4 packed, 2 values per byte)
hidden_states_scale: [10240, 448] uint8 → viewed as float8_e4m3fn
No padding is applied in the non-swizzled layout. Scale numel = 10240 × 448 = 4,587,520.

Step 2, EP allgather via dispatch()

MoEPrepareAndFinalizeNaiveDPEPModular.prepare() in naive_dp_ep.py calls get_ep_group().dispatch(), which allgathers both hidden_states and hidden_states_scale (passed as extra_tensors) across all 32 EP ranks:

hidden_states: 32 × [10240, 3584] → [327680, 3584]
hidden_states_scale: 32 × [10240, 448][327680, 448]

Step 3, Scale reshape in vLLM wrapper

In trtllm_nvfp4_moe.py, the scale is reshaped before passing to flashInfer:

hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
    *hidden_states.shape[:-1], -1)  # → [327680, 448]

At this point hidden_states_scale.numel() = 327680 × 448 = 146,800,640.

Step 4, int32 overflow in FlashInfer C++ kernel

In csrc/trtllm_fused_moe_kernel_launcher.cu, the scale vector size is computed as:

int const num_tokens = hidden_states.size(0);   // int (32-bit) = 327680
int hidden_size = hidden_states.size(1);          // int (32-bit) = 3584
if (hidden_states.dtype() == dl_uint8) hidden_size *= 2;  // hidden_size = 7168
hidden_states_scale_vec_size =
    (num_tokens * hidden_size) / hidden_states_scale.value().numel();
//   ^^^^^^^^^^^^^^^^^^^^^^^^
//   int * int = int → OVERFLOW before promotion to int64 for division

the overflow:
327680 × 7168 = 2,348,810,240
INT_MAX = 2,147,483,647
2,348,810,240 > INT_MAX, signed int32 overflow (undefined behavior in C++, wraps to -1,946,157,056 on two's complement architectures)

vec_size = -1,946,157,056 / 146,800,640 = -13
-13 ≠ 16 and -13 ≠ 32 will throws "Unsupported hidden state scale shape"

Step 5, why not and works

Overflow threshold for DeepSeek-R1 (hidden_size=7168):
Max safe tokens: INT_MAX / 7168 = 299,593
EP32 per-rank limit: 299,593 / 32 ≈ 9,362
Any max_num_batched_tokens > 9362 with EP32 will trigger the overflow

We confirmed the overflow boundary on an 8-node GB200 cluster (32 GPUs, EP32, DP32) with --all2all-backend allgather_reducescatter:

max_num_batched_tokens Total tokens (×32) M × 7168 vs INT_MAX Result
9360 299,520 2,146,560,000 < 2,147,483,647 ✅ Success
9370 299,840 2,148,853,760 > 2,147,483,647 Crash
8192 (Workaround) 262,144 1,879,048,192 < 2,147,483,647 ✅ Success
10240 (Original) 327,680 2,348,810,240 > 2,147,483,647 Crash

Reproduction
vLLM serve with EP32:

vllm serve nvidia/DeepSeek-R1-NVFP4 \
    --tensor-parallel-size 1 \
    --data-parallel-size 32 \
    --enable-expert-parallel \
    --all2all-backend allgather_reducescatter \
    --max-num-batched-tokens 10240 \
    --kv-cache-dtype fp8 \
    --trust-remote-code

Crashes during engine initialization with:
NotImplementedError: Unsupported hidden state scale shape. (Also found this issue in vllm-project/vllm#36022 (comment))

Promote the multiplication operands to int64_t before division to prevent overflow:
hidden_states_scale_vec_size: Cast num_tokens to int64_t so the multiplication chain executes in 64-bit.
weight_scale_vec_size: Apply the same pattern with local_num_experts cast to int64_t, and declare the variable as int64_t for consistency.

Cast the multiplication operands to int64_t before the division:

// In csrc/trtllm_fused_moe_kernel_launcher.cu
// Before (overflow-prone):
int const num_tokens = hidden_states.size(0);
int hidden_size = hidden_states.size(1);
if (hidden_states.dtype() == dl_uint8) hidden_size *= 2;
hidden_states_scale_vec_size =
    (num_tokens * hidden_size) / hidden_states_scale.value().numel();

// After (safe):
int const num_tokens = hidden_states.size(0);
int hidden_size = hidden_states.size(1);
if (hidden_states.dtype() == dl_uint8) hidden_size *= 2;
    hidden_states_scale_vec_size = (static_cast<int64_t>(num_tokens) * hidden_size) / hidden_states_scale.value().numel();
  }

The same pattern should also be applied to weight_scale_vec_size for safety:

int64_t weight_scale_vec_size =
    (static_cast<int64_t>(local_num_experts) * intermediate_size
     * intermediate_size_factor * hidden_size) /
    gemm1_weights_scale.numel();

Impact
Zero performance impact: these are CPU-side setup computations executed once before GPU kernel launch.
Zero API change: No function signatures are modified.
Unblocks: EP32+ deployments for large-hidden-size models (DeepSeek-R1, etc.) with max_num_batched_tokens above the int32 threshold.

Environment
Model: DeepSeek-R1-0528-FP4 (NVFP4, hidden_size=7168)
Hardware: 8× GB200 nodes (32 GPUs), disaggregated prefill-decode
Configuration: DP=32, EP=32, TP=1, PP=1
vLLM: v0.17.2rc1 (bundled FlashInfer)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Fixed integer overflow in internal size calculations that could cause crashes or incorrect behavior with very large models or batch sizes, improving stability and reliability for large-scale inference.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 22, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4197fe1f-3069-4b0f-9274-c42ebb3f2f04

📥 Commits

Reviewing files that changed from the base of the PR and between 07d7aaf and 69e587e.

📒 Files selected for processing (1)
  • csrc/trtllm_fused_moe_kernel_launcher.cu

📝 Walkthrough

Walkthrough

This change widens two size computations in the TRTLLM fused MoE kernel launcher from 32-bit to 64-bit integers to avoid overflow when computing scale vector sizes for hidden states and weights.

Changes

Cohort / File(s) Summary
MOE kernel int widening
csrc/trtllm_fused_moe_kernel_launcher.cu
Changed hidden_states_scale_vec_size and weight_scale_vec_size from int to int64_t; added explicit int64_t casts on operands (num_tokens * hidden_size, local_num_experts * intermediate_size * ... * hidden_size) and updated checks to use 64-bit values.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

run-ci, op: moe

Suggested reviewers

  • cyx-6
  • yzh119
  • jiahanc
  • jimmyzho
  • djmmoss
  • nvmbreughe

Poem

🐰
I hopped through code with careful paws,
Swapped thirty-two for sixty-four because,
Tokens and experts now have room,
No overflow to bring the doom,
Hooray — safe math reclaims the cause!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: fixing int32 overflow in trtllm_fp4_block_scale_moe that caused errors in EP32+ configurations.
Description check ✅ Passed The PR description is comprehensive and well-structured with detailed root-cause analysis, reproduction steps, code changes, impact assessment, and environment details, fully addressing the template requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical integer overflow bug that prevented the successful deployment of large Expert Parallel configurations in vLLM, specifically when using models with substantial hidden state sizes. By promoting key arithmetic operations to 64-bit integers, the change ensures that calculations involving large token counts and hidden dimensions are handled correctly, thereby resolving misleading NotImplementedError messages and enabling broader compatibility for high-scale inference setups without impacting performance or API signatures.

Highlights

  • Fixes int32 overflow: Resolved an int32 overflow issue within the trtllm_fp4_block_scale_moe function, which previously led to a NotImplementedError for large Expert Parallel (EP) configurations.
  • Enables large EP deployments: This fix unblocks deployments of large EP32+ configurations, particularly for models with large hidden sizes like DeepSeek-R1, by correctly handling large token counts that previously exceeded int32 limits.
  • Type promotion for calculations: Implemented type promotion by casting multiplication operands to int64_t in critical calculations for hidden_states_scale_vec_size and weight_scale_vec_size to prevent overflow.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@qiching qiching changed the title fix: int32 overflow in trtllm_fp4_block_scale_moe for large EP config… fix: int32 overflow in trtllm_fp4_block_scale_moe causing "Unsupported hidden state scale shape" for EP32+ configs Mar 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical int32 overflow bug in trtllm_fp4_block_scale_moe by promoting the multiplication to 64-bit. The proactive fix for weight_scale_vec_size is also a good addition for safety. I have one minor suggestion to improve consistency between the two related variable declarations.

@qiching qiching marked this pull request as ready for review March 23, 2026 20:09
Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@aleozlx aleozlx enabled auto-merge (squash) March 23, 2026 22:50
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 23, 2026

/bot run

@aleozlx aleozlx disabled auto-merge March 23, 2026 22:50
@aleozlx aleozlx added the run-ci label Mar 23, 2026
@aleozlx aleozlx self-assigned this Mar 23, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !449 has been created, and the CI pipeline #46822943 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46822943: 13/20 passed

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the bugfix.

@yzh119 yzh119 merged commit 76790d8 into flashinfer-ai:main Mar 24, 2026
40 of 60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants