Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

P-tuning refactor Part 3/N #6106

Merged
merged 46 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
2b95406
patch to allow using tokenizers without additional_special_tokens_ids…
arendu Dec 15, 2022
c131a90
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Dec 16, 2022
9e15c3a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Dec 20, 2022
d0e3669
merge main
arendu Jan 5, 2023
0a19a5a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
ec3d57b
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
64e36ba
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
5bfde7e
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 7, 2023
b04b145
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 10, 2023
b1906ab
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 11, 2023
9795062
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 12, 2023
0f83085
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 19, 2023
ee4dd1a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 20, 2023
53ba0b2
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 23, 2023
a6aee2a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 24, 2023
33442d4
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 30, 2023
8e6c5c9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 31, 2023
efd263c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 8, 2023
ecfda4f
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 10, 2023
15aee0c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 15, 2023
2b7f3de
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 15, 2023
f62cde9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 15, 2023
31915c9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 17, 2023
fa22a1f
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 18, 2023
e62cd47
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 23, 2023
321a907
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 23, 2023
aeaf13f
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 24, 2023
c9a61f1
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 24, 2023
252c5ed
steps to inheret gpt from base
arendu Feb 24, 2023
8a7e23c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2023
50bcfb5
steps to inheret gpt from base
arendu Feb 24, 2023
7c0cb77
steps to inheret gpt from base
arendu Feb 24, 2023
1808262
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2023
bc25a8f
fix in pred step
arendu Feb 24, 2023
963538f
Merge branch 'adithyare/refac_ptuning_part3.2' of https://github.com/…
arendu Feb 24, 2023
6cae9d8
moved to base model
arendu Feb 24, 2023
ca5bc2c
removing prompt table class moved into prompt encoder class
arendu Feb 24, 2023
204a3d0
minor fix
arendu Feb 24, 2023
5c2734a
minor fix
arendu Feb 24, 2023
c21814c
changes in dialogue models
arendu Feb 24, 2023
e82bf3c
Merge branch 'main' into adithyare/refac_ptuning_part3.2
arendu Feb 24, 2023
69158c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2023
5e16431
minor updates to classifications
arendu Feb 24, 2023
9b834b1
Merge branch 'adithyare/refac_ptuning_part3.2' of https://github.com/…
arendu Feb 24, 2023
58c70a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2023
86312d7
Merge branch 'main' into adithyare/refac_ptuning_part3.2
arendu Feb 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/nlp/dialogue/conf/dialogue_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ model:
p_tuning: # P-tuning specific params
dropout: 0.0
num_layers: 2
encoder_type: mlp # lstm or tpmlp or embedding
prompt_learning_nemo_path: prompt_learning.nemo
data: {}
virtual_prompt_style: 'p-tuning' # 'prompt-tuning'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def pad_taskname_ids(self, taskname_ids):
taskname_ids = torch.tensor(taskname_ids)

# Task ids are just used for a look up embeddings for prompt-table
elif self.virtual_prompt_source in [VirtualPromptSource.PROMPT_TABLE, VirtualPromptSource.NO_PROMPT]:
elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT:
taskname_ids = torch.tensor(taskname_ids)

return taskname_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ def load_data(self, dataset):
if self.min_seq_length <= len(input_ids) <= self.max_seq_length:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
taskname_id = self.tokenizer.text_to_ids(taskname)

elif self.virtual_prompt_source == VirtualPromptSource.PROMPT_TABLE:
taskname_id = self.task_templates[taskname]["task_id_num"]
elif (
self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT
): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def get_prompt_token_labels_for_megatron_gpt(self, input_ids, num_prompt_tokens)

def get_virtual_prompt_ids_for_megatron_gpt(self, input_ids):
if (
self.cfg.virtual_prompt_style == VirtualPromptStyle.PROMPT_TUNING
self.cfg.virtual_prompt_style == VirtualPromptStyle.P_TUNING
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VirtualPromptStyle.PROMPT_TUNING is removed from the enum?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay. saw it. NVM

or not self.prompt_learning
or self.trainer.testing
):
Expand Down Expand Up @@ -712,10 +712,12 @@ def prepare_data(self):
def setup(self, stage=None):
super().setup(stage)
if self.cfg.library == "megatron" and self.prompt_learning and stage == "fit":
if self.cfg.virtual_prompt_style == VirtualPromptStyle.PROMPT_TUNING:
self.language_model.init_new_prompts()
else:
if self.cfg.virtual_prompt_style == VirtualPromptStyle.P_TUNING:
self.language_model.init_prompt_encoder()
else:
raise ValueError(
"Use model.virtual_prompt_style='p-tuning' with model.p_tuning.encoder_type='embedding' to enable prompt-tuning."
)

