[vLLM IR][RMSNorm] Port Mixer2RMSNormGated to vLLM IR Ops#39262
[vLLM IR][RMSNorm] Port Mixer2RMSNormGated to vLLM IR Ops#39262wxsIcey wants to merge 5 commits into
Conversation
Signed-off-by: Icey <1790571317@qq.com>
There was a problem hiding this comment.
Code Review
This pull request introduces the mixer2_rms_norm_gated IR operator and its Triton implementation, refactoring the Mamba Mixer2 layer to utilize this new operator. The changes include updating kernel configurations, platform-specific priorities, and the underlying Triton kernel for gated layer normalization. A potential TypeError was identified in the native implementation of mixer2_rms_norm_gated when the weight is None, and a code suggestion was provided to handle this case.
Signed-off-by: Icey <1790571317@qq.com>
|
This pr has the same issue as #38798: when using I hope to get some help. cc@zou3519 |
Signed-off-by: Icey <1790571317@qq.com>
|
@wxsIcey I think you're missing vllm/vllm/compilation/passes/ir/lowering_pass.py Lines 95 to 97 in 06de5e1 |
It seems to have no effect. |
tomeras91
left a comment
There was a problem hiding this comment.
Left a nit comment
Also - do you plan to post benchmark results before/after this change? I understand we don't expect any perf diff (IR ops still go through torch.compile), but would like to verify that since this PR changes the code path significantly..
| """Priority list for vllm.ir.ops.rms_norm""" | ||
|
|
||
| mixer2_rms_norm_gated: list[str] = Field(default_factory=list) | ||
| """Priority list for vllm.ir.ops.rms_norm_gated""" |
There was a problem hiding this comment.
nit: docstring should have vllm.ir.ops.mixer2_rms_norm_gated instead of vllm.ir.ops.rms_norm_gated.
(Or change the op name to rms_norm_gated)
There was a problem hiding this comment.
Thanks for your review. I will change it.
|
This pull request has merge conflicts that must be resolved before it can be |
Thanks for the review. This PR is currently on low priority. I will add benchmark tests once it's ready. |
Signed-off-by: Chaojun Zhang <chaojun.zhang@intel.com>
|
@wxsIcey I added this IR on XPU, please review wxsIcey#12 |
Thanks for your work, I will merge it. |
@wxsIcey, I found that the Triton provider fails to pass IR lowering. I attempted to fix this by registering the Triton kernel as a torch custom op. Please take a look wxsIcey#14, thanks |
This is a known issue. |
Purpose
Register
Mixer2RMSNormGatedas a vllm IR op and rewriteMixer2RMSNormGated.forward_nativeto dispatch correctly across all tensor-parallel configurations.The implementation handles four cases:
n_groups=1,tp_size>1n_groups=1,tp_size=1n_groups>1,n_groups % tp_size != 0n_groups>1,n_groups % tp_size == 0Cases 2 and 4 require no collective communication and are handled by the IR op. Cases 1 and 3 require cross-rank communication that cannot be fused into a single kernel, so they are handled with explicit AllReduce / AllGather before calling into local computation.
Because
forward_nativenow covers all cases (including the optimized IR op paths for cases 2 and 4),forward_cudais fully redundant and can be removed.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.