-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Perf] Optimize bias handling in AscendRMSNorm #7226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,11 +37,28 @@ def __init__( | |
| super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) | ||
| vllm_config = get_current_vllm_config() | ||
| self.bias = None | ||
| self.bias_loaded = False | ||
|
|
||
| # quantization with anti_method m4 will generate none-zero norm bias | ||
| if vllm_config.quant_config is not None and any( | ||
| "norm.bias" in name for name in vllm_config.quant_config.quant_description | ||
| ): | ||
| self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) | ||
| self.bias.weight_loader = self._bias_weight_loader | ||
|
|
||
| def _bias_weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better to wrap the original weight loader such that we don't need to implement the details.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the custom op has no weight loader function? I don't get your point. If you still have any question on this, plz feel free to open a new pr. This pr is ready for merge now. |
||
| if param.numel() == 1 and loaded_weight.numel() == 1: | ||
| # Sometimes scalar values aren't considered tensors with shapes | ||
| # so if both param and loaded_weight are a scalar, | ||
| # "broadcast" instead of copy | ||
| param.data.fill_(loaded_weight.item()) | ||
| else: | ||
| assert param.size() == loaded_weight.size(), ( | ||
| f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()})" | ||
| ) | ||
|
|
||
| param.data.copy_(loaded_weight) | ||
| self.bias_loaded = True | ||
|
|
||
| def forward_oot( | ||
| self, | ||
|
|
@@ -62,7 +79,7 @@ def forward_oot( | |
| return x, residual | ||
|
|
||
| x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) | ||
| if self.bias is not None: | ||
| if self.bias_loaded: | ||
| x.add_(self.bias) | ||
|
|
||
| weight_prefetch_method = get_weight_prefetch_method() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While adding the
bias_loadedflag is a good optimization, it's not used consistently throughout theforward_ootmethod. The branch that handles a non-Noneresidual(lines 70-79) still checksif self.bias is not Noneand passesself.biasto the custom op unconditionally. This can lead to unnecessary bias additions with a zero tensor, which this PR aims to prevent. To make the optimization effective in all cases, this logic should also be updated to useself.bias_loaded.