Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Framework for PEFT via mixins #6391

Merged
merged 92 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
18f32ea
init commit ptuning via mixin
arendu Apr 7, 2023
e9059b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
2a470b2
updates
arendu Apr 7, 2023
395285c
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 7, 2023
3d05b45
gpt ptuning places virtual tokens on the left only
arendu Apr 7, 2023
fb1e6a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
4d7f5b9
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu Apr 7, 2023
cc81184
encoder input modified when pre_process is true
arendu Apr 7, 2023
6bf990f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
074082a
optimizer group and state dict updates
arendu Apr 8, 2023
eacf1b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2023
1e581a9
adapter ptuning working for pp>1
arendu Apr 8, 2023
fd76ce1
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 8, 2023
ec1212b
adapter defaults
arendu Apr 8, 2023
07b7456
adapter ptuining config defaults
arendu Apr 8, 2023
b9bb8fa
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu Apr 10, 2023
71d685a
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu Apr 12, 2023
91e8f4a
training works
arendu Apr 14, 2023
0c06620
loading and saving adapter only params during training
arendu Apr 14, 2023
643cab6
added checks and comments
arendu Apr 14, 2023
009928f
clean up
arendu Apr 14, 2023
7d62b18
checks for grad is None before calling all_reduce
arendu Apr 14, 2023
91287e9
load adapter .nemo file working
arendu Apr 15, 2023
43ab6c3
resume training for adapters
arendu Apr 17, 2023
56a110d
merged and resolved conflicts
arendu Apr 17, 2023
f410078
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
d71a047
peft tuning
arendu Apr 18, 2023
b1d661b
resolve conficts
arendu Apr 18, 2023
532971b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2023
3564cb5
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu Apr 18, 2023
fc9a733
minor
arendu Apr 18, 2023
76f7f3d
file not needed
arendu Apr 18, 2023
24eb2bd
undo prompt learning dataset changes
arendu Apr 18, 2023
d65d206
undo updates to gpt prompt learning model
arendu Apr 18, 2023
363cef0
naming updates
arendu Apr 18, 2023
dfe7671
decoding
arendu Apr 19, 2023
1204027
predict_step in gpt_sft_model
arendu Apr 20, 2023
4199546
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 20, 2023
7ac4d8e
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu Apr 24, 2023
4cc768e
updates
arendu Apr 24, 2023
9094282
merged
arendu Apr 24, 2023
8be5bba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2023
e11965c
removed inference from tuning config
arendu Apr 24, 2023
76dcc64
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 24, 2023
ab2fb09
no test in peft training
arendu Apr 24, 2023
d42bded
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu Apr 24, 2023
17e9157
answer only loss and correct defaults for val_loss
arendu Apr 24, 2023
d73e2ae
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 24, 2023
b001fcd
hybrid adapters and ptuning
arendu Apr 24, 2023
6f02c0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2023
56ba439
eval working..
arendu Apr 25, 2023
1a9524e
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 25, 2023
a4f3016
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2023
12d6a6a
prepending tokens for ptuning
arendu Apr 25, 2023
6a51758
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu Apr 25, 2023
613653e
cleaned up eval config
arendu Apr 25, 2023
95d6a24
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 25, 2023
e75f1b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2023
92c1881
clean up
arendu Apr 25, 2023
aa96731
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 25, 2023
25d7bea
update
arendu Apr 25, 2023
4d14f96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2023
4328edd
default prompt template
arendu Apr 26, 2023
14ef26d
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 26, 2023
4b2a856
Lora added
arendu Apr 26, 2023
1bde903
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2023
51c41e9
Support synamic length with GPT SFT
aklife97 Apr 27, 2023
73c0c25
make branch functional
aklife97 Apr 28, 2023
e226940
merged sft dymanic len fix
arendu Apr 28, 2023
429554c
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu Apr 28, 2023
b9e8494
defaults to max_pad_length=False in GPT SFT dataset
arendu Apr 28, 2023
4e3d293
adapter parallel_adapters to support Lora
arendu Apr 28, 2023
d0042f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2023
9711d37
added early stopping by default
arendu May 1, 2023
7510238
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu May 1, 2023
823a71f
eval script for peft and eval config. bug fixes in predict step and a…
arendu May 2, 2023
1f6863b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
ce5ea7a
updates
arendu May 2, 2023
a16636e
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu May 2, 2023
166d91c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
6e8d551
updates
arendu May 2, 2023
3d72c87
updates
arendu May 2, 2023
1e5b62a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
2c209ae
docs
arendu May 2, 2023
a148d8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
3fa310c
better defaults
arendu May 2, 2023
7ff1dfe
updates
arendu May 2, 2023
248c260
Merge branch 'adithyare/ptuning_w_mixin' of https://github.com/NVIDIA…
arendu May 2, 2023
546e227
update
arendu May 2, 2023
7f83a83
docs
arendu May 2, 2023
ea93b73
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu May 2, 2023
a9afe96
Merge branch 'main' into adithyare/ptuning_w_mixin
arendu May 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
name: adapter_ptuning

