-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalized the P-tuning method to support various NLP tasks #3623
Conversation
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
This pull request introduces 10 alerts when merging 844fc83 into 4697662 - view on LGTM.com new alerts:
|
This pull request introduces 10 alerts when merging 846cf4d into 4697662 - view on LGTM.com new alerts:
|
846cf4d
to
a94c8af
Compare
Signed-off-by: Yi Dong <[email protected]>
This pull request introduces 10 alerts when merging a94c8af into 4697662 - view on LGTM.com new alerts:
|
Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
This pull request introduces 11 alerts when merging 3a7b4fb into c3101cc - view on LGTM.com new alerts:
|
This pull request introduces 11 alerts when merging 306d9bc into 8e15ba4 - view on LGTM.com new alerts:
|
Can you take a look at these? |
This pull request introduces 11 alerts when merging a6bfd92 into 64eb620 - view on LGTM.com new alerts:
|
# shared params for dataset and data loaders | ||
# tokenizer needs to get initialized before the super.__init__() | ||
# as dataloaders and datasets need it to process the data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean up comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned.
# register the file containing the labels into the artifacts to get stored in the '.nemo' file later | ||
self.embeddings = self.model.model.language_model.embedding.word_embeddings | ||
|
||
# self.vocab = self.tokenizer.tokenizer.get_vocab() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned
# _, returned_pred = self.get_prediction(batch_size, label_position, logits) | ||
# returned_label = self.get_ground_truth_labels(batch_size, label_ids) | ||
# return floss, returned_pred, returned_label |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned
# predicted_tokens_dec = torch.cat([predicted_tokens_dec, token_ids[:, -1].unsqueeze(1)], 1) | ||
# new_pred = torch.zeros_like(token_ids[:, 0:1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned
# slow version of above scatter logics | ||
# for bidx in range(bz): | ||
# position = blocked_indices[bidx].nonzero()[:, 0] | ||
# for i in range(len(position)): | ||
# raw_embeds[bidx, position[i], :] = replace_embeds[bidx, i, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will leave it since it helps people to understand the code.
"id": "0b0f69a9", | ||
"metadata": {}, | ||
"source": [ | ||
"## Download the SQuDA dataset\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
This pull request introduces 11 alerts when merging 4d3ebc6 into 9aef14f - view on LGTM.com new alerts:
|
Signed-off-by: Yi Dong <[email protected]>
This pull request introduces 11 alerts when merging f60f7ab into 298b686 - view on LGTM.com new alerts:
|
This pull request introduces 11 alerts when merging 6e6a306 into c645c4c - view on LGTM.com new alerts:
|
This pull request introduces 11 alerts when merging 85100b4 into 461a866 - view on LGTM.com new alerts:
|
This pull request introduces 11 alerts when merging 14739bc into d364b3f - view on LGTM.com new alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
* memory fix Signed-off-by: Yi Dong <[email protected]> * fix exp cause inf Signed-off-by: Yi Dong <[email protected]> * clean comments Signed-off-by: Yi Dong <[email protected]> * added the ptune gpt model Signed-off-by: Yi Dong <[email protected]> * fix the decode Signed-off-by: Yi Dong <[email protected]> * text gen loss is done Signed-off-by: Yi Dong <[email protected]> * working now Signed-off-by: Yi Dong <[email protected]> * fix style Signed-off-by: Yi Dong <[email protected]> * auto infer decoder len Signed-off-by: Yi Dong <[email protected]> * evaluate on the best checkpoint Signed-off-by: Yi Dong <[email protected]> * update the dataset to handle multiple tasks Signed-off-by: Yi Dong <[email protected]> * added the inference logics Signed-off-by: Yi Dong <[email protected]> * simplify the data processor Signed-off-by: Yi Dong <[email protected]> * task dependent Signed-off-by: Yi Dong <[email protected]> * multiple task working Signed-off-by: Yi Dong <[email protected]> * performance improvement Signed-off-by: Yi Dong <[email protected]> * fix the old ptune classification Signed-off-by: Yi Dong <[email protected]> * added tutorial notebook Signed-off-by: Yi Dong <[email protected]> * remove the outputs Signed-off-by: Yi Dong <[email protected]> * added template data processor Signed-off-by: Yi Dong <[email protected]> * fix bug of non task depedent and detects negative cut Signed-off-by: Yi Dong <[email protected]> * remove the sentiment analysis notebook Signed-off-by: Yi Dong <[email protected]> * comments clean up Signed-off-by: Yi Dong <[email protected]> Co-authored-by: Eric Harper <[email protected]>
What does this PR do ?
It generalizes the P-tuning method to support various NLP tasks including text-cls, NER, text summation, Q/A etc. It supports multiple task with one prompt encoder model. It also includes a tutorial notebook to show how to use it.
Changelog
Usage
Check the notebook
https://github.com/NVIDIA/NeMo/blob/generalized_gpt_ptuning/tutorials/nlp/PTune_multiple_NLP_tasks.ipynb
PR Type: