From b714a1c10fb83246e195f32ba414eb5d9d5dba80 Mon Sep 17 00:00:00 2001 From: rickychen-infinirc Date: Sat, 3 Jan 2026 12:40:59 +0800 Subject: [PATCH 1/4] [Bugfix] Fix CPU backend gibberish output with --enforce-eager Add explicit forward_cpu methods to CustomOp subclasses that delegate to forward_native, ensuring CPU backend uses PyTorch native implementation instead of buggy CPU C++ kernels when custom_ops='all'. Classes fixed: - RotaryEmbedding (rotary_embedding/base.py) - RMSNorm (layernorm.py) - GemmaRMSNorm (layernorm.py) - RMSNormGated (layernorm.py) Fixes #31626 Signed-off-by: rickychen-infinirc --- vllm/model_executor/layers/layernorm.py | 23 +++++++++++++++++++ .../layers/rotary_embedding/base.py | 10 ++++++++ 2 files changed, 33 insertions(+) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 8cc374ac9155..017d4f1332bf 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -247,6 +247,15 @@ def forward_xpu( self.variance_epsilon, ) + def forward_cpu( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Use PyTorch native implementation for CPU to ensure correctness + # The CPU C++ kernel may have inconsistencies with the native impl + return self.forward_native(x, residual) + def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" s += f", eps={self.variance_epsilon}" @@ -320,6 +329,14 @@ def forward_cuda( self._is_compiled = True return self.forward_native(x, residual) + def forward_cpu( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Use PyTorch native implementation for CPU to ensure correctness + return self.forward_native(x, residual) + @CustomOp.register("rms_norm_gated") class RMSNormGated(CustomOp): @@ -423,6 +440,12 @@ def forward_cuda( norm_before_gate=self.norm_before_gate, ) + def forward_cpu( + self, x: torch.Tensor, z: torch.Tensor | None = None + ) -> torch.Tensor: + # Use PyTorch native implementation for CPU to ensure correctness + return self.forward_native(x, z) + class LayerNorm(nn.Module): """ diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 7e83ea9a1355..3dd75140a0af 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -250,6 +250,16 @@ def forward_xpu( ) return query, key + def forward_cpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Use PyTorch native implementation for CPU to ensure correctness + # The CPU C++ kernel may have inconsistencies with the native impl + return self.forward_native(positions, query, key) + def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" From b75b49ab63967a4ffaa1c5f74cc696e8a45018f0 Mon Sep 17 00:00:00 2001 From: rickychen-infinirc Date: Sat, 3 Jan 2026 18:57:35 +0800 Subject: [PATCH 2/4] [Bugfix] Fix CPU backend gibberish output with --enforce-eager Add explicit forward_cpu methods to CustomOp subclasses that delegate to forward_native, ensuring CPU backend uses PyTorch native implementation instead of buggy CPU C++ kernels when custom_ops='all'. Classes fixed: - RotaryEmbedding (rotary_embedding/base.py) - RMSNorm (layernorm.py) - GemmaRMSNorm (layernorm.py) - RMSNormGated (layernorm.py) Fixes #31626 Signed-off-by: rickychen-infinirc --- vllm/model_executor/layers/rotary_embedding/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 3dd75140a0af..aa41dbad01dd 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -257,7 +257,6 @@ def forward_cpu( key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # Use PyTorch native implementation for CPU to ensure correctness - # The CPU C++ kernel may have inconsistencies with the native impl return self.forward_native(positions, query, key) def extra_repr(self) -> str: From 0bd69effc359114af1c511e3d019850bf604b840 Mon Sep 17 00:00:00 2001 From: rickychen-infinirc Date: Mon, 5 Jan 2026 17:05:08 +0800 Subject: [PATCH 3/4] [Bugfix] Refactor CPU forward dispatch to use native impl by default - Change CustomOp.forward_cpu() default to forward_native instead of forward_cuda, as most CPU custom kernels are not performance-critical and can have compatibility issues - Remove redundant forward_cpu() from RMSNorm, GemmaRMSNorm, RMSNormGated since they now inherit the base class behavior Signed-off-by: rickychen-infinirc --- vllm/model_executor/custom_op.py | 5 ++-- vllm/model_executor/layers/layernorm.py | 25 ------------------- .../layers/rotary_embedding/base.py | 17 +++++++++++-- 3 files changed, 18 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 66250f816f45..a80768c33a51 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -67,8 +67,9 @@ def forward_xpu(self, *args, **kwargs): return self.forward_native(*args, **kwargs) def forward_cpu(self, *args, **kwargs): - # By default, we assume that CPU ops are compatible with CUDA ops. - return self.forward_cuda(*args, **kwargs) + # By default, we assume that CPU ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) def forward_tpu(self, *args, **kwargs): # By default, we assume that TPU ops are compatible with the diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 017d4f1332bf..627421eecfdc 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -247,15 +247,6 @@ def forward_xpu( self.variance_epsilon, ) - def forward_cpu( - self, - x: torch.Tensor, - residual: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # Use PyTorch native implementation for CPU to ensure correctness - # The CPU C++ kernel may have inconsistencies with the native impl - return self.forward_native(x, residual) - def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" s += f", eps={self.variance_epsilon}" @@ -329,15 +320,6 @@ def forward_cuda( self._is_compiled = True return self.forward_native(x, residual) - def forward_cpu( - self, - x: torch.Tensor, - residual: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # Use PyTorch native implementation for CPU to ensure correctness - return self.forward_native(x, residual) - - @CustomOp.register("rms_norm_gated") class RMSNormGated(CustomOp): """RMS Normalization with optional gating. @@ -440,13 +422,6 @@ def forward_cuda( norm_before_gate=self.norm_before_gate, ) - def forward_cpu( - self, x: torch.Tensor, z: torch.Tensor | None = None - ) -> torch.Tensor: - # Use PyTorch native implementation for CPU to ensure correctness - return self.forward_native(x, z) - - class LayerNorm(nn.Module): """ Layer Normalization. diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index aa41dbad01dd..bd82728ed15f 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -256,8 +256,21 @@ def forward_cpu( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - # Use PyTorch native implementation for CPU to ensure correctness - return self.forward_native(positions, query, key) + from vllm import _custom_ops as ops + + self._match_cos_sin_cache_dtype(query) + + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" From 86b557a2ffcc9501631921d0e6ab526ce42f67ec Mon Sep 17 00:00:00 2001 From: rickychen-infinirc Date: Mon, 5 Jan 2026 23:11:48 +0800 Subject: [PATCH 4/4] style: add blank lines for PEP8 compliance Signed-off-by: rickychen-infinirc --- vllm/model_executor/layers/layernorm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 627421eecfdc..8cc374ac9155 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -320,6 +320,7 @@ def forward_cuda( self._is_compiled = True return self.forward_native(x, residual) + @CustomOp.register("rms_norm_gated") class RMSNormGated(CustomOp): """RMS Normalization with optional gating. @@ -422,6 +423,7 @@ def forward_cuda( norm_before_gate=self.norm_before_gate, ) + class LayerNorm(nn.Module): """ Layer Normalization.