trainer:
devices: 1
accelerator: gpu
num_nodes: 1
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
max_epochs: -1
max_steps: 100 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 0.2
accumulate_grad_batches: 1
gradient_clip_val: 1.0
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
benchmark: False


exp_manager:
explicit_log_dir: null
exp_dir: null
name: ${name}
create_wandb_logger: null
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 1
mode: min
save_nemo_on_train_end: True # Should be false, correct prompt learning model file is saved at model.nemo_path set below,
filename: 'megatron_gpt_adapter_ptuning--{val_loss:.3f}-{step}'
model_parallel_size: ${model.tensor_model_parallel_size}
always_save_nemo: True
save_best_model: True
create_early_stopping_callback: True
early_stopping_callback_params:
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True

model:
seed: 1234
nemo_path: ${exp_manager.exp_dir}/${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved
virtual_prompt_style: 'no-prompts' # adapter tuning requires no virtual prompts
encoder_seq_length: 2048
gradient_as_bucket_view: false
tensor_model_parallel_size: 1 # intra-layer model parallelism
arendu marked this conversation as resolved.
Show resolved Hide resolved
pipeline_model_parallel_size: 1 # inter-layer model parallelism
global_batch_size: 8
micro_batch_size: 4
validation_global_batch_size: ${model.global_batch_size}
validation_micro_batch_size: ${model.micro_batch_size}
validation_drop_last: False

restore_path: null # Path to an existing adapter .nemo model you wish to add new tasks to or run inference with
language_model_path: ??? # Path to the GPT language model .nemo file, always required
existing_tasks: [] # List of tasks the model has already been p-tuned/prompt-tuned for, needed when a restore path is given
new_tasks: ["rte"] # List of new tasknames to be prompt-tuned

task_templates: # Add more/replace tasks as needed, these are just examples
- taskname: "boolq" # The task name
prompt_template: "Passage: {passage} \nQuestion: {question} \nAnswer: {answer}" # Prompt template for task, specify virtual prompt positions with <|VIRTUAL_PROMPT_#|>
total_virtual_tokens: 0 # Sum of tokens in virtual_token_splits must add to this number. Can differ between new and existing tasks, but must match across all new tasks being tuned at the same time.
virtual_token_splits: [] # number of virtual tokens to be inserted at each VIRTUAL PROMPT location, must add to total_virtual_tokens
truncate_field: "passage" # The {field} in the prompt template whose text will be truncated if the input is too long, if null, inputs that are too long will just be skipped.
answer_only_loss: True
answer_field: "answer"

- taskname: "intent_and_slot"
prompt_template: "intent options: {intent_options} slot options: {slot_options} {utterance} \nintent: {intent} \nslot: {slot}"
total_virtual_tokens: 10
virtual_token_splits: [10]
answer_only_loss: False
truncate_field: null

- taskname: "rte"
prompt_template: "sentence1: {premise} sentence2: {hypothesis} Answer: {answer}"
total_virtual_tokens: 10
virtual_token_splits: [10]
truncate_field: null
answer_only_loss: True
answer_field: "answer"

- taskname: "squad"
prompt_template: "context: {context} question: {question} answer: {answer}"
total_virtual_tokens: 10
virtual_token_splits: [10]
truncate_field: null
answer_only_loss: True
answer_field: "answer"

- taskname: "arc-challenge"
prompt_template: "question: {question} choices: {choices} answer: {answer}"
total_virtual_tokens: 10
virtual_token_splits: [10]
truncate_field: null
answer_only_loss: True
answer_field: "answer"

- taskname: "xsum"
prompt_template: "{source} Summary: {target}"
total_virtual_tokens: 10
virtual_token_splits: [10]
truncate_field: null
answer_only_loss: True
answer_field: "target"

- taskname: "taskname"
prompt_template: "{prompt} {completion}"
total_virtual_tokens: 10
virtual_token_splits: [10]
truncate_field: "prompt"
answer_only_loss: True
answer_field: "completion"

prompt_encoder_adapter:
virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence
bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck
embedding_dim: 1024 # the size of the prompt encoder embeddings
init_std: 0.023

data:
train_ds: ??? # expects a list of paths to training data files
validation_ds: ??? # expects a paths to validation data files
add_eos: True
shuffle: True
num_workers: 8
pin_memory: True

optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 50
constant_steps: 0 # Constant steps should also be 0 when min_lr=0
min_lr: 0.0 # min_lr must be 0.0 for prompt learning
monitor: val_loss
reduce_on_plateau: false
110 changes: 110 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_adapter_ptuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment

from nemo.collections.nlp.models.language_modeling.megatron_gpt_adapter_model import (
MegatronGPTAdapterLearningModel,
MegatronPTuningAdapterLearningModel,
)
Fixed Show fixed Hide fixed
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)

