Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,28 @@ def __init__(
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
self.bias = None
self.bias_loaded = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While adding the bias_loaded flag is a good optimization, it's not used consistently throughout the forward_oot method. The branch that handles a non-None residual (lines 70-79) still checks if self.bias is not None and passes self.bias to the custom op unconditionally. This can lead to unnecessary bias additions with a zero tensor, which this PR aims to prevent. To make the optimization effective in all cases, this logic should also be updated to use self.bias_loaded.


# quantization with anti_method m4 will generate none-zero norm bias
if vllm_config.quant_config is not None and any(
"norm.bias" in name for name in vllm_config.quant_config.quant_description
):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
self.bias.weight_loader = self._bias_weight_loader

def _bias_weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to wrap the original weight loader such that we don't need to implement the details.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the custom op has no weight loader function? I don't get your point. If you still have any question on this, plz feel free to open a new pr. This pr is ready for merge now.

if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()})"
)

param.data.copy_(loaded_weight)
self.bias_loaded = True

def forward_oot(
self,
Expand All @@ -62,7 +79,7 @@ def forward_oot(
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
if self.bias is not None:
if self.bias_loaded:
x.add_(self.bias)

weight_prefetch_method = get_weight_prefetch_method()
Expand Down
Loading