[ROCm] Guard group quant RMS norm fusion patterns#30239
[ROCm] Guard group quant RMS norm fusion patterns#30239yeqcharlotte merged 1 commit intovllm-project:mainfrom
Conversation
Summary: Fix AMD compilation failure for DeepSeek models introduced in vllm-project#27883. The issue was that RMSNormQuantFusionPass unconditionally creates FusedAddRMSNormGroupQuantPattern and RMSNormGroupQuantPattern for group quantization (GroupShape 64 and 128), but the underlying C++ operation per_token_group_fp8_quant is only available on CUDA (wrapped in #ifndef USE_ROCM in torch_bindings.cpp). On AMD platforms, this caused an assertion failure: AssertionError: unsupported quantization scheme QuantKey(f8e4m3fnuz,scale(f32,dynamic,GroupShape(row=1, col=128)),symmetric) The fix guards the creation of group quant patterns with current_platform.is_cuda(), matching the guard used for registering these keys in QUANT_OPS. Test Plan: Waiting for this deepseek job on amd to complete: https://www.internalfb.com/vanguard/serving_test_cases/1967790977283741 Will also wait for external CI Differential Revision: D88608586 Privacy Context Container: L1370295
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a compilation failure on AMD platforms by guarding CUDA-specific group quantization fusion patterns. The change is logical and aligns with existing patterns in the codebase for platform-specific operations. I've added one suggestion to refactor the new code to reduce duplication and improve maintainability.
| FusedAddRMSNormGroupQuantPattern( | ||
| epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) | ||
| ).register(self.patterns) | ||
|
|
||
| # Fuse rms_norm + fp8 group quant | ||
| RMSNormGroupQuantPattern( | ||
| epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) | ||
| ).register(self.patterns) | ||
|
|
||
| FusedAddRMSNormGroupQuantPattern( | ||
| epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) | ||
| ).register(self.patterns) | ||
|
|
||
| # Fuse rms_norm + fp8 group quant | ||
| RMSNormGroupQuantPattern( | ||
| epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) | ||
| ).register(self.patterns) |
There was a problem hiding this comment.
While this correctly guards the CUDA-specific patterns, there's an opportunity to reduce code duplication. The logic for group shapes 128 and 64 is identical. You can use a loop to register these patterns, which will make the code more concise and easier to maintain.
for group_size in [128, 64]:
group_shape = GroupShape(1, group_size)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=group_shape
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=group_shape
).register(self.patterns)|
Waiting for the test results from @yeqcharlotte before merging. |
added test results |
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Summary:
Fix AMD compilation failure for DeepSeek models introduced in #27883.
The issue was that RMSNormQuantFusionPass unconditionally creates
FusedAddRMSNormGroupQuantPattern and RMSNormGroupQuantPattern for
group quantization (GroupShape 64 and 128), but the underlying C++
operation per_token_group_fp8_quant is only available on CUDA
(wrapped in #ifndef USE_ROCM in torch_bindings.cpp).
On AMD platforms, this caused an assertion failure:
The fix guards the creation of group quant patterns with
current_platform.is_cuda(), matching the guard used for registering
these keys in QUANT_OPS.
Test Plan:
Ran e2e correctness/perf tests for DeepSeek on AMD. Will paste results after it's completed.
Perf
Correctness
Will also wait for external CI
Differential Revision:
D88608586