Skip to content

[ROCm] Guard group quant RMS norm fusion patterns#30239

Merged
yeqcharlotte merged 1 commit intovllm-project:mainfrom
yeqcharlotte:export-D88608586
Dec 8, 2025
Merged

[ROCm] Guard group quant RMS norm fusion patterns#30239
yeqcharlotte merged 1 commit intovllm-project:mainfrom
yeqcharlotte:export-D88608586

Conversation

@yeqcharlotte
Copy link
Copy Markdown
Collaborator

@yeqcharlotte yeqcharlotte commented Dec 8, 2025

Summary:
Fix AMD compilation failure for DeepSeek models introduced in #27883.

The issue was that RMSNormQuantFusionPass unconditionally creates
FusedAddRMSNormGroupQuantPattern and RMSNormGroupQuantPattern for
group quantization (GroupShape 64 and 128), but the underlying C++
operation per_token_group_fp8_quant is only available on CUDA
(wrapped in #ifndef USE_ROCM in torch_bindings.cpp).

On AMD platforms, this caused an assertion failure:

AssertionError: unsupported quantization scheme QuantKey(f8e4m3fnuz,scale(f32,dynamic,GroupShape(row=1, col=128)),symmetric)

The fix guards the creation of group quant patterns with
current_platform.is_cuda(), matching the guard used for registering
these keys in QUANT_OPS.

Test Plan:
Ran e2e correctness/perf tests for DeepSeek on AMD. Will paste results after it's completed.
Perf

Ran 40/40 requests in 48.11s
Success rate:        100.00%
QPS:                 0.83
Avg latency:         4.566s
Avg TTFT (client):   553.11ms
P50 TTFT (client):   551.61ms
P99 TTFT (client):   569.35ms
Avg TTIT (client):   40.13ms
P50 TTIT (client):   40.78ms
P99 TTIT (client):   41.43ms
Avg TTFT (server):   825.17ms
Avg TTIT (server):   36.83ms
Avg prefill len:     5823.10 tokens
P50 prefill len:     5824.00 tokens
P99 prefill len:     5886.00 tokens
Avg decode len:      100.00 tokens
P50 decode len:      100.00 tokens
P99 decode len:      100.00 tokens
Peak TPGS: 9.875

Correctness

[2025-12-07 22:38:16,439] [rank 0] [INFO] Evaluation results on task gsm8k.8_shot.1_gen: em: 0.980000 | f1: 0.980000 | em_maj1@1: 0.980000 | f1_maj1@1: 0.980000

Will also wait for external CI

Differential Revision:
D88608586

Summary:
Fix AMD compilation failure for DeepSeek models introduced in vllm-project#27883.

The issue was that RMSNormQuantFusionPass unconditionally creates
FusedAddRMSNormGroupQuantPattern and RMSNormGroupQuantPattern for
group quantization (GroupShape 64 and 128), but the underlying C++
operation per_token_group_fp8_quant is only available on CUDA
(wrapped in #ifndef USE_ROCM in torch_bindings.cpp).

On AMD platforms, this caused an assertion failure:
  AssertionError: unsupported quantization scheme QuantKey(f8e4m3fnuz,scale(f32,dynamic,GroupShape(row=1, col=128)),symmetric)

The fix guards the creation of group quant patterns with
current_platform.is_cuda(), matching the guard used for registering
these keys in QUANT_OPS.

Test Plan:
Waiting for this deepseek job on amd to complete: https://www.internalfb.com/vanguard/serving_test_cases/1967790977283741

Will also wait for external CI

Differential Revision:
D88608586

Privacy Context Container: L1370295
@mergify mergify bot added the nvidia label Dec 8, 2025
@yeqcharlotte yeqcharlotte changed the title Guard group quant RMS norm fusion patterns on CUDA platforms [ROCm] Guard group quant RMS norm fusion patterns Dec 8, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Dec 8, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a compilation failure on AMD platforms by guarding CUDA-specific group quantization fusion patterns. The change is logical and aligns with existing patterns in the codebase for platform-specific operations. I've added one suggestion to refactor the new code to reduce duplication and improve maintainability.

Comment on lines +495 to +511
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)

# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)

FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)

# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this correctly guards the CUDA-specific patterns, there's an opportunity to reduce code duplication. The logic for group shapes 128 and 64 is identical. You can use a loop to register these patterns, which will make the code more concise and easier to maintain.

                for group_size in [128, 64]:
                    group_shape = GroupShape(1, group_size)
                    FusedAddRMSNormGroupQuantPattern(
                        epsilon, FP8_DTYPE, group_shape=group_shape
                    ).register(self.patterns)

                    # Fuse rms_norm + fp8 group quant
                    RMSNormGroupQuantPattern(
                        epsilon, FP8_DTYPE, group_shape=group_shape
                    ).register(self.patterns)

Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Dec 8, 2025
@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 8, 2025
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 8, 2025

Waiting for the test results from @yeqcharlotte before merging.

@yeqcharlotte
Copy link
Copy Markdown
Collaborator Author

Waiting for the test results from @yeqcharlotte before merging.

added test results

@yeqcharlotte yeqcharlotte enabled auto-merge (squash) December 8, 2025 07:22
@yeqcharlotte yeqcharlotte merged commit eb1051f into vllm-project:main Dec 8, 2025
56 of 60 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Dec 8, 2025
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

meta-exported nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants