Skip to content

[Bugfix] Fix fused MoE int32 overflow in stride*offset without perf regression#34507

Merged
vllm-bot merged 2 commits intovllm-project:mainfrom
haosdent:fix/fused-moe-int64-stride-perf-regression
Feb 17, 2026
Merged

[Bugfix] Fix fused MoE int32 overflow in stride*offset without perf regression#34507
vllm-bot merged 2 commits intovllm-project:mainfrom
haosdent:fix/fused-moe-int64-stride-perf-regression

Conversation

@haosdent
Copy link
Copy Markdown
Contributor

@haosdent haosdent commented Feb 13, 2026

Purpose

Fixes #34413

PR #34279 annotated all stride parameters as tl.int64 to fix an int32 overflow crash, but this caused ~60x perf regression on small GPUs (e.g. NVIDIA GB10) due to register pressure. PR #34530 reverted that fix.

This patch prevents the overflow with minimal register impact by casting offs_token to int64 after loading instead of widening all strides. When chunking is disabled and M is large, stride_cm * offs_token (where stride_cm = N = w1.size(1) and offs_token up to M*topk) can exceed int32 max. The cast leverages Triton type promotion (int32 * int64 -> int64) following the existing pattern used for off_experts and offs_bn.

Adds a regression test that disables chunking with M=100000, n=2048, topk=6 (product = 4096 * 600000 = 2.46B > int32 max) and validates correctness against the torch_moe reference.

Overflow safety audit (all stride*offset products verified):

┌───────────────────────────────────────┬──────────────────┬─────────────┬────────────────────┬───────┐
│              Expression               │   Offset type    │ Stride type │       Result       │ Safe? │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ offs_token // top_k * stride_am       │ int64            │ int32       │ int64              │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ offs_k * stride_ak                    │ int32            │ int32       │ int32 (tiny: ~128) │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ off_experts * stride_be               │ int64 (existing) │ int32       │ int64              │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ offs_bn * stride_bn                   │ int64 (existing) │ int32       │ int64              │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ stride_cm * offs_token                │ int64            │ int32       │ int64              │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ BLOCK_SIZE_K * stride_ak (inner loop) │ constexpr (~128) │ int32       │ int32 (tiny)       │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ offs_token // top_k * stride_asm      │ int64            │ int32       │ int64              │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ off_experts * stride_bse              │ int64 (existing) │ int32       │ int64              │ Yes   │
├───────────────────────────────────────┼──────────────────┼─────────────┼────────────────────┼───────┤
│ offs_bn * stride_bsn                  │ int64 (existing) │ int32       │ int64              │ Yes   │
└───────────────────────────────────────┴──────────────────┴─────────────┴────────────────────┴───────┘

Test Plan

Add test case test_fused_moe_int64_overflow, and run python -m pytest tests/kernels/moe/test_moe.py::test_fused_moe_int64_overflow -v 2>&1

Test Result

New test case pass while it fail on the main branch.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the bug Something isn't working label Feb 13, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a significant performance regression in the fused MoE kernels on register-constrained GPUs. The previous approach of casting all stride parameters to tl.int64 was too broad and caused excessive register pressure. The new approach is more targeted: it reverts the strides to their default integer type and instead casts offs_token to tl.int64 just before its use in pointer arithmetic. This correctly prevents overflow in stride-offset products by leveraging Triton's type promotion, while minimizing the impact on register usage. The change is applied consistently across both fused_moe_kernel and fused_moe_kernel_gptq_awq and aligns with existing patterns in the code. The fix is well-reasoned and should restore the expected performance. I approve these changes.

@eugr
Copy link
Copy Markdown

eugr commented Feb 13, 2026

@haosdent - thanks, this fixes the performance regression on Spark.
@johnnynunez - fyi.

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Feb 13, 2026

Thanks a lot @haosdent for the analysis, we will revert the original PR for now to simplify #34530

@tlrmchlsmth
Copy link
Copy Markdown
Member

Sorry, went a bit scorched earth with that one - have you tried the repro scrip in #34279 on this by any chance? Hoping this more selective fix still resolves the IMAs

@haosdent
Copy link
Copy Markdown
Contributor Author

I run the script as well, this fix works for me

