Skip to content

[Bugfix] Temporarily disable group quant rms norm fusion#30273

Draft
ElizaWszola wants to merge 4 commits intovllm-project:mainfrom
neuralmagic:temp-disable-rmsgroup-quant-fusion
Draft

[Bugfix] Temporarily disable group quant rms norm fusion#30273
ElizaWszola wants to merge 4 commits intovllm-project:mainfrom
neuralmagic:temp-disable-rmsgroup-quant-fusion

Conversation

@ElizaWszola
Copy link
Copy Markdown
Contributor

@ElizaWszola ElizaWszola commented Dec 8, 2025

Temporarily disable changed introduced by #27883 until we cleanly resolve the issue with model config assertions on VL models.

Testing:

Unit testing with tests/compile/test_fusion.py.

Tested e2e with:

  • Qwen/Qwen3-30B-A3B-FP8
  • Qwen/Qwen3-VL-4B-Instruct
  • Qwen/Qwen3-VL-2B-Instruct-FP8

All tests have been done with both VLLM_USE_DEEP_GEMM=0 and VLLM_USE_DEEP_GEMM=1.

Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 8, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request temporarily disables group quantization RMS norm fusion to address an issue with model config assertions on Vision-Language models. The changes involve commenting out the registration of certain fusion patterns and removing the logic that determines scale layout for group quantization.

My review identifies a potential incompleteness in disabling the feature. While FusedAddRMSNormGroupQuantPattern is disabled, RMSNormGroupQuantPattern for the same group shapes remains active, even though the corresponding tests are disabled. I've added comments to suggest disabling these patterns as well to prevent untested code paths from being active.

Comment on lines +480 to +482
# FusedAddRMSNormGroupQuantPattern(
# epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
# ).register(self.patterns)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While FusedAddRMSNormGroupQuantPattern for group shape (1, 128) is correctly commented out, the corresponding RMSNormGroupQuantPattern on lines 485-487 remains active. Given that tests for this group shape are disabled in tests/compile/test_fusion.py, this leaves an untested code path. To fully disable the group quantization fusion as intended by this PR, RMSNormGroupQuantPattern for this group shape should also be commented out.

Comment on lines +490 to +492
# FusedAddRMSNormGroupQuantPattern(
# epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
# ).register(self.patterns)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the case for group shape 128, FusedAddRMSNormGroupQuantPattern for group shape (1, 64) is commented out, but RMSNormGroupQuantPattern on lines 495-497 remains active. To ensure complete disabling of this feature and avoid untested code, this pattern should also be commented out.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +128 to +130
# TODO: add use_col_major_scales and use_e8m0 to MatcherQuantFP8
# after the issue with group quant rms fusion for VL models is fixed
self.quant_matcher = MatcherQuantFP8(key.quant)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Group quant fusion ignores col‑major/e8m0 settings

RMSNormQuantPattern now always builds MatcherQuantFP8 without forwarding the use_col_major_scales or use_e8m0 flags, yet the RMSNormGroupQuantPattern registrations for group sizes 64/128 are still active. On platforms where block FP8 paths expect column-major scales or e8m0 (CUTLASS or DeepGEMM), the fused rms_norm_per_block_quant kernels will now be invoked with row-major scales and e4m3, producing mis-quantized outputs on those GPUs rather than just disabling the fusion.

Useful? React with 👍 / 👎.

Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify bot removed the needs-rebase label Dec 8, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@ElizaWszola ElizaWszola changed the title Temporarily disable group quant rms norm fusion [Bugfix] Temporarily disable group quant rms norm fusion Dec 8, 2025
@ElizaWszola
Copy link
Copy Markdown
Contributor Author

Note that when testing on Hopper, I currently replace

q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
            input_2d,
            group_size=self.act_quant_group_shape.col,
            use_ue8m0=True,
        )

with

q_input, input_scale = self.deepgemm_input_quant_op(input_2d)

in fp8_utils.py to circumvent a packed deepgemm bug.

Signed-off-by: ElizaWszola <ewszola@redhat.com>
@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 11, 2025
@zou3519
Copy link
Copy Markdown
Collaborator

zou3519 commented Dec 11, 2025

To check, @ElizaWszola, you do want this to go in first, right? (when compared with #30244)

@ElizaWszola
Copy link
Copy Markdown
Contributor Author

ElizaWszola commented Dec 11, 2025

@zou3519 I'd rather merge #30244 at this point. I had made #30273 in case #30244 wouldn't be easy to implement, but #30244 is working well now. I can still keep #30273 as a draft for now, just to be safe.

Sorry for confusion!

@ElizaWszola ElizaWszola marked this pull request as draft December 11, 2025 19:58
@zou3519 zou3519 removed the ready ONLY add when PR is ready to merge/full CI is needed label Dec 11, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 16, 2025
@mergify mergify bot added the bug Something isn't working label Jan 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working needs-rebase

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants