From f09d137c5f8488618088d8d242922071a7339183 Mon Sep 17 00:00:00 2001 From: "Charlotte (Ye) Qi" Date: Sun, 7 Dec 2025 22:15:33 -0800 Subject: [PATCH] Guard group quant RMS norm fusion patterns on CUDA platforms Summary: Fix AMD compilation failure for DeepSeek models introduced in https://github.com/vllm-project/vllm/pull/27883. The issue was that RMSNormQuantFusionPass unconditionally creates FusedAddRMSNormGroupQuantPattern and RMSNormGroupQuantPattern for group quantization (GroupShape 64 and 128), but the underlying C++ operation per_token_group_fp8_quant is only available on CUDA (wrapped in #ifndef USE_ROCM in torch_bindings.cpp). On AMD platforms, this caused an assertion failure: AssertionError: unsupported quantization scheme QuantKey(f8e4m3fnuz,scale(f32,dynamic,GroupShape(row=1, col=128)),symmetric) The fix guards the creation of group quant patterns with current_platform.is_cuda(), matching the guard used for registering these keys in QUANT_OPS. Test Plan: Waiting for this deepseek job on amd to complete: https://www.internalfb.com/vanguard/serving_test_cases/1967790977283741 Will also wait for external CI Differential Revision: D88608586 Privacy Context Container: L1370295 --- vllm/compilation/fusion.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index de083a2e5e3c..a7e6a69e64c9 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -490,23 +490,25 @@ def __init__(self, config: VllmConfig): # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: # Fuse fused_add_rms_norm + fp8 group quant - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) - - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) - - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) - - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) + # Only register group quant patterns on CUDA where the C++ op exists + if current_platform.is_cuda(): + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) + + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(