Testing m=1000, n=1024, k=1024, e=8, topk=2
  stride_cm=2048, max_offs_token=2000, max_c_offset=4096000, int32_max=2147483647
  stride_am=1024, max_a_offset=1024000
  SUCCESS - output shape: torch.Size([1000, 1024])

Testing m=40000, n=1024, k=1024, e=8, topk=6
  stride_cm=6144, max_offs_token=240000, max_c_offset=1474560000, int32_max=2147483647
  stride_am=1024, max_a_offset=40960000
  SUCCESS - output shape: torch.Size([40000, 1024])

Testing m=100000, n=2048, k=1024, e=8, topk=6
  stride_cm=12288, max_offs_token=600000, max_c_offset=7372800000, int32_max=2147483647
  stride_am=1024, max_a_offset=102400000
  ** C OFFSET WILL OVERFLOW INT32 **
  SUCCESS - output shape: torch.Size([100000, 1024])

but if switched to main, it would crash

Testing m=1000, n=1024, k=1024, e=8, topk=2
  stride_cm=2048, max_offs_token=2000, max_c_offset=4096000, int32_max=2147483647
  stride_am=1024, max_a_offset=1024000
  SUCCESS - output shape: torch.Size([1000, 1024])

Testing m=40000, n=1024, k=1024, e=8, topk=6
  stride_cm=6144, max_offs_token=240000, max_c_offset=1474560000, int32_max=2147483647
  stride_am=1024, max_a_offset=40960000
  SUCCESS - output shape: torch.Size([40000, 1024])

Testing m=100000, n=2048, k=1024, e=8, topk=6
  stride_cm=12288, max_offs_token=600000, max_c_offset=7372800000, int32_max=2147483647
  stride_am=1024, max_a_offset=102400000
  ** C OFFSET WILL OVERFLOW INT32 **
  SUCCESS - output shape: torch.Size([100000, 1024])
  CRASH: AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@haosdent
Copy link
Copy Markdown
Contributor Author

@SouthWest7 thank you for this

The test_moe test cases do not cover the overflow scenario.

I have refered the reproduce script and add new test case test_fused_moe_int64_overflow to cover it.

@haosdent haosdent force-pushed the fix/fused-moe-int64-stride-perf-regression branch from 8762017 to 63008c1 Compare February 14, 2026 13:43
…egression

PR vllm-project#34279 annotated all stride parameters as tl.int64 to fix an int32
overflow crash, but this caused ~60x perf regression on small GPUs (e.g.
NVIDIA GB10) due to register pressure. PR vllm-project#34530 reverted that fix.

This patch prevents the overflow with minimal register impact by casting
offs_token to int64 after loading instead of widening all strides. When
chunking is disabled and M is large, stride_cm * offs_token (where
stride_cm = N = w1.size(1) and offs_token up to M*topk) can exceed
int32 max. The cast leverages Triton type promotion (int32 * int64 ->
int64) following the existing pattern used for off_experts and offs_bn.

Adds a regression test that disables chunking with M=100000, n=2048,
topk=6 (product = 4096 * 600000 = 2.46B > int32 max) and validates
correctness against the torch_moe reference.

Fixes vllm-project#34413

Signed-off-by: haosdent <haosdent@gmail.com>
@haosdent haosdent force-pushed the fix/fused-moe-int64-stride-perf-regression branch from 63008c1 to 54ba644 Compare February 14, 2026 13:49
@haosdent haosdent changed the title [Bugfix] Fix fused MoE perf regression on small GPUs from int64 strides [Bugfix] Fix fused MoE int32 overflow in stride*offset without perf regression Feb 14, 2026
@haosdent
Copy link
Copy Markdown
Contributor Author

Hi @mgoin I have updated the PR, may you help to take a look to check if this is still necessary?

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, glad you setup the regression test as well. LGTM but will let @tlrmchlsmth sign off. It would be good to get AMD perf validation as well

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 16, 2026
Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This has been annoying me for a long time so glad to see it get fixed

@vllm-bot vllm-bot merged commit b68fd89 into vllm-project:main Feb 17, 2026
53 of 59 checks passed
@mgehre-amd
Copy link
Copy Markdown
Contributor

mgehre-amd commented Feb 17, 2026

Unfortunately, I still see the perf degradation:

Benchmark: cyankiwi/Qwen3-30B-A3B-Instruct-2507-AWQ-4bit on Strix Halo (gfx1151)

Just d7982da (int64 strides) by itself gave a 5% regression.
Adding 54ba644 on top does not recover the original performance.

Setup: --input-len 512 --max-model-len 4096 --dtype float16 --target-gpu-memory-gb 20 --num-prompts 5

Metric Baseline (neither d7982da nor 54ba644) Modified (both d7982da and 54ba644) Delta
Decode 46.5 tok/s 43.7 tok/s -6.0%
TPOT 21.48 ms 22.89 ms +6.5%
TTFT 275 ms 434 ms +58%

Each configuration was run twice with consistent results.

@haosdent
Copy link
Copy Markdown
Contributor Author

haosdent commented Feb 17, 2026

and the targeted offs_token cast in 54ba644 does not compensate for it on this workload.

@mgehre-amd Thanks for your test, it has indeed degraded a bit after becoming int64.
But I'm not sure if compromise on this to fix the overflow issue or revert the change. I could not judge it from my side (either of them is ok for me until we find a better approach), so we may need vLLM experts' help to decide this @mgoin @tlrmchlsmth .

@mgehre-amd
Copy link
Copy Markdown
Contributor

If you revert both d7982da and 54ba644 , do the degradation happen? @mgehre-amd

With both commits reverted, I don't see the regression. Updated my previous message to be clearer.

