Skip to content

[Bugfix] Fix fusion for VL models#30244

Merged
DarkLight1337 merged 7 commits intovllm-project:mainfrom
neuralmagic:fix-vl-models-fusion
Dec 14, 2025
Merged

[Bugfix] Fix fusion for VL models#30244
DarkLight1337 merged 7 commits intovllm-project:mainfrom
neuralmagic:fix-vl-models-fusion

Conversation

@ElizaWszola
Copy link
Copy Markdown
Contributor

@ElizaWszola ElizaWszola commented Dec 8, 2025

A fix for #27883 so the fusion code doesn't break VL models.

Testing:

All tests have been run on Hopper GPU with both VLLM_USE_DEEP_GEMM=0 and VLLM_USE_DEEP_GEMM=1. Note that DeepGemm runs currently require changes from #30336 to run.

E2E tests:

  • Qwen/Qwen3-30B-A3B-FP8
  • Qwen/Qwen3-VL-4B-Instruct
  • Qwen/Qwen3-VL-2B-Instruct-FP8
  • Qwen/Qwen3-VL-30B-A3B-Instruct-FP8

All FP8 models have been manually verified to produce the fused group quant rms norm kernel during compilation.

Unit test:

tests/compile/test_fusion.py

Signed-off-by: ElizaWszola <ewszola@redhat.com>
@ElizaWszola ElizaWszola changed the title Fix fusion for VL models [Bugfix] Fix fusion for VL models Dec 8, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the FP8 quantization fusion logic to make it more robust for Vision-Language (VL) models. The change correctly moves the decision-making for using deepgemm and column-major scales from a static configuration-based approach to a dynamic one that uses the weight tensor's shape at runtime. This is a solid improvement. I've identified one area of code duplication introduced in this change that should be addressed to improve maintainability.

