Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

P-tuning refactor Part 2/N #6056

Merged
merged 105 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
2b95406
patch to allow using tokenizers without additional_special_tokens_ids…
arendu Dec 15, 2022
c131a90
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Dec 16, 2022
9e15c3a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Dec 20, 2022
d0e3669
merge main
arendu Jan 5, 2023
0a19a5a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
ec3d57b
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
64e36ba
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
5bfde7e
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 7, 2023
b04b145
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 10, 2023
b1906ab
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 11, 2023
9795062
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 12, 2023
0f83085
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 19, 2023
ee4dd1a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 20, 2023
53ba0b2
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 23, 2023
a6aee2a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 24, 2023
33442d4
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 30, 2023
8e6c5c9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 31, 2023
efd263c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 8, 2023
ecfda4f
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 10, 2023
15aee0c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 15, 2023
3fe8e34
early stop callback for prompt/p tuning
arendu Feb 15, 2023
0d2666c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2023
1af3e79
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
ecbb118
update
arendu Feb 15, 2023
8af4d3e
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 15, 2023
b0d9ea4
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
8651935
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
2b569e5
added exp manager config for early stop
arendu Feb 15, 2023
d6f48d1
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 15, 2023
409d94e
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
2b7f3de
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 15, 2023
a10284d
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
f62cde9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 15, 2023
f155911
pushed logic for creating early stopping inside exp manager
arendu Feb 15, 2023
858d46a
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
6ae6188
pushed logic for creating early stopping inside exp manager
arendu Feb 15, 2023
0aefe59
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 15, 2023
2f1111d
minor updates and added dataclass check
arendu Feb 16, 2023
60e0d25
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
2b8d254
more args
arendu Feb 16, 2023
cac3df9
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 16, 2023
2f35842
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 16, 2023
a176efb
more args
arendu Feb 16, 2023
c31d4aa
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 16, 2023
4e90eef
Merge branch 'main' into adithyare/refac_ptuning
arendu Feb 16, 2023
83c2b69
wrap tpmlp inside prompt encoder
arendu Feb 17, 2023
31915c9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 17, 2023
f381bb5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
8d11d53
Merge branch 'main' into adithyare/refac_ptuning
arendu Feb 17, 2023
cc28888
Merge branch 'adithyare/refac_ptuning' of https://github.com/NVIDIA/N…
arendu Feb 17, 2023
1ed4bd7
Merge branch 'main' into adithyare/refac_ptuning
arendu Feb 17, 2023
30d5f88
Merge branch 'adithyare/refac_ptuning' of https://github.com/NVIDIA/N…
arendu Feb 17, 2023
65530b9
updates removed unused imports
arendu Feb 17, 2023
c6b367d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
9cab271
Merge branch 'main' into adithyare/refac_ptuning
arendu Feb 17, 2023
23fcef1
removes typecheck for tpmlp module
arendu Feb 18, 2023
df907c6
Merge branch 'adithyare/refac_ptuning' of https://github.com/NVIDIA/N…
arendu Feb 18, 2023
486d2c7
refac
arendu Feb 18, 2023
b893a84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2023
f849619
removing refs to PROMPT_TABLE
arendu Feb 18, 2023
29b3790
resolved conflict
arendu Feb 18, 2023
d08094b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2023
af803b0
Merge branch 'main' into adithyare/refac_ptuning_part2
arendu Feb 18, 2023
127ec6c
remove ref to PROMPT_TABLE
arendu Feb 18, 2023
b0968f9
minor fix
arendu Feb 18, 2023
fa22a1f
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 18, 2023
9012d48
merged conficts
arendu Feb 19, 2023
e34e1bd
merged conficts
arendu Feb 19, 2023
47a124d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2023
9961382
bug fix with tpmlp
arendu Feb 19, 2023
567ed50
Merge branch 'adithyare/refac_ptuning_part2' of https://github.com/NV…
arendu Feb 19, 2023
6dc9603
inference seems to be working
arendu Feb 19, 2023
adcdcd7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2023
d06ec0a
phasing out prompt learning in t5
arendu Feb 19, 2023
6b70d53
Merge branch 'adithyare/refac_ptuning_part2' of https://github.com/NV…
arendu Feb 19, 2023
65cf1c6
revert prompt table to allow t5 to work
arendu Feb 19, 2023
56420ee
updates
arendu Feb 19, 2023
8e329c5
updates
arendu Feb 19, 2023
b012153
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2023
447c384
make prompt encoder None in init
arendu Feb 19, 2023
5437650
Merge branch 'adithyare/refac_ptuning_part2' of https://github.com/NV…
arendu Feb 19, 2023
4982603
update test
arendu Feb 19, 2023
ad7216c
fixed init
arendu Feb 19, 2023
c2def0f
setting lstm params the old way
arendu Feb 19, 2023
7b790ef
revert t5 dataset
arendu Feb 19, 2023
a26360e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2023
609af27
revert t5 dataset
arendu Feb 19, 2023
4af7825
revert t5 dataset
arendu Feb 19, 2023
4f83a0b
revert t5 dataset
arendu Feb 19, 2023
8b1a20f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2023
1b89180
revert t5 dataset
arendu Feb 19, 2023
88ec311
Merge branch 'adithyare/refac_ptuning_part2' of https://github.com/NV…
arendu Feb 19, 2023
b8e417f
save to now works with pp>1
arendu Feb 20, 2023
287e7f7
revert t5 datasets
arendu Feb 20, 2023
4025a30
unused args
arendu Feb 20, 2023
ca9c4b1
Merge branch 'main' into adithyare/refac_ptuning_part2
arendu Feb 22, 2023
1d7a392
make prompt encoder state_dict backwards compatible
arendu Feb 23, 2023
9bf4154
pipe taskname to prompt encoder
arendu Feb 23, 2023
832a847
update
arendu Feb 23, 2023
6a3849d
update
arendu Feb 23, 2023
dd75591
Merge branch 'main' into adithyare/refac_ptuning_part3
arendu Feb 23, 2023
4623a3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2023
30216ee
Merge branch 'adithyare/refac_ptuning_part3' into adithyare/refac_ptu…
arendu Feb 23, 2023
c93df19
comment
arendu Feb 23, 2023
c2b10cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3361,7 +3361,8 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
trainer.max_epochs=null \
model.data.num_workers=1 \
model.tensor_model_parallel_size=1 \
model.virtual_prompt_style='prompt-tuning' \
model.virtual_prompt_style='p-tuning' \
model.p_tuning.encoder_type='embedding' \
model.language_model_path='/home/TestData/nlp/megatron_gpt/tiny/megatron_14m_gpt_tp1_pp1.nemo' \
model.existing_tasks=[] \
model.new_tasks=['rte'] \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GPTPromptLearningDataset(Dataset):
Args:
data (list[strings], list[dicts]): (1) paths to .jsonl or .json files, (2) dict objects corresponding to each input example
tokenizer (tokenizer): Tokenizer from frozen language model
virtual_prompt_source (Enum): Either VirtualPromptSource.PROMPT_TABLE or VirtualPromptSource.PROMPT_ENCODER
virtual_prompt_source (Enum): Either VirtualPromptSource.NO_PROMPTS or VirtualPromptSource.PROMPT_ENCODER
task_templates (dict): Dictionary containing all task template information needed to format prompts. Created in the GPTPromptLearningModel class.
pseudo_tokens (list[strings]): A list of virtual prompt token placeholders e.g [<prompt_1>, <prompt_2>, ...] up to max num virtual tokens
pad_token_id (int): ID of pad token from tokenizer
Expand Down Expand Up @@ -179,10 +179,6 @@ def load_data(self, dataset):
if self.min_seq_length <= len(input_ids) <= self.max_seq_length:
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
taskname_id = self.tokenizer.text_to_ids(taskname)

elif self.virtual_prompt_source == VirtualPromptSource.PROMPT_TABLE:
taskname_id = self.task_templates[taskname]["task_id_num"]

elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT:
taskname_id = -1
else:
Expand Down Expand Up @@ -342,7 +338,7 @@ def collate_fn(self, batch, tp_workers=0):
taskname_ids = torch.tensor(taskname_ids)

# Task ids are just used for a look up embeddings for prompt-table
elif self.virtual_prompt_source in [VirtualPromptSource.PROMPT_TABLE, VirtualPromptSource.NO_PROMPT]:
elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT:
taskname_ids = torch.tensor(taskname_ids)

# Get max sequence length of batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ def add_ptuned_prompts_to_prompt_table(self):
device = next(self.word_embeddings.parameters()).device
tokenized_taskname = torch.tensor(self.tokenizer.text_to_ids(taskname)).to(device)
taskname_embeddings = self.word_embeddings(tokenized_taskname).unsqueeze(0)
virtual_prompt_embeddings = self.prompt_encoder(taskname_embeddings=taskname_embeddings).squeeze(0)
batch_size = taskname_embeddings.shape[0]
virtual_prompt_embeddings = self.prompt_encoder(batch_size=batch_size, use_cached_reps=False).squeeze(
0
)
total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"]
self.prompt_table.add_prompt_from_p_tuning_encoder(
taskname, virtual_prompt_embeddings, total_virtual_tokens
Expand Down Expand Up @@ -343,8 +346,11 @@ def embed_input_train(self, input_ids: Tensor, taskname_ids: Tensor):
virtual_token_embeds = torch.stack(virtual_token_embeds)

elif self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
taskname_embeddings = self.word_embeddings(taskname_ids)
virtual_token_embeds = self.prompt_encoder(taskname_embeddings=taskname_embeddings)
# taskname_embeddings = self.word_embeddings(taskname_ids)
batch_size, _ = taskname_ids.size()
virtual_token_embeds = self.prompt_encoder(batch_size=batch_size, use_cached_reps=False)
else:
raise RuntimeError("invalid VirtualPromptSource..")

# Create index template specifying where virtual token embeddings should be placed
batch_size, _, embedding_size = discrete_token_embeds.shape
Expand Down
Loading