wzhao18 pushed a commit to wzhao18/vllm that referenced this pull request Feb 18, 2026
…egression (vllm-project#34507)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
…egression (vllm-project#34507)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
ZJY0516 pushed a commit to ZJY0516/vllm that referenced this pull request Feb 23, 2026
…egression (vllm-project#34507)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…egression (vllm-project#34507)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…egression (vllm-project#34507)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
…egression (vllm-project#34507)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
## Summary

Cherry-pick upstream bug fixes for RHAIIS 3.3.1 onto `rhai/0.13.0`. All
fixes are from upstream vLLM `main` and address critical bugs affecting
RHAIIS 3.3.0. Other releases (3.2.2, EAx) will be done separately.

**Jira Epic:**
[INFERENG-4743](https://issues.redhat.com/browse/INFERENG-4743)

## Cherry-picked commits (chronological order)

| # | Upstream PR | Jira | Summary |
|---|------------|------|---------|
| 1 | [vllm-project#30550](vllm-project#30550) |
[INFERENG-5106](https://issues.redhat.com/browse/INFERENG-5106) |
Support using chat template as custom score template for reranking
models |
| 2 | [vllm-project#31406](vllm-project#31406) |
[INFERENG-4800](https://issues.redhat.com/browse/INFERENG-4800) | Add
encoder-only/cross attention support to Triton Attention backend |
| 3 | [vllm-project#34243](vllm-project#34243) |
[INFERENG-4746](https://issues.redhat.com/browse/INFERENG-4746) | Fix
Llama-4 attn quantization by correctly permuting scales for rope (int8,
fp8) |
| 4 | [vllm-project#34454](vllm-project#34454) |
[INFERENG-5032](https://issues.redhat.com/browse/INFERENG-5032) | Fix
structured output in multi-turn GPT-OSS (content:null with json_object)
|
| 5 | [vllm-project#34507](vllm-project#34507) |
[INFERENG-5038](https://issues.redhat.com/browse/INFERENG-5038) | Fix
fused MoE int32 overflow in stride*offset for large models |
| 6 | [vllm-project#35085](vllm-project#35085) |
[INFERENG-5028](https://issues.redhat.com/browse/INFERENG-5028) |
Gracefully disable AllReduceFusionPass on GPUs without multicast support
|
| 7 | [vllm-project#35456](vllm-project#35456) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) |
Replace assert with ValueError for response_format validation
(completions) |
| 8 | [vllm-project#35510](vllm-project#35510) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) | Add
response_format validation to chat completions endpoint |


## Conflict resolutions

<details>
<summary><b>#1 — llama-nemotron-embed / score-template support
(vllm-project#30550)</b>: Clean cherry-pick, no conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#2 — Triton Attention (vllm-project#31406)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#3 — Llama-4 attn quant (vllm-project#34243)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly. 4 intermediate upstream commits touch `llama4.py` but
the fix targets a self-contained block.
</details>

<details>
<summary><b>vllm-project#4 — GPT-OSS multi-turn (vllm-project#34454)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly despite 3 intermediate upstream commits that refactored
imports in `gptoss_reasoning_parser.py`. The fix logic (adding
`eom_token_id` early-exit check in `is_reasoning_end`) was independent
of the import changes.
</details>

<details>
<summary><b>vllm-project#5 — Fused MoE int32 overflow (vllm-project#34507)</b>: Conflicts in 2
files</summary>

**`vllm/model_executor/layers/fused_moe/fused_moe.py`**: ~30
intermediate upstream commits refactored `fused_moe_kernel` with
conditional `naive_block_assignment` logic that doesn't exist in
`rhai/0.13.0`. Resolved by keeping our simpler code and applying only
the int64 cast fix:
- `fused_moe_kernel_gptq_awq`: added `.to(tl.int64)` to `tl.load()`
result
- `fused_moe_kernel`: added `offs_token = offs_token.to(tl.int64)`
before `token_mask`

**`tests/kernels/moe/test_moe.py`**: Upstream test changes depend on
`make_dummy_moe_config()` from intermediate refactors. Resolved by
keeping our existing test code (no test changes).
</details>

<details>
<summary><b>vllm-project#6 — AllReduceFusionPass multicast (vllm-project#35085)</b>: Conflict
due to file rename + API change</summary>

Upstream moved `collective_fusion.py` →
`compilation/passes/fusion/allreduce_rms_fusion.py` and changed the API
from `trtllm_create_ipc_workspace_for_all_reduce_fusion()` to
`create_allreduce_fusion_workspace()`. Resolved by applying the
try/except wrapper around our existing
`trtllm_create_ipc_workspace_for_all_reduce_fusion()` call in
`collective_fusion.py`. The error handling logic (catching RuntimeError
with "multicast" in message, logging warning, returning early) is
identical to upstream.
</details>

<details>
<summary><b>vllm-project#7 — response_format validation for completions
(vllm-project#35456)</b>: Conflict due to file restructuring</summary>

Upstream split `protocol.py` into `completion/protocol.py` and
`chat_completion/protocol.py`. Our branch still has the monolithic
`protocol.py`. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`CompletionRequest` in our `protocol.py`
- Using `ValueError` instead of upstream's `VLLMValidationError` (which
doesn't exist in our branch; `ValueError` is already handled as 400 Bad
Request in `serving_engine.py`)
- Test additions from upstream applied cleanly to
`test_completion_error.py`
</details>

<details>
<summary><b>vllm-project#8 — response_format validation for chat completions
(vllm-project#35510)</b>: Conflict due to file restructuring</summary>

Same file restructuring issue as vllm-project#6. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/chat_completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`ChatCompletionRequest` in our `protocol.py`
- Only accepting the `test_json_schema_response_format_missing_schema`
test from the conflict (discarding ~140 lines of intermediate upstream
tests that reference non-existent paths in our branch)
</details>

## Test plan

- [ ] Verify `llama-nemotron-embed-1b-v2` works correctly with the
backported score-template / bidirectional model support
- [ ] Verify Llama-4 quantized model loads correctly with int8/fp8
attention quantization
- [ ] Verify GPT-OSS multi-turn chat with `json_object` response_format
returns valid content
- [ ] Verify large MoE models (e.g. Qwen3.5-397B) don't crash with int32
overflow
- [ ] Verify MoE model loading on H200 GPUs (without multicast)
gracefully falls back
- [ ] Verify `response_format: {type: "json_schema"}` without
`json_schema` field returns 400 (not 500) for both `/v1/completions` and
`/v1/chat/completions`
- [ ] Verify encoder models (e.g. Whisper) work with Triton attention
backend on ROCm


[INFERENG-4743]:
https://redhat.atlassian.net/browse/INFERENG-4743?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4800]:
https://redhat.atlassian.net/browse/INFERENG-4800?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4746]:
https://redhat.atlassian.net/browse/INFERENG-4746?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5032]:
https://redhat.atlassian.net/browse/INFERENG-5032?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5038]:
https://redhat.atlassian.net/browse/INFERENG-5038?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

[INFERENG-5106]:
https://redhat.atlassian.net/browse/INFERENG-5106?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Qwen coder next performance after d7982da commit.

6 participants