Skip to content

Commit

Permalink
enable selective unfreeze (#7326)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* wip

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* avoid PTL method conflicts

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and yaoyu-33 committed Sep 5, 2023
1 parent e26c41b commit a47c1e3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
9 changes: 9 additions & 0 deletions nemo/collections/common/parts/adapter_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def get_default_strategy_config(self) -> 'dataclass':
"""
return adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()

def adapter_unfreeze(self,):
"""
Sets the requires grad for all parameters in the adapter to True.
This method should be overridden for any custom unfreeze behavior that is required.
For example, if not all params of the adapter should be unfrozen.
"""
for param in self.parameters():
param.requires_grad_(True)


class LinearAdapter(nn.Module, AdapterModuleUtil):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, init_method_const, init_method_normal
from nemo.core.classes.mixins import adapter_mixin_strategies


try:
from apex.normalization.fused_layer_norm import MixedFusedLayerNorm

Expand Down Expand Up @@ -207,6 +208,12 @@ def _get_init_fn(self, init_method: str):
raise NotImplementedError("out_init_method should be zero, normal or xavier")
return init_fn

def adapter_unfreeze(self,):
"""
Can be customized to allow for selective training of only some params in the PEFT.
"""
super().adapter_unfreeze()

def forward(self, x):

if self.norm_position == 'pre':
Expand Down
4 changes: 2 additions & 2 deletions nemo/core/classes/mixins/adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,12 @@ def unfreeze_enabled_adapters(self, freeze_batchnorm: bool = True) -> None:

# Check if adapter is enabled or not
if self.adapter_cfg[name]['enabled'] and name in module.adapter_layer:

# Recursively set training mode of submodules
module.adapter_layer[name].train()

# Recursively set grad required for submodules
for pname, param in module.adapter_layer[name].named_parameters():
param.requires_grad_(True)
module.adapter_layer[name].adapter_unfreeze()

# unfreeze batch norm if any in the adapter submodules
for mname, module_ in module.adapter_layer[name].named_modules():
Expand Down

0 comments on commit a47c1e3

Please sign in to comment.