Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,15 @@ def forward_static(
x = x + residual
residual = x.to(orig_dtype)

if x.shape[-1] != hidden_size:
if weight is not None and x.shape[-1] != hidden_size:
raise ValueError(
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
)
Comment on lines +210 to 213
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this change prevents the ValueError when weight is None, it leaves the logic inconsistent for the rest of the function. Specifically, hidden_size is used later at line 218 to validate variance_size_override. If hidden_size is incorrect (e.g., 5376 instead of 72), the check if hidden_size < variance_size_override will be performed against the wrong value, which could lead to an IndexError during slicing at line 224 if an invalid override is provided.

Additionally, this fix only addresses the forward_static path. The forward_cuda and forward_hip methods still pass self.weight.data to the kernels even when has_weight is False. Since self.weight is initialized using the (potentially incorrect) hidden_size, those kernels will likely crash due to a size mismatch. You should consider applying a similar check or passing None in those paths, for example:

# In forward_cuda / forward_hip
return ir.ops.rms_norm(
    x, 
    self.weight.data if self.has_weight else None, 
    self.variance_epsilon, 
    self.variance_size_override
)

For this function, updating the local hidden_size to match the actual input dimension when weight is None ensures all subsequent logic remains consistent.

        if weight is not None:
            if x.shape[-1] != hidden_size:
                raise ValueError(
                    f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
                )
        else:
            # For weightless norms, use the actual input dimension for validation
            # of variance_size_override below.
            hidden_size = x.shape[-1]


# When weight is None (weightless norm), use input dim for validation
if weight is None:
hidden_size = x.shape[-1]

if variance_size_override is None:
x_var = x
else:
Expand Down Expand Up @@ -266,7 +270,10 @@ def forward_cuda(
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
x,
self.weight.data if self.has_weight else None,
self.variance_epsilon,
self.variance_size_override,
)

if self.variance_size_override is not None:
Expand Down Expand Up @@ -313,13 +320,21 @@ def forward_cuda(
)
return x, residual

# When has_weight is False, fall back to forward_native which
# handles weightless norms via forward_static without allocating
# a dummy ones tensor on every call.
if not self.has_weight:
return self.forward_native(x, residual)

if residual is not None:
return fused_add_rms_norm(
x, residual, self.weight.data, self.variance_epsilon
)
else:
assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
return rms_norm_batch_invariant(
x, self.weight.data, self.variance_epsilon
)

def forward_hip(
self,
Expand All @@ -328,19 +343,30 @@ def forward_hip(
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
x,
self.weight.data if self.has_weight else None,
self.variance_epsilon,
self.variance_size_override,
)

if self.variance_size_override is not None:
return self.forward_native(x, residual)

# When has_weight is False, fall back to forward_native which
# handles weightless norms via forward_static without allocating
# a dummy ones tensor on every call.
if not self.has_weight:
return self.forward_native(x, residual)

if residual is not None:
return self.rocm_norm_func_with_add(
x, residual, self.weight.data, self.variance_epsilon
)
else:
assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
return rms_norm_batch_invariant(
x, self.weight.data, self.variance_epsilon
)

def forward_xpu(
self,
Expand Down
Loading