From 196b78c7b750592078c677841c865a04791beafc Mon Sep 17 00:00:00 2001 From: Taksh Date: Mon, 6 Apr 2026 17:33:08 +0530 Subject: [PATCH 1/3] Fix RMSNorm hidden_size validation crash for weightless norms When `replace_rms_norm_class` replaces RMSNorm modules, it passes `hidden_size` from the LM config even for vision encoder norms that have a different dimension. For norms like Gemma4's `v_norm` (`with_scale=False`), no `weight` tensor is registered, so the hidden_size correction code (which reads `weight.shape`) is skipped, leaving the wrong hidden_size. The `forward_static` validation then raises `ValueError: Expected hidden_size to be 5376, but found: 72`. Skip the hidden_size validation when `weight is None`, since a weightless RMSNorm just computes `x / sqrt(mean(x^2) + eps)` and does not depend on hidden_size. Fixes #39061 Co-authored-by: Claude Opus 4.6 (1M context) --- vllm/model_executor/layers/layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7b222f9c431b..62e7f49d9948 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -207,7 +207,7 @@ def forward_static( x = x + residual residual = x.to(orig_dtype) - if x.shape[-1] != hidden_size: + if weight is not None and x.shape[-1] != hidden_size: raise ValueError( f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" ) From bb3779173439aab09fa809a9b27b134dd527c931 Mon Sep 17 00:00:00 2001 From: Taksh Date: Tue, 7 Apr 2026 14:54:53 +0530 Subject: [PATCH 2/3] Fix weightless RMSNorm in forward_cuda and forward_hip methods Address remaining issues with has_weight=False norms: - forward_static: update local hidden_size from input shape when weight is None so variance_size_override validation uses the correct value - forward_cuda/forward_hip: pass None to ir.ops.rms_norm when has_weight is False, and create ones tensor for fused_add_rms_norm and rms_norm_batch_invariant which require a weight tensor Co-Authored-By: Claude Opus 4.6 (1M context) --- vllm/model_executor/layers/layernorm.py | 26 +++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 62e7f49d9948..8ced9cfdfcd2 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -212,6 +212,10 @@ def forward_static( f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" ) + # When weight is None (weightless norm), use input dim for validation + if weight is None: + hidden_size = x.shape[-1] + if variance_size_override is None: x_var = x else: @@ -266,7 +270,10 @@ def forward_cuda( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is None and not envs.VLLM_BATCH_INVARIANT: return ir.ops.rms_norm( - x, self.weight.data, self.variance_epsilon, self.variance_size_override + x, + self.weight.data if self.has_weight else None, + self.variance_epsilon, + self.variance_size_override, ) if self.variance_size_override is not None: @@ -313,13 +320,15 @@ def forward_cuda( ) return x, residual + weight = self.weight.data if self.has_weight else torch.ones( + x.shape[-1], device=x.device, dtype=x.dtype) if residual is not None: return fused_add_rms_norm( - x, residual, self.weight.data, self.variance_epsilon + x, residual, weight, self.variance_epsilon ) else: assert envs.VLLM_BATCH_INVARIANT - return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon) + return rms_norm_batch_invariant(x, weight, self.variance_epsilon) def forward_hip( self, @@ -328,19 +337,24 @@ def forward_hip( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is None and not envs.VLLM_BATCH_INVARIANT: return ir.ops.rms_norm( - x, self.weight.data, self.variance_epsilon, self.variance_size_override + x, + self.weight.data if self.has_weight else None, + self.variance_epsilon, + self.variance_size_override, ) if self.variance_size_override is not None: return self.forward_native(x, residual) + weight = self.weight.data if self.has_weight else torch.ones( + x.shape[-1], device=x.device, dtype=x.dtype) if residual is not None: return self.rocm_norm_func_with_add( - x, residual, self.weight.data, self.variance_epsilon + x, residual, weight, self.variance_epsilon ) else: assert envs.VLLM_BATCH_INVARIANT - return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon) + return rms_norm_batch_invariant(x, weight, self.variance_epsilon) def forward_xpu( self, From c62c3e58eacd4dbc050e405c56aa3ee2301f03dd Mon Sep 17 00:00:00 2001 From: Taksh Date: Wed, 8 Apr 2026 13:09:50 +0530 Subject: [PATCH 3/3] Fix forward_cuda/forward_hip to fall back to forward_native for weightless norms Instead of allocating a dummy torch.ones tensor on every call when has_weight is False, fall back to forward_native which properly handles weightless norms via forward_static (where hidden_size is derived from the input tensor shape). Co-Authored-By: Claude Opus 4.6 (1M context) --- vllm/model_executor/layers/layernorm.py | 28 ++++++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 8ced9cfdfcd2..37dc4e132fc2 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -320,15 +320,21 @@ def forward_cuda( ) return x, residual - weight = self.weight.data if self.has_weight else torch.ones( - x.shape[-1], device=x.device, dtype=x.dtype) + # When has_weight is False, fall back to forward_native which + # handles weightless norms via forward_static without allocating + # a dummy ones tensor on every call. + if not self.has_weight: + return self.forward_native(x, residual) + if residual is not None: return fused_add_rms_norm( - x, residual, weight, self.variance_epsilon + x, residual, self.weight.data, self.variance_epsilon ) else: assert envs.VLLM_BATCH_INVARIANT - return rms_norm_batch_invariant(x, weight, self.variance_epsilon) + return rms_norm_batch_invariant( + x, self.weight.data, self.variance_epsilon + ) def forward_hip( self, @@ -346,15 +352,21 @@ def forward_hip( if self.variance_size_override is not None: return self.forward_native(x, residual) - weight = self.weight.data if self.has_weight else torch.ones( - x.shape[-1], device=x.device, dtype=x.dtype) + # When has_weight is False, fall back to forward_native which + # handles weightless norms via forward_static without allocating + # a dummy ones tensor on every call. + if not self.has_weight: + return self.forward_native(x, residual) + if residual is not None: return self.rocm_norm_func_with_add( - x, residual, weight, self.variance_epsilon + x, residual, self.weight.data, self.variance_epsilon ) else: assert envs.VLLM_BATCH_INVARIANT - return rms_norm_batch_invariant(x, weight, self.variance_epsilon) + return rms_norm_batch_invariant( + x, self.weight.data, self.variance_epsilon + ) def forward_xpu( self,