-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
P-tuning refactor Part 3/N #6106
Changes from all commits
2b95406
c131a90
9e15c3a
d0e3669
0a19a5a
ec3d57b
64e36ba
5bfde7e
b04b145
b1906ab
9795062
0f83085
ee4dd1a
53ba0b2
a6aee2a
33442d4
8e6c5c9
efd263c
ecfda4f
15aee0c
2b7f3de
f62cde9
31915c9
fa22a1f
e62cd47
321a907
aeaf13f
c9a61f1
252c5ed
8a7e23c
50bcfb5
7c0cb77
1808262
bc25a8f
963538f
6cae9d8
ca5bc2c
204a3d0
5c2734a
c21814c
e82bf3c
69158c3
5e16431
9b834b1
58c70a4
86312d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,11 +25,13 @@ | |
|
||
from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset | ||
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel | ||
from nemo.collections.nlp.models.language_modeling.megatron_base_prompt_learning_model import ( | ||
MegatronBasePromptLearningModel, | ||
) | ||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel | ||
from nemo.collections.nlp.modules.common import ( | ||
PromptEncoder, | ||
PromptEncoderType, | ||
PromptTable, | ||
VirtualPromptPlaceholderToken, | ||
VirtualPromptSource, | ||
VirtualPromptStyle, | ||
|
@@ -62,7 +64,7 @@ | |
__all__ = ['MegatronGPTPromptLearningModel'] | ||
|
||
|
||
class MegatronGPTPromptLearningModel(MegatronBaseModel, TextGeneration): | ||
class MegatronGPTPromptLearningModel(MegatronBasePromptLearningModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why drop the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The base model containers it now! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The base class now has TextGenetation. |
||
""" | ||
Model class for prompt-tuning or p-tuning a pretrained Megatron GPT model. | ||
|
||
|
@@ -84,7 +86,9 @@ class MegatronGPTPromptLearningModel(MegatronBaseModel, TextGeneration): | |
|
||
def __init__(self, cfg: DictConfig, trainer: Trainer): | ||
super().__init__(cfg, trainer) | ||
self.init_model(cfg, trainer) | ||
|
||
def init_model(self, cfg: DictConfig, trainer: Trainer): | ||
self.cfg = cfg | ||
save_restore_connector = NLPSaveRestoreConnector() | ||
if os.path.isdir(cfg.get('language_model_path')): | ||
|
@@ -150,14 +154,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): | |
# Load templates for assigning virtual prompt token positions | ||
self.load_task_templates(self.cfg.task_templates) | ||
|
||
if self.frozen_model.model.pre_process and self.virtual_prompt_style in [ | ||
if self.first_stage_of_pipeline() and self.virtual_prompt_style in [ | ||
VirtualPromptStyle.P_TUNING, | ||
]: | ||
|
||
self.word_embeddings = self.frozen_model.model.language_model.embedding.word_embeddings | ||
|
||
self.padded_vocab_size = self.frozen_model.padded_vocab_size | ||
self._prompt_encoder_key = VirtualPromptSource.PROMPT_ENCODER.value | ||
|
||
# Prepare pseudo token ids for virtual/virtual prompt tokens | ||
self.pseudo_tokens = get_pseudo_tokens(self.max_virtual_tokens) | ||
|
@@ -172,9 +175,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): | |
elif self.virtual_prompt_style == VirtualPromptStyle.NO_PROMPT: | ||
self.virtual_prompt_source = VirtualPromptSource.NO_PROMPT | ||
else: | ||
raise ValueError( | ||
f"\nvirtual prompt style '{cfg.virtual_prompt_style}' not recognized, please use one of 'prompt-tuning' or 'p-tuning'" | ||
) | ||
raise ValueError(f"\nvirtual prompt style '{cfg.virtual_prompt_style}.'") | ||
|
||
self._reduced_loss_buffer = [] | ||
self._inference_config = None | ||
|
@@ -184,158 +185,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): | |
self.lowest_val_loss = None | ||
self.prompt_encoder = None | ||
|
||
def load_task_templates(self, task_templates): | ||
""" | ||
Takes in the task template portion of the config and turns | ||
it into a table where each task's prompt template and | ||
the number of virtual tokens to insert in a given part of | ||
the prompt template are specified. | ||
""" | ||
self.task_templates = {} | ||
self.task_id_num_to_name = {} | ||
self.max_virtual_tokens = 0 | ||
|
||
task_id_num = 0 | ||
for task in task_templates: | ||
self.task_templates[task.taskname] = { | ||
"prompt_template": task.prompt_template, | ||
"prompt_template_fields": re.findall("\{(.*?)\}", task.prompt_template), | ||
"answer_only_loss": task.get("answer_only_loss", False), | ||
"answer_field": task.get("answer_field", None), | ||
"truncate_field": task.truncate_field, | ||
"total_virtual_tokens": task.total_virtual_tokens, | ||
"virtual_token_splits": task.virtual_token_splits, | ||
"task_id_num": task_id_num, | ||
} | ||
|
||
self.max_virtual_tokens = max(self.max_virtual_tokens, task.total_virtual_tokens) | ||
self.task_id_num_to_name[task_id_num] = task.taskname | ||
task_id_num += 1 | ||
|
||
# Check that all new tasks have the same total num virtual tokens | ||
# Num virtual tokens for new tasks don't need to match num used for previously tuned tasks | ||
if self.new_tasks: | ||
new_task_name = self.new_tasks[0] | ||
self.total_new_task_virtual_tokens = self.task_templates[new_task_name]["total_virtual_tokens"] | ||
|
||
assert all( | ||
self.task_templates[taskname]["total_virtual_tokens"] == self.total_new_task_virtual_tokens | ||
for taskname in self.new_tasks | ||
), "Total virtual tokens for each task tuned simultaneously must match. If you want to use a different number of virtual tokens for different tasks, tune them separately." | ||
|
||
def init_new_prompts(self): | ||
""" | ||
Initialize new virtual prompts to be tuned using prompt tuning | ||
""" | ||
for idx, taskname in enumerate(self.new_tasks): | ||
init_method = self.cfg.prompt_tuning.new_prompt_init_methods[idx].lower() | ||
total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] | ||
|
||
if init_method == "text": | ||
init_text = self.cfg.prompt_tuning.new_prompt_init_text[idx] | ||
init_text_ids = self.tokenizer.text_to_ids(init_text) | ||
self.prompt_table.init_prompt_from_text( | ||
taskname, init_text_ids, self.word_embeddings, total_virtual_tokens | ||
) | ||
|
||
elif init_method == 'random': | ||
self.prompt_table.init_prompt_from_random(taskname, total_virtual_tokens) | ||
|
||
else: | ||
raise AttributeError( | ||
f'\nvirtual prompt init method {init_method} is not recognized\ | ||
please use one of text or random' | ||
) | ||
|
||
def init_prompt_encoder(self): | ||
""" | ||
Init the prompt encoder needed for p-tuning on a new task | ||
""" | ||
# Total virtual tokens should be the same across all new tasks, so just need one | ||
new_task = self.new_tasks[0] | ||
total_virtual_tokens = self.task_templates[new_task]["total_virtual_tokens"] | ||
|
||
encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "tpmlp").lower()) | ||
self.prompt_encoder = PromptEncoder( | ||
encoder_type=encoder_type, | ||
total_virtual_tokens=total_virtual_tokens, | ||
token_dim=self.hidden_size, | ||
hidden_size=self.cfg.p_tuning.get("encoder_hidden", self.hidden_size // 2), | ||
lstm_dropout=self.cfg.p_tuning.get("dropout", 0.0), | ||
num_layers=self.cfg.p_tuning.get("num_layers", 2), | ||
init_std=self.cfg.p_tuning.get("init_std", 0.023), | ||
taskname=new_task, | ||
) | ||
|
||
def freeze_existing_word_embeddings(self): | ||
"""Freeze params of existing virtual prompts that should not be tuned further | ||
""" | ||
# Make sure word embeddings are frozen | ||
for params in self.word_embeddings.parameters(): | ||
params.requires_grad = False | ||
|
||
def get_model_tasks(self): | ||
""" | ||
For user to inspect which tasks the model has been | ||
p-tuned/prompt-tuned to preform. | ||
""" | ||
tasks = {} | ||
for taskname in self.prompt_table.prompt_table.keys(): | ||
tasks[taskname] = self.task_templates[taskname].copy() | ||
|
||
return tasks | ||
|
||
def state_dict(self, destination=None, prefix=None, keep_vars=False): | ||
""" | ||
Custom state dict that only contains prompt table and prompt encoder parameters. | ||
No frozen model parameters are stored in the state dict. Prompt encoder parameters | ||
are only in state dict for intermediate checkpoints saved during training. Final | ||
nemo checkpoints at the end of training will contain prompt table parameters only. | ||
""" | ||
state_dict_ = {} | ||
if self.frozen_model.model.pre_process: | ||
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: | ||
state_dict_ = self.prompt_encoder.state_dict() | ||
else: | ||
raise RuntimeError("invalid virtual prompt source") | ||
|
||
return state_dict_ | ||
|
||
def load_state_dict(self, state_dict, strict: bool = True): | ||
""" | ||
Custom load state dict method that only loads prompt table and prompt encoder | ||
parameters. Matching load method for this class' custom state dict method. | ||
""" | ||
if self.frozen_model.model.pre_process: | ||
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: | ||
if self.prompt_encoder is None: | ||
self.init_prompt_encoder() | ||
self.prompt_encoder.load_state_dict(state_dict, strict) | ||
else: | ||
raise RuntimeError("invalid virtual prompt source") | ||
|
||
def setup_optimizer_param_groups(self): | ||
""" | ||
ModelPT override. Optimizer will get self._optimizer_param_groups. | ||
Makes two optimizer param groups, one for the frozen model params | ||
and one for the prompt-table/prompt-encoder params. The learning | ||
rate for the frozen model's params will always be zero effectively | ||
freezing the model's params but still allowing for the needed gradients | ||
to be passed around in pipeline parallel models. The prompt-encoder | ||
and/or prompt table will use the learning rate set by the user. | ||
""" | ||
# Freeze frozen model | ||
for param in self.frozen_model.parameters(): | ||
param.requires_grad = False | ||
|
||
virtual_prompt_params = {'params': []} | ||
|
||
if self.frozen_model.model.pre_process: | ||
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: | ||
virtual_prompt_params['params'].extend([param for param in self.prompt_encoder.parameters()]) | ||
else: | ||
raise RuntimeError("should not be here.") | ||
self._optimizer_param_groups = (virtual_prompt_params,) | ||
def first_stage_of_pipeline(self): | ||
return self.frozen_model.model.pre_process | ||
|
||
def forward( | ||
self, | ||
|
@@ -354,7 +205,7 @@ def forward( | |
in the MegatronGPT class. | ||
""" | ||
# Get embeddings for text tokens and insert virtual token embeddings | ||
if self.frozen_model.model.pre_process: | ||
if self.first_stage_of_pipeline(): | ||
input_embeds = self.embed_input(input_ids, taskname_ids, use_cached_reps=inference) | ||
if hasattr(self.frozen_model.model.language_model.embedding, "position_embeddings"): | ||
position_embeddings = self.frozen_model.model.language_model.embedding.position_embeddings( | ||
|
@@ -393,54 +244,6 @@ def forward( | |
|
||
return output | ||
|
||
def embed_input(self, input_ids: Tensor, taskname_ids: Tensor, use_cached_reps: bool): | ||
""" | ||
Replaces the virtual tokens in the input_ids with embeddings | ||
calculated from either the 'prompt_table' or 'prompt_encoder'. | ||
The virtual token placeholders have token_ids listed in | ||
`self.pseudo_token_ids`. | ||
|
||
params: | ||
input_ids: the input token ids | ||
taskname_ids: the NLP task tag token ids | ||
returns: | ||
the token embedding for the LM model. | ||
""" | ||
# Replace virtual token ids with padding for forward pass through vocab embeddings | ||
discrete_token_ids = input_ids.clone() | ||
discrete_token_ids[(input_ids >= self.pseudo_token_ids_start)] = self.pad_token_id | ||
discrete_token_embeds = self.word_embeddings(discrete_token_ids).clone() | ||
|
||
# Find the indicies where virtual tokens should be inserted | ||
virtual_token_locations = input_ids >= self.pseudo_token_ids_start | ||
|
||
# If there are no virtual tokens, just return discrete token embeds | ||
if not virtual_token_locations.any(): | ||
return discrete_token_embeds | ||
|
||
# Get virtual token embeddings from the prompt table or prompt encoder | ||
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: | ||
batch_size, _ = taskname_ids.size() | ||
virtual_token_embeds = self.prompt_encoder(batch_size=batch_size, use_cached_reps=use_cached_reps) | ||
else: | ||
raise RuntimeError("invalid VirtualPromptSource..") | ||
|
||
# Create index template specifying where virtual token embeddings should be placed | ||
batch_size, _, embedding_size = discrete_token_embeds.shape | ||
virtual_token_index = virtual_token_locations.nonzero().reshape((batch_size, -1, 2))[:, :, 1][:, :, None] | ||
virtual_token_index = virtual_token_index.expand( | ||
batch_size, self.total_new_task_virtual_tokens, embedding_size | ||
) | ||
|
||
# Make sure discrete_token_embeds and virtual_token_embeds share the same dtype | ||
discrete_token_embeds = discrete_token_embeds.type(virtual_token_embeds.dtype) | ||
|
||
# Insert virtual token embeddings where they belong amoung the discrete token embeddings | ||
discrete_token_embeds.scatter_(1, virtual_token_index, virtual_token_embeds) | ||
input_embeds = discrete_token_embeds | ||
|
||
return input_embeds | ||
|
||
def fwd_bwd_step(self, batch, batch_idx, forward_only): | ||
""" | ||
Dataloader produces a global batch which is turned into a list of microbatches. | ||
|
@@ -589,28 +392,6 @@ def test_epoch_end(self, outputs): | |
averaged_loss = average_losses_across_data_parallel_group(outputs) | ||
logging.info(f'test_loss: {averaged_loss[0]}') | ||
|
||
def on_train_end(self): | ||
# Save p-tuned prompts to prompt table for inference or future task training | ||
self.save_to(save_path=self.cfg.nemo_path) | ||
|
||
def setup(self, stage=None): | ||
if (stage == 'predict') and self.frozen_model.model.pre_process: | ||
self.freeze_existing_word_embeddings() | ||
return | ||
|
||
self.setup_test_data() | ||
if stage == 'test': | ||
return | ||
|
||
if self.frozen_model.model.pre_process: | ||
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: | ||
self.init_prompt_encoder() | ||
|
||
self.freeze_existing_word_embeddings() | ||
|
||
self.setup_training_data() | ||
self.setup_validation_data() | ||
|
||
def setup_training_data(self, training_data_config=None): | ||
if self.cfg.data.get('train_ds', None): | ||
self._train_ds, self._train_dl = self.build_virtual_prompt_dataset( | ||
|
@@ -735,12 +516,6 @@ def build_virtual_prompt_dataset( | |
|
||
return dataset, dataloader | ||
|
||
def set_inference_config(self, inference_config): | ||
self._inference_config = inference_config | ||
|
||
def get_inference_config(self): | ||
return self._inference_config | ||
|
||
def set_input_tensor(self, input_tensor): | ||
"""Set input tensor to be used instead of forward()'s input. | ||
When doing pipeline parallelism the input from the previous | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VirtualPromptStyle.PROMPT_TUNING
is removed from the enum?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay. saw it. NVM