Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapter tuning accepts expanded language model dir #6376

Merged
merged 2 commits into from
Apr 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# Adapted by: @adithyare


import os

import torch
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer
Expand All @@ -34,6 +36,7 @@
ParallelLinearAdapterConfig,
)
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.classes.mixins import adapter_mixins
from nemo.utils import logging, model_utils
Expand All @@ -42,6 +45,15 @@
class MegatronGPTBaseAdapterModel(MegatronGPTPromptLearningModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer)
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.get('language_model_path')):
save_restore_connector.model_extracted_dir = cfg.get('language_model_path')
self.frozen_model_cfg = MegatronGPTModel.restore_from(
cfg.get('language_model_path'),
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
self.adapter_name_keys = []

def forward(
Expand Down Expand Up @@ -246,9 +258,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
], "Adapter type should be 'linear_adapter' or 'parallel_adapter'"

self.adapter_name_keys = [AdapterName.PRE_ATTN_ADAPTER, AdapterName.POST_ATTN_ADAPTER]
frozen_model_cfg = MegatronGPTModel.restore_from(
cfg.get('language_model_path'), trainer=trainer, return_config=True
)
for _, layer in self.frozen_model.named_modules():
if hasattr(layer, 'activations_checkpoint_method'):
layer.activations_checkpoint_method = (
Expand All @@ -259,7 +268,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

if cfg.adapter_tuning.type == "parallel_adapter":
adapter_cfg = ParallelLinearAdapterConfig(
in_features=frozen_model_cfg.hidden_size,
in_features=self.frozen_model_cfg.hidden_size,
dim=cfg.adapter_tuning.adapter_dim,
norm_position=cfg.adapter_tuning.get('norm_position', 'pre'),
norm_type=cfg.adapter_tuning.get('norm_type', 'mixedfusedlayernorm'),
Expand All @@ -269,7 +278,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
)
else:
adapter_cfg = LinearAdapterConfig(
in_features=frozen_model_cfg.hidden_size,
in_features=self.frozen_model_cfg.hidden_size,
dim=cfg.adapter_tuning.adapter_dim,
norm_position=cfg.adapter_tuning.get('norm_position', 'pre'),
dropout=cfg.adapter_tuning.adapter_dropout,
Expand Down Expand Up @@ -306,9 +315,6 @@ class MegatronGPTInfusedAdapterModel(MegatronGPTBaseAdapterModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer)
self.adapter_name_keys = [AdapterName.KEY_INFUSED, AdapterName.VALUE_INFUSED, AdapterName.MLP_INFUSED]
frozen_model_cfg = MegatronGPTModel.restore_from(
cfg.get('language_model_path'), trainer=trainer, return_config=True
)
for _, layer in self.frozen_model.named_modules():
if hasattr(layer, 'activations_checkpoint_method'):
layer.activations_checkpoint_method = (
Expand All @@ -323,11 +329,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
for adapter_key in self.adapter_name_keys:
if adapter_key == AdapterName.MLP_INFUSED:
cfg = MLPInfusedAdapterConfig(
in_features=frozen_model_cfg.ffn_hidden_size // frozen_model_cfg.tensor_model_parallel_size
in_features=self.frozen_model_cfg.ffn_hidden_size
// self.frozen_model_cfg.tensor_model_parallel_size
)
elif adapter_key in [AdapterName.KEY_INFUSED, AdapterName.VALUE_INFUSED]:
cfg = InfusedAdapterConfig(
in_features=frozen_model_cfg.hidden_size // frozen_model_cfg.tensor_model_parallel_size
in_features=self.frozen_model_cfg.hidden_size
// self.frozen_model_cfg.tensor_model_parallel_size
)
else:
raise ValueError(f"Adapter Key {adapter_key} is unknown.")
Expand Down