def update_data_dirs(self, data_dir: str, dialogues_example_dir: str):
"""
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
return output

def setup(self, stage=None):
if stage == 'predict' or self.virtual_prompt_style == VirtualPromptStyle.INFERENCE:
if stage == 'predict':
self.frozen_model.freeze()
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@

from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.models.language_modeling.megatron_base_prompt_learning_model import (
MegatronBasePromptLearningModel,
)
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common import (
PromptEncoder,
PromptEncoderType,
PromptTable,
VirtualPromptPlaceholderToken,
VirtualPromptSource,
VirtualPromptStyle,
Expand Down Expand Up @@ -62,7 +64,7 @@
__all__ = ['MegatronGPTPromptLearningModel']


class MegatronGPTPromptLearningModel(MegatronBaseModel, TextGeneration):
class MegatronGPTPromptLearningModel(MegatronBasePromptLearningModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why drop the TextGeneration interface?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base model containers it now!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class now has TextGenetation.

"""
Model class for prompt-tuning or p-tuning a pretrained Megatron GPT model.

Expand All @@ -84,7 +86,9 @@ class MegatronGPTPromptLearningModel(MegatronBaseModel, TextGeneration):

def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer)
self.init_model(cfg, trainer)

def init_model(self, cfg: DictConfig, trainer: Trainer):
self.cfg = cfg
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.get('language_model_path')):
Expand Down Expand Up @@ -150,14 +154,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# Load templates for assigning virtual prompt token positions
self.load_task_templates(self.cfg.task_templates)

if self.frozen_model.model.pre_process and self.virtual_prompt_style in [
if self.first_stage_of_pipeline() and self.virtual_prompt_style in [
VirtualPromptStyle.P_TUNING,
]:

self.word_embeddings = self.frozen_model.model.language_model.embedding.word_embeddings

self.padded_vocab_size = self.frozen_model.padded_vocab_size
self._prompt_encoder_key = VirtualPromptSource.PROMPT_ENCODER.value

# Prepare pseudo token ids for virtual/virtual prompt tokens
self.pseudo_tokens = get_pseudo_tokens(self.max_virtual_tokens)
Expand All @@ -172,9 +175,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
elif self.virtual_prompt_style == VirtualPromptStyle.NO_PROMPT:
self.virtual_prompt_source = VirtualPromptSource.NO_PROMPT
else:
raise ValueError(
f"\nvirtual prompt style '{cfg.virtual_prompt_style}' not recognized, please use one of 'prompt-tuning' or 'p-tuning'"
)
raise ValueError(f"\nvirtual prompt style '{cfg.virtual_prompt_style}.'")

self._reduced_loss_buffer = []
self._inference_config = None
Expand All @@ -184,158 +185,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.lowest_val_loss = None
self.prompt_encoder = None

def load_task_templates(self, task_templates):
"""
Takes in the task template portion of the config and turns
it into a table where each task's prompt template and
the number of virtual tokens to insert in a given part of
the prompt template are specified.
"""
self.task_templates = {}
self.task_id_num_to_name = {}
self.max_virtual_tokens = 0

task_id_num = 0
for task in task_templates:
self.task_templates[task.taskname] = {
"prompt_template": task.prompt_template,
"prompt_template_fields": re.findall("\{(.*?)\}", task.prompt_template),
"answer_only_loss": task.get("answer_only_loss", False),
"answer_field": task.get("answer_field", None),
"truncate_field": task.truncate_field,
"total_virtual_tokens": task.total_virtual_tokens,
"virtual_token_splits": task.virtual_token_splits,
"task_id_num": task_id_num,
}

self.max_virtual_tokens = max(self.max_virtual_tokens, task.total_virtual_tokens)
self.task_id_num_to_name[task_id_num] = task.taskname
task_id_num += 1

# Check that all new tasks have the same total num virtual tokens
# Num virtual tokens for new tasks don't need to match num used for previously tuned tasks
if self.new_tasks:
new_task_name = self.new_tasks[0]
self.total_new_task_virtual_tokens = self.task_templates[new_task_name]["total_virtual_tokens"]

assert all(
self.task_templates[taskname]["total_virtual_tokens"] == self.total_new_task_virtual_tokens
for taskname in self.new_tasks
), "Total virtual tokens for each task tuned simultaneously must match. If you want to use a different number of virtual tokens for different tasks, tune them separately."

def init_new_prompts(self):
"""
Initialize new virtual prompts to be tuned using prompt tuning
"""
for idx, taskname in enumerate(self.new_tasks):
init_method = self.cfg.prompt_tuning.new_prompt_init_methods[idx].lower()
total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"]

if init_method == "text":
init_text = self.cfg.prompt_tuning.new_prompt_init_text[idx]
init_text_ids = self.tokenizer.text_to_ids(init_text)
self.prompt_table.init_prompt_from_text(
taskname, init_text_ids, self.word_embeddings, total_virtual_tokens
)

elif init_method == 'random':
self.prompt_table.init_prompt_from_random(taskname, total_virtual_tokens)

else:
raise AttributeError(
f'\nvirtual prompt init method {init_method} is not recognized\
please use one of text or random'
)

def init_prompt_encoder(self):
"""
Init the prompt encoder needed for p-tuning on a new task
"""
# Total virtual tokens should be the same across all new tasks, so just need one
new_task = self.new_tasks[0]
total_virtual_tokens = self.task_templates[new_task]["total_virtual_tokens"]

encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "tpmlp").lower())
self.prompt_encoder = PromptEncoder(
encoder_type=encoder_type,
total_virtual_tokens=total_virtual_tokens,
token_dim=self.hidden_size,
hidden_size=self.cfg.p_tuning.get("encoder_hidden", self.hidden_size // 2),
lstm_dropout=self.cfg.p_tuning.get("dropout", 0.0),
num_layers=self.cfg.p_tuning.get("num_layers", 2),
init_std=self.cfg.p_tuning.get("init_std", 0.023),
taskname=new_task,
)

def freeze_existing_word_embeddings(self):
"""Freeze params of existing virtual prompts that should not be tuned further
"""
# Make sure word embeddings are frozen
for params in self.word_embeddings.parameters():
params.requires_grad = False

def get_model_tasks(self):
"""
For user to inspect which tasks the model has been
p-tuned/prompt-tuned to preform.
"""
tasks = {}
for taskname in self.prompt_table.prompt_table.keys():
tasks[taskname] = self.task_templates[taskname].copy()

return tasks

def state_dict(self, destination=None, prefix=None, keep_vars=False):
"""
Custom state dict that only contains prompt table and prompt encoder parameters.
No frozen model parameters are stored in the state dict. Prompt encoder parameters
are only in state dict for intermediate checkpoints saved during training. Final
nemo checkpoints at the end of training will contain prompt table parameters only.
"""
state_dict_ = {}
if self.frozen_model.model.pre_process:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
state_dict_ = self.prompt_encoder.state_dict()
else:
raise RuntimeError("invalid virtual prompt source")

return state_dict_

def load_state_dict(self, state_dict, strict: bool = True):
"""
Custom load state dict method that only loads prompt table and prompt encoder
parameters. Matching load method for this class' custom state dict method.
"""
if self.frozen_model.model.pre_process:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
if self.prompt_encoder is None:
self.init_prompt_encoder()
self.prompt_encoder.load_state_dict(state_dict, strict)
else:
raise RuntimeError("invalid virtual prompt source")

def setup_optimizer_param_groups(self):
"""
ModelPT override. Optimizer will get self._optimizer_param_groups.
Makes two optimizer param groups, one for the frozen model params
and one for the prompt-table/prompt-encoder params. The learning
rate for the frozen model's params will always be zero effectively
freezing the model's params but still allowing for the needed gradients
to be passed around in pipeline parallel models. The prompt-encoder
and/or prompt table will use the learning rate set by the user.
"""
# Freeze frozen model
for param in self.frozen_model.parameters():
param.requires_grad = False

virtual_prompt_params = {'params': []}

if self.frozen_model.model.pre_process:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
virtual_prompt_params['params'].extend([param for param in self.prompt_encoder.parameters()])
else:
raise RuntimeError("should not be here.")
self._optimizer_param_groups = (virtual_prompt_params,)
def first_stage_of_pipeline(self):
return self.frozen_model.model.pre_process

def forward(
self,
Expand All @@ -354,7 +205,7 @@ def forward(
in the MegatronGPT class.
"""
# Get embeddings for text tokens and insert virtual token embeddings
if self.frozen_model.model.pre_process:
if self.first_stage_of_pipeline():
input_embeds = self.embed_input(input_ids, taskname_ids, use_cached_reps=inference)
if hasattr(self.frozen_model.model.language_model.embedding, "position_embeddings"):
position_embeddings = self.frozen_model.model.language_model.embedding.position_embeddings(
Expand Down Expand Up @@ -393,54 +244,6 @@ def forward(

return output

def embed_input(self, input_ids: Tensor, taskname_ids: Tensor, use_cached_reps: bool):
"""
Replaces the virtual tokens in the input_ids with embeddings
calculated from either the 'prompt_table' or 'prompt_encoder'.
The virtual token placeholders have token_ids listed in
`self.pseudo_token_ids`.

params:
input_ids: the input token ids
taskname_ids: the NLP task tag token ids
returns:
the token embedding for the LM model.
"""
# Replace virtual token ids with padding for forward pass through vocab embeddings
discrete_token_ids = input_ids.clone()
discrete_token_ids[(input_ids >= self.pseudo_token_ids_start)] = self.pad_token_id
discrete_token_embeds = self.word_embeddings(discrete_token_ids).clone()

# Find the indicies where virtual tokens should be inserted
virtual_token_locations = input_ids >= self.pseudo_token_ids_start

# If there are no virtual tokens, just return discrete token embeds
if not virtual_token_locations.any():
return discrete_token_embeds

# Get virtual token embeddings from the prompt table or prompt encoder
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
batch_size, _ = taskname_ids.size()
virtual_token_embeds = self.prompt_encoder(batch_size=batch_size, use_cached_reps=use_cached_reps)
else:
raise RuntimeError("invalid VirtualPromptSource..")

# Create index template specifying where virtual token embeddings should be placed
batch_size, _, embedding_size = discrete_token_embeds.shape
virtual_token_index = virtual_token_locations.nonzero().reshape((batch_size, -1, 2))[:, :, 1][:, :, None]
virtual_token_index = virtual_token_index.expand(
batch_size, self.total_new_task_virtual_tokens, embedding_size
)

# Make sure discrete_token_embeds and virtual_token_embeds share the same dtype
discrete_token_embeds = discrete_token_embeds.type(virtual_token_embeds.dtype)

# Insert virtual token embeddings where they belong amoung the discrete token embeddings
discrete_token_embeds.scatter_(1, virtual_token_index, virtual_token_embeds)
input_embeds = discrete_token_embeds

return input_embeds

def fwd_bwd_step(self, batch, batch_idx, forward_only):
"""
Dataloader produces a global batch which is turned into a list of microbatches.
Expand Down Expand Up @@ -589,28 +392,6 @@ def test_epoch_end(self, outputs):
averaged_loss = average_losses_across_data_parallel_group(outputs)
logging.info(f'test_loss: {averaged_loss[0]}')

def on_train_end(self):
# Save p-tuned prompts to prompt table for inference or future task training
self.save_to(save_path=self.cfg.nemo_path)

def setup(self, stage=None):
if (stage == 'predict') and self.frozen_model.model.pre_process:
self.freeze_existing_word_embeddings()
return

self.setup_test_data()
if stage == 'test':
return

if self.frozen_model.model.pre_process:
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING:
self.init_prompt_encoder()

self.freeze_existing_word_embeddings()

self.setup_training_data()
self.setup_validation_data()

def setup_training_data(self, training_data_config=None):
if self.cfg.data.get('train_ds', None):
self._train_ds, self._train_dl = self.build_virtual_prompt_dataset(
Expand Down Expand Up @@ -735,12 +516,6 @@ def build_virtual_prompt_dataset(

return dataset, dataloader

def set_inference_config(self, inference_config):
self._inference_config = inference_config

def get_inference_config(self):
return self._inference_config

def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(
return output, None

def setup(self, stage=None):
if stage == 'predict' or self.virtual_prompt_style == VirtualPromptStyle.INFERENCE:
if stage == 'predict':
self.frozen_model.freeze()
return

Expand Down
Loading