-
Notifications
You must be signed in to change notification settings - Fork 453
Add L2NormHook and use it in megatron.py #599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
659742d
60d98c6
3242dd0
85b0229
5bdb08b
b5de20a
889eb4b
675dca4
9526a0d
dedc036
f5b85bf
839ba74
594127e
1b70f3d
992058c
674e823
f6fc88b
0ba14fd
5aaaf1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,7 @@ | |
| from megatron.core.transformer.transformer_layer import TransformerLayer | ||
|
|
||
| from modelopt.torch.nas.modules import DynamicModuleList | ||
| from modelopt.torch.nas.plugins.megatron_hooks import L2NormHook | ||
| from modelopt.torch.opt.dynamic import DynamicModule | ||
| from modelopt.torch.opt.hparam import HPType | ||
| from modelopt.torch.opt.searcher import ConstraintsDict | ||
|
|
@@ -265,39 +266,19 @@ def _setup(self): | |
| # can be discarded. | ||
| # This limitation might be fixed in OMNIML-180 (Flexible Importance Estimator) | ||
| # where we separate the importance estimation from the dynamic module. | ||
| self._register_temp_attribute("_activations", None) | ||
| self.hook_handle = self.linear_fc2.register_forward_hook(self._linear_fc2_forward_hook) | ||
| max_ffn_size = int(self.get_hparam(self.hparam_name).max) # type: ignore[arg-type] | ||
| activation_hook = L2NormHook(max_size=max_ffn_size) | ||
| self._register_temp_attribute("_activation_hook", activation_hook) | ||
| # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? | ||
| self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook) | ||
| ffn_hidden_size.register_importance(self._estimate_importance) | ||
|
|
||
| def _linear_fc2_forward_hook(self, module, input, output): | ||
| """Hook to collect activations for importance estimation. | ||
|
|
||
| Activations are computed as mean over seq_len and then squared and summed over batch_size. | ||
| Later we take the square root of the sum to get the L2 norm. | ||
| """ | ||
| # Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions | ||
| # NOTE: This is not used at the moment since we restrict to TP=1 | ||
| input = gather_from_tensor_model_parallel_region(input[0]).detach() | ||
| if input.dim() == 2: | ||
| # For sparse experts, there is no batch dimension. | ||
| input = input[:, None, :] | ||
| # Dont aggregate activations from non-max subnets (e.g. from profiling) | ||
| if input.shape[-1] != self.get_hparam(self.hparam_name).max: | ||
| return | ||
|
|
||
| input = input.to(torch.float32) # use full precision to avoid overflow | ||
| activations = input.abs().mean(dim=0) # [batch_size, ffn_hidden_size] | ||
| activations = activations.pow(2).sum(dim=0) # [ffn_hidden_size] | ||
| if self._activations is None: | ||
| self._activations = activations | ||
| else: | ||
| self._activations += activations | ||
|
|
||
| def _estimate_importance(self) -> TracedHp.Importance: | ||
| """Return the activation magnitude-based importance of the ffn_hidden_size.""" | ||
| assert self._activations is not None, "No activations collected for importance estimation." | ||
| # Convert squared sum to L2 norm | ||
| return self._activations.pow(0.5) | ||
| assert self._activation_hook._activations is not None, ( | ||
| "No activations collected for importance estimation." | ||
| ) | ||
| return self._activation_hook.accumulate() | ||
|
|
||
| def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: | ||
| """Set hidden size for shared expert.""" | ||
|
|
@@ -612,46 +593,26 @@ def _setup(self): | |
| ) | ||
|
|
||
| # register importance estimator for linear_qkv.output_size and linear_proj.input_size | ||
| self._register_temp_attribute("_activations", None) | ||
| self.hook_handle = self.linear_proj.register_forward_hook(self._linear_proj_forward_hook) | ||
| num_heads_per_group_max = int(self.get_hparam("num_heads_per_group").max) # type: ignore[arg-type] | ||
| num_query_groups_max = int(self.get_hparam("num_query_groups").max) # type: ignore[arg-type] | ||
| max_size = num_heads_per_group_max * num_query_groups_max * self.config.kv_channels | ||
| activation_hook = L2NormHook(max_size=max_size) | ||
| self._register_temp_attribute("_activation_hook", activation_hook) | ||
| # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. even if we register hook_handle as temp attribute, we still need to call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand now. |
||
| self.hook_handle = self.linear_proj.register_forward_hook(activation_hook) | ||
| # NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads, | ||
| # otherwise we would only have aggregated importance of heads per group. | ||
| # While enforcing order during `sort_parameters`, we dont check the shape of the slice_order | ||
| num_heads_per_group.register_importance(self._estimate_all_head_importance) | ||
| num_query_groups.register_importance(self._estimate_query_group_importance) | ||
|
|
||
| def _linear_proj_forward_hook(self, module, input, output): | ||
| """Hook to collect activations for importance estimation. | ||
|
|
||
| Activations are computed as mean over seq_len and then squared and summed over batch_size. | ||
| Later we take the square root of the sum to get the L2 norm. | ||
| """ | ||
| # Gather input [seq_len, batch_size, query_projection_size] over all TP regions | ||
| # NOTE: This is not used at the moment since we restrict to TP=1 | ||
| input = gather_from_tensor_model_parallel_region(input[0]).detach() | ||
|
|
||
| # Dont aggregate activations from non-max subnets (e.g. from profiling) | ||
| if ( | ||
| input.shape[-1] | ||
| != self.get_hparam("num_heads_per_group").max | ||
| * self.get_hparam("num_query_groups").max | ||
| * self.config.kv_channels | ||
| ): | ||
| return | ||
|
|
||
| input = input.to(torch.float32) # use full precision to avoid overflow | ||
| activations = input.abs().mean(dim=0) | ||
| activations = activations.pow(2).sum(dim=0) # [query_projection_size] | ||
| if self._activations is None: | ||
| self._activations = activations | ||
| else: | ||
| self._activations += activations | ||
|
|
||
| def _estimate_all_head_importance(self) -> TracedHp.Importance: | ||
| """Return the importance for num_attention_heads (num_heads_per_group * num_query_groups).""" | ||
| assert self._activations is not None, "No activations collected for importance estimation." | ||
| assert self._activation_hook._activations is not None, ( | ||
| "No activations collected for importance estimation." | ||
| ) | ||
| # Convert squared sum to L2 norm | ||
| scores = self._activations.pow(0.5) | ||
| scores = self._activation_hook.accumulate() | ||
| attn_head_importance = torch.linalg.vector_norm( | ||
| scores.view( | ||
| self.get_hparam("num_heads_per_group").max | ||
|
|
@@ -665,9 +626,11 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance: | |
|
|
||
| def _estimate_query_group_importance(self) -> TracedHp.Importance: | ||
| """Return the importance of the ``num_query_groups`` hparam.""" | ||
| assert self._activations is not None, "No activations collected for importance estimation." | ||
| assert self._activation_hook._activations is not None, ( | ||
| "No activations collected for importance estimation." | ||
| ) | ||
| # Convert squared sum to L2 norm | ||
| scores = self._activations.pow(0.5) | ||
| scores = self._activation_hook.accumulate() | ||
| group_importance = torch.linalg.vector_norm( | ||
| scores.view( | ||
| self.get_hparam("num_heads_per_group").max, | ||
|
|
@@ -1594,8 +1557,11 @@ def get_activations_and_layer_scores( | |
| """Get the per-rank activations and layer scores from the module.""" | ||
| local_activations = {} | ||
| for n, m in self.named_modules(): | ||
| # TODO: Remove legacy _activations check once all modules use _activation_hook | ||
| if hasattr(m, "_activations"): | ||
| local_activations[n] = m._activations | ||
| elif hasattr(m, "_activation_hook") and m._activation_hook._activations is not None: | ||
|
danielkorzekwa marked this conversation as resolved.
Outdated
|
||
| local_activations[n] = m._activation_hook._activations | ||
| activations_per_rank = dist.allgather( | ||
| local_activations, group=get_pipeline_model_parallel_group() | ||
| ) | ||
|
|
@@ -1624,8 +1590,11 @@ def set_activations_and_layer_scores( | |
| for layer in self.decoder.layers: | ||
| layer._scores = layer_scores[layer.layer_number] | ||
| for n, m in self.named_modules(): | ||
| # TODO: Remove legacy _activations check once all modules use _activation_hook | ||
| if hasattr(m, "_activations"): | ||
| m._activations = activations_per_rank[rank][n] | ||
| elif hasattr(m, "_activation_hook"): | ||
| m._activation_hook._activations = activations_per_rank[rank][n] | ||
|
|
||
|
|
||
| def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
AAnoosheh marked this conversation as resolved.
|
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Forward hooks for activation-based importance estimation (megatron NAS plugin).""" | ||
|
|
||
| from abc import ABC, abstractmethod | ||
|
|
||
| import torch | ||
| from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region | ||
| from torch import nn | ||
|
|
||
|
|
||
| class ForwardHook(ABC): | ||
| """Base class for PyTorch forward hooks. | ||
|
|
||
| This follows the PyTorch forward hook API where the second | ||
| parameter is 'args' (a tuple of positional arguments passed to forward()). | ||
|
|
||
| Usage: | ||
| hook = MyHook() | ||
| module.register_forward_hook(hook) | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def __call__( | ||
| self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | ||
| ) -> None: | ||
| """Forward hook that is called after the module's forward pass. | ||
|
|
||
| Args: | ||
| module: The module this hook is registered on | ||
| args: Tuple of positional arguments passed to module.forward() | ||
| output: The output from module.forward() | ||
|
|
||
| Returns: | ||
| None (does not modify the output) | ||
| """ | ||
| ... | ||
|
|
||
|
|
||
| class L2NormHook(ForwardHook): | ||
|
danielkorzekwa marked this conversation as resolved.
Outdated
|
||
| """Hook for accumulating activation statistics for importance estimation. | ||
|
|
||
| Activations are computed as mean over seq_len and then squared and summed over batch_size. | ||
| In the accumulate() method we take the square root of the sum to get the L2 norm. | ||
|
|
||
| Args: | ||
| max_size: Optional maximum expected size to validate against (skips if mismatch). | ||
| Useful for skipping non-max subnets during profiling. | ||
| """ | ||
|
|
||
| def __init__(self, max_size: int | None = None): | ||
| """Initialize the L2NormHook.""" | ||
| self.max_size = max_size | ||
| self._activations: torch.Tensor | None = None | ||
|
|
||
| def __call__( | ||
| self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | ||
| ) -> None: | ||
| """Accumulate activation statistics from the forward pass.""" | ||
| # Gather input [seq_len, batch_size, hidden_size] over all TP regions | ||
| # NOTE: This is not used at the moment since we restrict to TP=1 | ||
| input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach() | ||
|
|
||
| if input_tensor.dim() == 2: | ||
| # For sparse experts, there is no batch dimension. | ||
| input_tensor = input_tensor[:, None, :] | ||
|
|
||
| # Dont aggregate activations from non-max subnets (e.g. from profiling) | ||
| if self.max_size is not None and input_tensor.shape[-1] != self.max_size: | ||
| return | ||
|
|
||
| input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow | ||
| activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size] | ||
| activations = activations.pow(2).sum(dim=0) # [hidden_size] | ||
|
|
||
| if self._activations is None: | ||
| self._activations = activations | ||
| else: | ||
| self._activations += activations | ||
|
|
||
| def accumulate(self) -> torch.Tensor: | ||
| """Return the accumulated L2 norm of activations. | ||
|
|
||
| Returns: | ||
| Tensor of accumulated scores, one per channel | ||
|
|
||
| Raises: | ||
| AssertionError: If no activations have been collected yet | ||
| """ | ||
| assert self._activations is not None, "No activations collected for importance estimation." | ||
| # Convert squared sum to L2 norm | ||
| return self._activations.pow(0.5) | ||
Uh oh!
There was an error while loading. Please reload this page.