Skip to content

[vLLM IR][RMSNorm] Port Mixer2RMSNormGated to vLLM IR Ops#39262

Open
wxsIcey wants to merge 5 commits into
vllm-project:mainfrom
wxsIcey:wxs/vllm-ir-mixer2-gated-rms-norm
Open

[vLLM IR][RMSNorm] Port Mixer2RMSNormGated to vLLM IR Ops#39262
wxsIcey wants to merge 5 commits into
vllm-project:mainfrom
wxsIcey:wxs/vllm-ir-mixer2-gated-rms-norm

Conversation

@wxsIcey
Copy link
Copy Markdown
Contributor

@wxsIcey wxsIcey commented Apr 8, 2026

Purpose

Register Mixer2RMSNormGated as a vllm IR op and rewrite Mixer2RMSNormGated.forward_native to dispatch correctly across all tensor-parallel configurations.

The implementation handles four cases:

Case Condition Issue Solution
1 n_groups=1, tp_size>1 Variance must be computed across all ranks (one global norm group, each rank holds only a slice) AllReduce local sum-of-squares → compute global variance
2 n_groups=1, tp_size=1 No TP, local data is complete Use IR op directly
3 n_groups>1, n_groups % tp_size != 0 Group boundaries straddle rank boundaries (a rank may hold half a group), local norm is incorrect AllGather full tensor → normalize locally → slice back to local rank
4 n_groups>1, n_groups % tp_size == 0 Each rank holds an integer number of complete groups, variance can be computed independently Use IR op directly

Cases 2 and 4 require no collective communication and are handled by the IR op. Cases 1 and 3 require cross-rank communication that cannot be fused into a single kernel, so they are handled with explicit AllReduce / AllGather before calling into local computation.

Because forward_native now covers all cases (including the optimized IR op paths for cases 2 and 4), forward_cuda is fully redundant and can be removed.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Icey <1790571317@qq.com>
@mergify mergify Bot added nvidia rocm Related to AMD ROCm labels Apr 8, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 8, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the mixer2_rms_norm_gated IR operator and its Triton implementation, refactoring the Mamba Mixer2 layer to utilize this new operator. The changes include updating kernel configurations, platform-specific priorities, and the underlying Triton kernel for gated layer normalization. A potential TypeError was identified in the native implementation of mixer2_rms_norm_gated when the weight is None, and a code suggestion was provided to handle this case.

Comment thread vllm/ir/ops/layernorm.py Outdated
Signed-off-by: Icey <1790571317@qq.com>
@wxsIcey wxsIcey changed the title [vLLM IR] mixer2_gated_rms_norm [vLLM IR] Port Mixer2RMSNormGated to vLLM IR Ops Apr 8, 2026
@wxsIcey wxsIcey changed the title [vLLM IR] Port Mixer2RMSNormGated to vLLM IR Ops [vLLM IR][RMSNorm] Port Mixer2RMSNormGated to vLLM IR Ops Apr 8, 2026
@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Apr 8, 2026

This pr has the same issue as #38798: when using wrap_triton rather than custom op, garbled characters are output. However, I found that setting enforce_eager=True resulted in normal output. It seems this issue is related to torch.compile().

I hope to get some help. cc@zou3519

Signed-off-by: Icey <1790571317@qq.com>
@ProExpertProg
Copy link
Copy Markdown
Collaborator

@wxsIcey I think you're missing run_functional_passes=False in the lowering pass I think - for some reason that flag removes the triton kernel from the replacement if set to True. See

match.replace_by_example(
ir_op_impl.impl_fn, bound_args.args, run_functional_passes=False
)

@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Apr 10, 2026

@wxsIcey I think you're missing run_functional_passes=False in the lowering pass I think - for some reason that flag removes the triton kernel from the replacement if set to True. See

match.replace_by_example(
ir_op_impl.impl_fn, bound_args.args, run_functional_passes=False
)

It seems to have no effect.

Copy link
Copy Markdown
Member

@tomeras91 tomeras91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a nit comment
Also - do you plan to post benchmark results before/after this change? I understand we don't expect any perf diff (IR ops still go through torch.compile), but would like to verify that since this PR changes the code path significantly..

Comment thread vllm/config/kernel.py
"""Priority list for vllm.ir.ops.rms_norm"""

mixer2_rms_norm_gated: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.rms_norm_gated"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: docstring should have vllm.ir.ops.mixer2_rms_norm_gated instead of vllm.ir.ops.rms_norm_gated.
(Or change the op name to rms_norm_gated)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. I will change it.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 14, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wxsIcey.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 14, 2026
@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Apr 15, 2026

Left a nit comment Also - do you plan to post benchmark results before/after this change? I understand we don't expect any perf diff (IR ops still go through torch.compile), but would like to verify that since this PR changes the code path significantly..

Thanks for the review. This PR is currently on low priority. I will add benchmark tests once it's ready.

Signed-off-by: Chaojun Zhang <chaojun.zhang@intel.com>
@chaojun-zhang
Copy link
Copy Markdown
Contributor

@wxsIcey I added this IR on XPU, please review wxsIcey#12

@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Apr 21, 2026

@wxsIcey I added this IR on XPU, please review wxsIcey#12

Thanks for your work, I will merge it.

@mergify mergify Bot added the intel-gpu Related to Intel GPU label Apr 21, 2026
@chaojun-zhang
Copy link
Copy Markdown
Contributor

@wxsIcey I added this IR on XPU, please review wxsIcey#12

Thanks for your work, I will merge it.

@wxsIcey, I found that the Triton provider fails to pass IR lowering. I attempted to fix this by registering the Triton kernel as a torch custom op. Please take a look wxsIcey#14, thanks

@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Apr 22, 2026

@wxsIcey I added this IR on XPU, please review wxsIcey#12

Thanks for your work, I will merge it.

@wxsIcey, I found that the Triton provider fails to pass IR lowering. I attempted to fix this by registering the Triton kernel as a torch custom op. Please take a look wxsIcey#14, thanks

This is a known issue. make_fx does not handle the triton operator correctly. You can see the discussion in #38798. We need to figure out why wrap_trion cannot be used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

intel-gpu Related to Intel GPU needs-rebase nvidia rocm Related to AMD ROCm

Projects

Status: Todo
Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants