Skip to content

Commit

Permalink
P-tuning refactor Part 3/N (#6106)
Browse files Browse the repository at this point in the history
* patch to allow using tokenizers without additional_special_tokens_ids attribute

Signed-off-by: arendu <[email protected]>

* steps to inheret gpt from base

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* steps to inheret gpt from base

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix in pred step

Signed-off-by: arendu <[email protected]>

* moved to base model

Signed-off-by: arendu <[email protected]>

* removing prompt table class moved into prompt encoder class

Signed-off-by: arendu <[email protected]>

* minor fix

Signed-off-by: arendu <[email protected]>

* minor fix

Signed-off-by: arendu <[email protected]>

* changes in dialogue models

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor updates to classifications

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arendu and pre-commit-ci[bot] committed Feb 24, 2023
1 parent 3e8aa5e commit dabd8b8
Show file tree
Hide file tree
Showing 12 changed files with 72 additions and 662 deletions.
1 change: 1 addition & 0 deletions examples/nlp/dialogue/conf/dialogue_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ model:
p_tuning: # P-tuning specific params
dropout: 0.0
num_layers: 2
encoder_type: mlp # lstm or tpmlp or embedding
prompt_learning_nemo_path: prompt_learning.nemo
data: {}
virtual_prompt_style: 'p-tuning' # 'prompt-tuning'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def pad_taskname_ids(self, taskname_ids):
taskname_ids = torch.tensor(taskname_ids)

# Task ids are just used for a look up embeddings for prompt-table
elif self.virtual_prompt_source in [VirtualPromptSource.PROMPT_TABLE, VirtualPromptSource.NO_PROMPT]:
elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT:
taskname_ids = torch.tensor(taskname_ids)

return taskname_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ def load_data(self, dataset):
if self.min_seq_length <= len(input_ids) <= self.max_seq_length:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
taskname_id = self.tokenizer.text_to_ids(taskname)

elif self.virtual_prompt_source == VirtualPromptSource.PROMPT_TABLE:
taskname_id = self.task_templates[taskname]["task_id_num"]
elif (
self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT
): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def get_prompt_token_labels_for_megatron_gpt(self, input_ids, num_prompt_tokens)

def get_virtual_prompt_ids_for_megatron_gpt(self, input_ids):
if (
self.cfg.virtual_prompt_style == VirtualPromptStyle.PROMPT_TUNING
self.cfg.virtual_prompt_style == VirtualPromptStyle.P_TUNING
or not self.prompt_learning
or self.trainer.testing
):
Expand Down Expand Up @@ -712,10 +712,12 @@ def prepare_data(self):
def setup(self, stage=None):
super().setup(stage)
if self.cfg.library == "megatron" and self.prompt_learning and stage == "fit":
if self.cfg.virtual_prompt_style == VirtualPromptStyle.PROMPT_TUNING:
self.language_model.init_new_prompts()
else:
if self.cfg.virtual_prompt_style == VirtualPromptStyle.P_TUNING:
self.language_model.init_prompt_encoder()
else:
raise ValueError(
"Use model.virtual_prompt_style='p-tuning' with model.p_tuning.encoder_type='embedding' to enable prompt-tuning."
)

def update_data_dirs(self, data_dir: str, dialogues_example_dir: str):
"""
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
return output

def setup(self, stage=None):
if stage == 'predict' or self.virtual_prompt_style == VirtualPromptStyle.INFERENCE:
if stage == 'predict':
self.frozen_model.freeze()
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@

from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.models.language_modeling.megatron_base_prompt_learning_model import (
MegatronBasePromptLearningModel,
)
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common import (
PromptEncoder,
PromptEncoderType,
PromptTable,
VirtualPromptPlaceholderToken,
VirtualPromptSource,
VirtualPromptStyle,
Expand Down Expand Up @@ -62,7 +64,7 @@
__all__ = ['MegatronGPTPromptLearningModel']


class MegatronGPTPromptLearningModel(MegatronBaseModel, TextGeneration):
class MegatronGPTPromptLearningModel(MegatronBasePromptLearningModel):
"""
Model class for prompt-tuning or p-tuning a pretrained Megatron GPT model.
Expand All @@ -84,7 +86,9 @@ class MegatronGPTPromptLearningModel(MegatronBaseModel, TextGeneration):

def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer)
self.init_model(cfg, trainer)

def init_model(self, cfg: DictConfig, trainer: Trainer):
self.cfg = cfg
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.get('language_model_path')):
Expand Down Expand Up @@ -150,14 +154,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# Load templates for assigning virtual prompt token positions
self.load_task_templates(self.cfg.task_templates)

if self.frozen_model.model.pre_process and self.virtual_prompt_style in [
if self.first_stage_of_pipeline() and self.virtual_prompt_style in [
VirtualPromptStyle.P_TUNING,
]:

self.word_embeddings = self.frozen_model.model.language_model.embedding.word_embeddings

self.padded_vocab_size = self.frozen_model.padded_vocab_size
self._prompt_encoder_key = VirtualPromptSource.PROMPT_ENCODER.value

# Prepare pseudo token ids for virtual/virtual prompt tokens
self.pseudo_tokens = get_pseudo_tokens(self.max_virtual_tokens)
Expand All @@ -172,9 +175,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
elif self.virtual_prompt_style == VirtualPromptStyle.NO_PROMPT:
self.virtual_prompt_source = VirtualPromptSource.NO_PROMPT
else:
raise ValueError(
f"\nvirtual prompt style '{cfg.virtual_prompt_style}' not recognized, please use one of 'prompt-tuning' or 'p-tuning'"
)
raise ValueError(f"\nvirtual prompt style '{cfg.virtual_prompt_style}.'")

self._reduced_loss_buffer = []
self._inference_config = None
Expand All @@ -184,158 +185,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.lowest_val_loss = None
self.prompt_encoder = None

def load_task_templates(self, task_templates):
"""
Takes in the task template portion of the config and turns
it into a table where each task's prompt template and
the number of virtual tokens to insert in a given part of
the prompt template are specified.
"""
self.task_templates = {}
self.task_id_num_to_name = {}
self.max_virtual_tokens = 0

task_id_num = 0
for task in task_templates:
self.task_templates[task.taskname] = {
"prompt_template": task.prompt_template,
"prompt_template_fields": re.findall("\{(.*?)\}", task.prompt_template),
"answer_only_loss": task.get("answer_only_loss", False),
"answer_field": task.get("answer_field", None),
"truncate_field": task.truncate_field,
"total_virtual_tokens": task.total_virtual_tokens,
"virtual_token_splits": task.virtual_token_splits,
"task_id_num": task_id_num,
}

self.max_virtual_tokens = max(self.max_virtual_tokens, task.total_virtual_tokens)
self.task_id_num_to_name[task_id_num] = task.taskname
task_id_num += 1

# Check that all new tasks have the same total num virtual tokens
# Num virtual tokens for new tasks don't need to match num used for previously tuned tasks
if self.new_tasks:
new_task_name = self.new_tasks[0]
self.total_new_task_virtual_tokens = self.task_templates[new_task_name]["total_virtual_tokens"]

assert all(
self.task_templates[taskname]["total_virtual_tokens"] == self.total_new_task_virtual_tokens
for taskname in self.new_tasks
), "Total virtual tokens for each task tuned simultaneously must match. If you want to use a different number of virtual tokens for different tasks, tune them separately."

def init_new_prompts(self):
"""
Initialize new virtual prompts to be tuned using prompt tuning
"""
for idx, taskname in enumerate(self.new_tasks):
init_method = self.cfg.prompt_tuning.new_prompt_init_methods[idx].lower()
total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"]

if init_method == "text":
init_text = self.cfg.prompt_tuning.new_prompt_init_text[idx]
init_text_ids = self.tokenizer.text_to_ids(init_text)
self.prompt_table.init_prompt_from_text(
taskname, init_text_ids, self.word_embeddings, total_virtual_tokens
)

elif init_method == 'random':
self.prompt_table.init_prompt_from_random(taskname, total_virtual_tokens)

else:
raise AttributeError(
f'\nvirtual prompt init method {init_method} is not recognized\
please use one of text or random'
)

def init_prompt_encoder(self):
"""
Init the prompt encoder needed for p-tuning on a new task
"""
# Total virtual tokens should be the same across all new tasks, so just need one
new_task = self.new_tasks[0]
total_virtual_tokens = self.task_templates[new_task]["total_virtual_tokens"]

encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "tpmlp").lower())
self.prompt_encoder = PromptEncoder(
encoder_type=encoder_type,
total_virtual_tokens=total_virtual_tokens,
token_dim=self.hidden_size,
hidden_size=self.cfg.p_tuning.get("encoder_hidden", self.hidden_size // 2),
lstm_dropout=self.cfg.p_tuning.get("dropout", 0.0),
num_layers=self.cfg.p_tuning.get("num_layers", 2),
init_std=self.cfg.p_tuning.get("init_std", 0.023),
taskname=new_task,
)

def freeze_existing_word_embeddings(self):
"""Freeze params of existing virtual prompts that should not be tuned further
"""
# Make sure word embeddings are frozen
for params in self.word_embeddings.parameters():
params.requires_grad = False

def get_model_tasks(self):
"""
For user to inspect which tasks the model has been
p-tuned/prompt-tuned to preform.
"""
tasks = {}
for taskname in self.prompt_table.prompt_table.keys():
tasks[taskname] = self.task_templates[taskname].copy()

return tasks

def state_dict(self, destination=None, prefix=None, keep_vars=False):
"""
Custom state dict that only contains prompt table and prompt encoder parameters.
No frozen model parameters are stored in the state dict. Prompt encoder parameters
are only in state dict for intermediate checkpoints saved during training. Final
nemo checkpoints at the end of training will contain prompt table parameters only.
"""
state_dict_ = {}
if self.frozen_model.model.pre_process:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
state_dict_ = self.prompt_encoder.state_dict()
else:
raise RuntimeError("invalid virtual prompt source")

return state_dict_

def load_state_dict(self, state_dict, strict: bool = True):
"""
Custom load state dict method that only loads prompt table and prompt encoder
parameters. Matching load method for this class' custom state dict method.
"""
if self.frozen_model.model.pre_process:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
if self.prompt_encoder is None:
self.init_prompt_encoder()
self.prompt_encoder.load_state_dict(state_dict, strict)
else:
raise RuntimeError("invalid virtual prompt source")

def setup_optimizer_param_groups(self):
"""
ModelPT override. Optimizer will get self._optimizer_param_groups.
Makes two optimizer param groups, one for the frozen model params
and one for the prompt-table/prompt-encoder params. The learning
rate for the frozen model's params will always be zero effectively
freezing the model's params but still allowing for the needed gradients
to be passed around in pipeline parallel models. The prompt-encoder
and/or prompt table will use the learning rate set by the user.
"""
# Freeze frozen model
for param in self.frozen_model.parameters():
param.requires_grad = False

virtual_prompt_params = {'params': []}

if self.frozen_model.model.pre_process:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
virtual_prompt_params['params'].extend([param for param in self.prompt_encoder.parameters()])
else:
raise RuntimeError("should not be here.")
self._optimizer_param_groups = (virtual_prompt_params,)
def first_stage_of_pipeline(self):
return self.frozen_model.model.pre_process

def forward(
self,
Expand All @@ -354,7 +205,7 @@ def forward(
in the MegatronGPT class.
"""
# Get embeddings for text tokens and insert virtual token embeddings
if self.frozen_model.model.pre_process:
if self.first_stage_of_pipeline():
input_embeds = self.embed_input(input_ids, taskname_ids, use_cached_reps=inference)
if hasattr(self.frozen_model.model.language_model.embedding, "position_embeddings"):
position_embeddings = self.frozen_model.model.language_model.embedding.position_embeddings(
Expand Down Expand Up @@ -393,54 +244,6 @@ def forward(

return output

def embed_input(self, input_ids: Tensor, taskname_ids: Tensor, use_cached_reps: bool):
"""
Replaces the virtual tokens in the input_ids with embeddings
calculated from either the 'prompt_table' or 'prompt_encoder'.
The virtual token placeholders have token_ids listed in
`self.pseudo_token_ids`.
params:
input_ids: the input token ids
taskname_ids: the NLP task tag token ids
returns:
the token embedding for the LM model.
"""
# Replace virtual token ids with padding for forward pass through vocab embeddings
discrete_token_ids = input_ids.clone()
discrete_token_ids[(input_ids >= self.pseudo_token_ids_start)] = self.pad_token_id
discrete_token_embeds = self.word_embeddings(discrete_token_ids).clone()

# Find the indicies where virtual tokens should be inserted
virtual_token_locations = input_ids >= self.pseudo_token_ids_start

# If there are no virtual tokens, just return discrete token embeds
if not virtual_token_locations.any():
return discrete_token_embeds

# Get virtual token embeddings from the prompt table or prompt encoder
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
batch_size, _ = taskname_ids.size()
virtual_token_embeds = self.prompt_encoder(batch_size=batch_size, use_cached_reps=use_cached_reps)
else:
raise RuntimeError("invalid VirtualPromptSource..")

# Create index template specifying where virtual token embeddings should be placed
batch_size, _, embedding_size = discrete_token_embeds.shape
virtual_token_index = virtual_token_locations.nonzero().reshape((batch_size, -1, 2))[:, :, 1][:, :, None]
virtual_token_index = virtual_token_index.expand(
batch_size, self.total_new_task_virtual_tokens, embedding_size
)

# Make sure discrete_token_embeds and virtual_token_embeds share the same dtype
discrete_token_embeds = discrete_token_embeds.type(virtual_token_embeds.dtype)

# Insert virtual token embeddings where they belong amoung the discrete token embeddings
discrete_token_embeds.scatter_(1, virtual_token_index, virtual_token_embeds)
input_embeds = discrete_token_embeds

return input_embeds

def fwd_bwd_step(self, batch, batch_idx, forward_only):
"""
Dataloader produces a global batch which is turned into a list of microbatches.
Expand Down Expand Up @@ -589,28 +392,6 @@ def test_epoch_end(self, outputs):
averaged_loss = average_losses_across_data_parallel_group(outputs)
logging.info(f'test_loss: {averaged_loss[0]}')

def on_train_end(self):
# Save p-tuned prompts to prompt table for inference or future task training
self.save_to(save_path=self.cfg.nemo_path)

def setup(self, stage=None):
if (stage == 'predict') and self.frozen_model.model.pre_process:
self.freeze_existing_word_embeddings()
return

self.setup_test_data()
if stage == 'test':
return

if self.frozen_model.model.pre_process:
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING:
self.init_prompt_encoder()

self.freeze_existing_word_embeddings()

self.setup_training_data()
self.setup_validation_data()

def setup_training_data(self, training_data_config=None):
if self.cfg.data.get('train_ds', None):
self._train_ds, self._train_dl = self.build_virtual_prompt_dataset(
Expand Down Expand Up @@ -735,12 +516,6 @@ def build_virtual_prompt_dataset(

return dataset, dataloader

def set_inference_config(self, inference_config):
self._inference_config = inference_config

def get_inference_config(self):
return self._inference_config

def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(
return output, None

def setup(self, stage=None):
if stage == 'predict' or self.virtual_prompt_style == VirtualPromptStyle.INFERENCE:
if stage == 'predict':
self.frozen_model.freeze()
return

Expand Down
Loading

0 comments on commit dabd8b8

Please sign in to comment.