diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7b222f9c431b..37dc4e132fc2 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -207,11 +207,15 @@ def forward_static( x = x + residual residual = x.to(orig_dtype) - if x.shape[-1] != hidden_size: + if weight is not None and x.shape[-1] != hidden_size: raise ValueError( f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" ) + # When weight is None (weightless norm), use input dim for validation + if weight is None: + hidden_size = x.shape[-1] + if variance_size_override is None: x_var = x else: @@ -266,7 +270,10 @@ def forward_cuda( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is None and not envs.VLLM_BATCH_INVARIANT: return ir.ops.rms_norm( - x, self.weight.data, self.variance_epsilon, self.variance_size_override + x, + self.weight.data if self.has_weight else None, + self.variance_epsilon, + self.variance_size_override, ) if self.variance_size_override is not None: @@ -313,13 +320,21 @@ def forward_cuda( ) return x, residual + # When has_weight is False, fall back to forward_native which + # handles weightless norms via forward_static without allocating + # a dummy ones tensor on every call. + if not self.has_weight: + return self.forward_native(x, residual) + if residual is not None: return fused_add_rms_norm( x, residual, self.weight.data, self.variance_epsilon ) else: assert envs.VLLM_BATCH_INVARIANT - return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon) + return rms_norm_batch_invariant( + x, self.weight.data, self.variance_epsilon + ) def forward_hip( self, @@ -328,19 +343,30 @@ def forward_hip( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is None and not envs.VLLM_BATCH_INVARIANT: return ir.ops.rms_norm( - x, self.weight.data, self.variance_epsilon, self.variance_size_override + x, + self.weight.data if self.has_weight else None, + self.variance_epsilon, + self.variance_size_override, ) if self.variance_size_override is not None: return self.forward_native(x, residual) + # When has_weight is False, fall back to forward_native which + # handles weightless norms via forward_static without allocating + # a dummy ones tensor on every call. + if not self.has_weight: + return self.forward_native(x, residual) + if residual is not None: return self.rocm_norm_func_with_add( x, residual, self.weight.data, self.variance_epsilon ) else: assert envs.VLLM_BATCH_INVARIANT - return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon) + return rms_norm_batch_invariant( + x, self.weight.data, self.variance_epsilon + ) def forward_xpu( self,