"""
This is the script to train an Adapter infused GPT Model for text generation.
A base GPT Model is required as a starting point. This script will then insert
Adapters into each Transformer layer and will train/update only these adapters
during training. The base GPT Model weights will remain frozen.

During training this script will only save the newly trained Adapter weights
in checkpoints. At the end of training a .nemo file of Adapter weights will
be saved.

Usage:
Assuming the base model is a 125m GPT Model, with TP=1, PP=1:
a. run a training run for a base gpt nemo file:
python megatron_gpt_adapter_tuning.py \
"model.data.train_ds=[PATH TO TRAINING JSONL FILE]",
"model.data.validation_ds=[PATH TO VALIDATION JSONL FILE]",
model.language_model_path="PATH TO BASE GPT MODEL .nemo FILE"
name="NAME OF TRAINING RUN"
exp_manager.exp_dir="DIR TO SAVE CHECKPOINTS and .nemo FILE",
trainer.max_epochs=2
"""


@hydra_runner(config_path="conf", config_name="megatron_gpt_adapter_ptuning_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam'

plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2 and not with_distributed_adam:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
exp_manager(trainer, cfg.exp_manager)

# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

# load existing or init new soft prompt GPT model
if cfg.model.get("restore_path", None):
model = MegatronPTuningAdapterLearningModel.restore_from(
cfg.model.restore_path, cfg.model, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector()
)
else:
model = MegatronPTuningAdapterLearningModel(cfg.model, trainer=trainer)

trainer.fit(model)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(
self.add_eos = add_eos
self.for_train = for_train
self.examples = []

if not self.for_train:
self.tokens_to_generate = tokens_to_generate

Expand Down Expand Up @@ -134,16 +133,13 @@ def load_data(self, dataset):
prompt_template = self.task_templates[taskname]["prompt_template"]
prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"]
total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"]
virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"]
truncation_field = self.task_templates[taskname]['truncate_field']
answer_only_loss = self.task_templates[taskname]["answer_only_loss"]
answer_field = self.task_templates[taskname]["answer_field"]

input_example = prompt_template

self._input_sanity_checks(
total_virtual_tokens,
virtual_token_splits,
prompt_template,
prompt_template_fields,
truncation_field,
Expand All @@ -153,26 +149,25 @@ def load_data(self, dataset):
)

# Format the input example according to the template
input_example = input_example.replace(
"<|VIRTUAL_PROMPT_0|>", ""
).strip() # Ignores virtual prompt positions in template...
input_example = self._insert_text_in_template(input_example, prompt_template_fields, doc)
input_example = self._insert_virtual_token_placeholders(input_example, virtual_token_splits)
input_ids = self.tokenizer.text_to_ids(input_example)

# Add BOS/EOS if desired, adds EOS by default
if self.add_bos:
input_ids = [self.tokenizer.bos_id] + input_ids
if self.add_eos:
input_ids = input_ids + [self.tokenizer.eos_id]

pad = [
self.pad_token_id
] * total_virtual_tokens # Pad tokens are placed and will be replaced by virtual embeddings.
input_ids = pad + input_ids
# Try to truncate input text to fit into the max sequence length
if len(input_ids) > self.max_seq_length:
input_ids = self._truncate_input(
truncation_field,
input_ids,
taskname,
doc,
prompt_template,
prompt_template_fields,
virtual_token_splits,
truncation_field, input_ids, taskname, doc, prompt_template, prompt_template_fields,
)

# Skip example if the final length doesn't fit length requirements even after truncation
Expand All @@ -198,7 +193,6 @@ def load_data(self, dataset):
def _input_sanity_checks(
self,
total_virtual_tokens,
virtual_token_splits,
prompt_template,
prompt_template_fields,
truncation_field,
Expand All @@ -211,16 +205,6 @@ def _input_sanity_checks(
total_virtual_tokens < self.max_seq_length
), "virtual prompt tokens should not exceed max sequence length"

# Make sure virtual token splits add up to the total number of virtual tokens
assert (
sum(virtual_token_splits) == total_virtual_tokens
), "Sum of prompt token split values must equal total number of prompt tokens"

# Make sure number of virtual prompt locations match the number of virtual prompt splits
assert prompt_template.count('<|VIRTUAL_PROMPT_') == len(
virtual_token_splits
), "The number of '<|VIRTUAL_PROMPT_n|>' markers and the number of prompt token splits must match"

# Check if input example has fields not present in template
keys_not_in_template = list(set(doc.keys()) - set(prompt_template_fields) - set(['taskname']))
assert (
Expand Down Expand Up @@ -254,21 +238,8 @@ def _insert_text_in_template(self, input_example, prompt_template_fields, doc):

return input_example.strip(" ")

def _insert_virtual_token_placeholders(self, input_example, virtual_token_splits):
""" Insert the correct number of pseudo tokens at the <|VIRTUAL_PROMPT_n|> markers """
total_inserted_tokens = 0

for idx in range(len(virtual_token_splits)):
split_start = total_inserted_tokens
split_end = total_inserted_tokens + virtual_token_splits[idx]
pseudo_tokens_for_split = "".join(self.pseudo_tokens[split_start:split_end])
input_example = input_example.replace(f'<|VIRTUAL_PROMPT_{idx}|>', pseudo_tokens_for_split)
total_inserted_tokens = split_end

return input_example

def _truncate_input(
self, truncation_field, input_ids, taskname, doc, prompt_template, prompt_template_fields, virtual_token_splits
self, truncation_field, input_ids, taskname, doc, prompt_template, prompt_template_fields,
):
""" Try to truncate input text to fit into the max sequence length """
logging.info(
Expand All @@ -289,7 +260,6 @@ def _truncate_input(
# Re-insert the truncated text string into the text prompt
input_example = prompt_template
input_example = self._insert_text_in_template(input_example, prompt_template_fields, doc)
input_example = self._insert_virtual_token_placeholders(input_example, virtual_token_splits)

# Re-tokenize the whole prompt
input_ids = self.tokenizer.text_to_ids(input_example)
Expand Down
Loading