Comment on lines +272 to +276
using_deepgemm = should_use_deepgemm_for_fp8_linear(
self.model_dtype,
weight,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic to determine using_deepgemm and use_col_major_scales is duplicated in vllm/compilation/fusion.py in FusedAddRMSNormGroupQuantPattern.replacement (line 267) and RMSNormGroupQuantPattern.replacement (line 331). To improve maintainability and prevent potential bugs from inconsistent updates, consider centralizing this logic. A helper method within the MatcherQuantFP8 class could be a good way to encapsulate this logic, which can then be called from all three locations.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +272 to +275
using_deepgemm = should_use_deepgemm_for_fp8_linear(
self.model_dtype,
weight,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip deepgemm check for 1D RMSNorm weights

During the FP8 group quantization path the matcher now feeds the RMSNorm weight tensor into should_use_deepgemm_for_fp8_linear, but that helper assumes a 2D linear weight and unconditionally accesses weight.shape[1]. RMSNorm weights are 1D, so when the pattern is traced (or the replacement runs) this call raises IndexError: tuple index out of range, preventing the fused RMSNorm+quant pattern from compiling for group-quantized models such as VL configs.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead just create multiple patterns, one for column scales and one for row scales?

@ElizaWszola
Copy link
Copy Markdown
Contributor Author

Can we instead just create multiple patterns, one for column scales and one for row scales?

@ProExpertProg We will also need one for e8m0 if we don't want to check if we use deepgemm during matching. Should I put all this information in QuantKey (it will result in 4x combinations for groupwise keys in quant_utils.py) or pass it like we're passing epsilon?

@ProExpertProg
Copy link
Copy Markdown
Collaborator

I think pass it like we're passing epsilon for now, it doesn't seem like something to go in QuantKey at least for now

@ElizaWszola
Copy link
Copy Markdown
Contributor Author

I think pass it like we're passing epsilon for now, it doesn't seem like something to go in QuantKey at least for now

@ProExpertProg I'm running into some duplicate pattern errors with this approach. Because it's a breaking bug (it breaks all VL models), would it be ok to land this PR as is and then make a follow-up one with cleaner matching?

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 9, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify bot removed the needs-rebase label Dec 9, 2025
@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Dec 9, 2025

I have just tried this PR atop of 67475a6 but got the following error for Qwen3 MoE:

...
  File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/models/qwen3_vl.py", line 1651, in forward
    hidden_states = self.language_model.model(
...
  File "/app/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper
    raise e.with_traceback(None) from e.__cause__  # User compiler error

torch._dynamo.exc.Unsupported: can't handle functions not implemented in python

from user code:
  File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/models/qwen3_vl_moe.py", line 116, in forward
    hidden_states, residual = layer(
...
  File "/app/.venv/lib/python3.12/site-packages/vllm/utils/deep_gemm.py", line 63, in is_deep_gemm_e8m0_used
    if not is_deep_gemm_supported():
...
  File "/app/.venv/lib/python3.12/site-packages/vllm/platforms/interface.py", line 295, in is_device_capability
    current_capability = cls.get_device_capability(device_id=device_id)

Launch script:

vllm serve qwen/qwen3-vl-235b-a22b-instruct-fp8 --tensor-parallel-size 8 --enable-expert-parallel --all2all-backend deepep_low_latency --mm-encoder-tp-mode data --async-scheduling

Qwen3-vl-32B(non-MoE) is working like a charm BTW.

@ElizaWszola
Copy link
Copy Markdown
Contributor Author

@cjackal Looking into this now. In the meantime, you can also try disabling the group quant rms norm fusion altogether: #30273

@ElizaWszola
Copy link
Copy Markdown
Contributor Author

ElizaWszola commented Dec 9, 2025

@cjackal This looks like an unrelated issue, can you try applying changes from this PR on the top of the current one? #30336 (I tested it with Qwen/Qwen3-VL-30B-A3B-Instruct-FP8, it let me run inference on it till completion)

@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Dec 9, 2025

@cjackal This looks like an unrelated issue, can you try applying changes from this PR on the top of the current one? #30336 (I tested it with Qwen/Qwen3-VL-30B-A3B-Instruct-FP8, it let me run inference on it till completion)

Yikes, I mistakenly tested the BF16 checkpoint for qwen3-vl-32b. #30336 indeed solves the dynamo compilation error, huge thanks!

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Dec 10, 2025

Hi, I'd just like to mention that current main (since #28480 merged) is incompatible with deepgemm, see #28480 (comment). I have checked that commit 00e5cbb with this PR & #30336 cherry-picked is working well for FP8 llama4 and qwen3vl moe.

@ElizaWszola
Copy link
Copy Markdown
Contributor Author

ElizaWszola commented Dec 10, 2025

@cjackal I've been seeing the same error, thanks for identifying the root PR!

@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Dec 10, 2025

@cjackal I've been seeing the same error, thanks for identifying the root PR!

It looks like I'm able to run the model with deepgemm when I replace m.weight_scale with m.weight_scale_inv in _extract_data_from_linear_base_module() in deep_gemm_warmup.py, but I'm not sure atp if this is a safe solution to the problem

Fortunately there's ongoing bugfix at #30399 🚀

Signed-off-by: ElizaWszola <ewszola@redhat.com>
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
def test_rms_group_quant(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a new test, just enable rmsnorm+quant fusion in the other tests!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning was to test a block-quant model separately, so I don't have to figure out the number of fusions for all models (with the current regex-based counting, block and non-blocks quant rms fusions are counted together), and also so I don't have to make the code inside tests more complicated (block quant rms fusions are supported only when fp8 is enabled).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should address this in a follow up also to reduce the total test running times

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 12, 2025
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@DarkLight1337 DarkLight1337 merged commit 994acec into vllm-project:main Dec 14, 2025
50 checks passed
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Dec 15, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
joa-stdn pushed a commit to joa-stdn/vllm that referenced this pull request Dec 15, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Joachim Studnia <joachim@mistral.ai>
teddygood pushed a commit to teddygood/vllm that referenced this pull request Dec 16, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
DarkLight1337 added a commit to DarkLight1337/vllm that referenced this pull request Dec 16, 2025
…ect#30396

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
vllm-bot pushed a commit that referenced this pull request Dec 17, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Dec 17, 2025
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…ect#30396 (vllm-project#30787)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…ect#30396 (vllm-project#30787)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: ElizaWszola <ewszola@redhat.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants