Skip to content

Commit

Permalink
Adapter tuning accepts expanded language model dir (#6376)
Browse files Browse the repository at this point in the history
* set model extracted dir such that the save restore connector can load expanded base models

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arendu and pre-commit-ci[bot] committed Apr 5, 2023
1 parent 4e03e3b commit 7778fcc
Showing 1 changed file with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# Adapted by: @adithyare


import os

import torch
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer
Expand All @@ -34,6 +36,7 @@
ParallelLinearAdapterConfig,
)
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.classes.mixins import adapter_mixins
from nemo.utils import logging, model_utils
Expand All @@ -42,6 +45,15 @@
class MegatronGPTBaseAdapterModel(MegatronGPTPromptLearningModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer)
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.get('language_model_path')):
save_restore_connector.model_extracted_dir = cfg.get('language_model_path')
self.frozen_model_cfg = MegatronGPTModel.restore_from(
cfg.get('language_model_path'),
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
self.adapter_name_keys = []

def forward(
Expand Down Expand Up @@ -246,9 +258,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
], "Adapter type should be 'linear_adapter' or 'parallel_adapter'"

self.adapter_name_keys = [AdapterName.PRE_ATTN_ADAPTER, AdapterName.POST_ATTN_ADAPTER]
frozen_model_cfg = MegatronGPTModel.restore_from(
cfg.get('language_model_path'), trainer=trainer, return_config=True
)
for _, layer in self.frozen_model.named_modules():
if hasattr(layer, 'activations_checkpoint_method'):
layer.activations_checkpoint_method = (
Expand All @@ -259,7 +268,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

if cfg.adapter_tuning.type == "parallel_adapter":
adapter_cfg = ParallelLinearAdapterConfig(
in_features=frozen_model_cfg.hidden_size,
in_features=self.frozen_model_cfg.hidden_size,
dim=cfg.adapter_tuning.adapter_dim,
norm_position=cfg.adapter_tuning.get('norm_position', 'pre'),
norm_type=cfg.adapter_tuning.get('norm_type', 'mixedfusedlayernorm'),
Expand All @@ -269,7 +278,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
)
else:
adapter_cfg = LinearAdapterConfig(
in_features=frozen_model_cfg.hidden_size,
in_features=self.frozen_model_cfg.hidden_size,
dim=cfg.adapter_tuning.adapter_dim,
norm_position=cfg.adapter_tuning.get('norm_position', 'pre'),
dropout=cfg.adapter_tuning.adapter_dropout,
Expand Down Expand Up @@ -306,9 +315,6 @@ class MegatronGPTInfusedAdapterModel(MegatronGPTBaseAdapterModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer)
self.adapter_name_keys = [AdapterName.KEY_INFUSED, AdapterName.VALUE_INFUSED, AdapterName.MLP_INFUSED]
frozen_model_cfg = MegatronGPTModel.restore_from(
cfg.get('language_model_path'), trainer=trainer, return_config=True
)
for _, layer in self.frozen_model.named_modules():
if hasattr(layer, 'activations_checkpoint_method'):
layer.activations_checkpoint_method = (
Expand All @@ -323,11 +329,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
for adapter_key in self.adapter_name_keys:
if adapter_key == AdapterName.MLP_INFUSED:
cfg = MLPInfusedAdapterConfig(
in_features=frozen_model_cfg.ffn_hidden_size // frozen_model_cfg.tensor_model_parallel_size
in_features=self.frozen_model_cfg.ffn_hidden_size
// self.frozen_model_cfg.tensor_model_parallel_size
)
elif adapter_key in [AdapterName.KEY_INFUSED, AdapterName.VALUE_INFUSED]:
cfg = InfusedAdapterConfig(
in_features=frozen_model_cfg.hidden_size // frozen_model_cfg.tensor_model_parallel_size
in_features=self.frozen_model_cfg.hidden_size
// self.frozen_model_cfg.tensor_model_parallel_size
)
else:
raise ValueError(f"Adapter Key {adapter_key} is unknown.")
Expand Down

0 comments on commit 7778fcc

Please sign in to comment.