-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
P-tuning refactor Part 3/N #6106
Conversation
… attribute Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: arendu <[email protected]>
…NVIDIA/NeMo into adithyare/refac_ptuning_part3.2
Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: arendu <[email protected]>
…NVIDIA/NeMo into adithyare/refac_ptuning_part3.2
for more information, see https://pre-commit.ci
@@ -291,7 +291,7 @@ def get_prompt_token_labels_for_megatron_gpt(self, input_ids, num_prompt_tokens) | |||
|
|||
def get_virtual_prompt_ids_for_megatron_gpt(self, input_ids): | |||
if ( | |||
self.cfg.virtual_prompt_style == VirtualPromptStyle.PROMPT_TUNING | |||
self.cfg.virtual_prompt_style == VirtualPromptStyle.P_TUNING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VirtualPromptStyle.PROMPT_TUNING
is removed from the enum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay. saw it. NVM
total_virtual_tokens=total_virtual_tokens, | ||
token_dim=self.hidden_size, | ||
hidden_size=self.cfg.p_tuning.get("encoder_hidden", 2048), | ||
hidden_size=self.cfg.p_tuning.get("encoder_hidden", self.hidden_size // 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is self.hidden_size // 2
a good number to use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a better default, for when the base model has different hidden sizes.
@@ -62,7 +64,7 @@ | |||
__all__ = ['MegatronGPTPromptLearningModel'] | |||
|
|||
|
|||
class MegatronGPTPromptLearningModel(MegatronBaseModel, TextGeneration): | |||
class MegatronGPTPromptLearningModel(MegatronBasePromptLearningModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why drop the TextGeneration
interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base model containers it now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base class now has TextGenetation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Great to see a lot of redundant code removed from this PR. Minor stylistic changes, but ok to merge, if you can address them in the next PR
@@ -59,7 +58,9 @@ class MegatronBasePromptLearningModel(MegatronBaseModel, TextGeneration): | |||
|
|||
def __init__(self, cfg: DictConfig, trainer: Trainer): | |||
super().__init__(cfg, trainer) | |||
self.init_model(cfg, trainer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see here (in the lines above) that you have imports you don't use - maybe remove them in the refactor 4/n ? vscode should be able to identify them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's quite a few in other files as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! will do!
raise ValueError( | ||
f"\nvirtual prompt style '{cfg.virtual_prompt_style}' not recognized, please use one of 'prompt-tuning' or 'p-tuning'" | ||
) | ||
raise ValueError(f"\nvirtual prompt style '{cfg.virtual_prompt_style}'") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please indicate what they should use instead. e.g please use p-tuning
@@ -211,68 +170,25 @@ def init_prompt_encoder(self): | |||
new_task = self.new_tasks[0] | |||
total_virtual_tokens = self.task_templates[new_task]["total_virtual_tokens"] | |||
|
|||
encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "mlp").lower()) | |||
encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "tpmlp").lower()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why default was changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason, it was during my copy based from the child model.
self.prompt_encoder.load_state_dict(state_dict_, strict) | ||
|
||
def embed_input_train(self, input_ids: Tensor, taskname_ids: Tensor): | ||
raise ValueError("invalid virtual prompt source") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please indicate what is the correct config. E.g. please set this cfg.... to p_tuning
* patch to allow using tokenizers without additional_special_tokens_ids attribute Signed-off-by: arendu <[email protected]> * steps to inheret gpt from base Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * steps to inheret gpt from base Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix in pred step Signed-off-by: arendu <[email protected]> * moved to base model Signed-off-by: arendu <[email protected]> * removing prompt table class moved into prompt encoder class Signed-off-by: arendu <[email protected]> * minor fix Signed-off-by: arendu <[email protected]> * minor fix Signed-off-by: arendu <[email protected]> * changes in dialogue models Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor updates to classifications Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* patch to allow using tokenizers without additional_special_tokens_ids attribute Signed-off-by: arendu <[email protected]> * steps to inheret gpt from base Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * steps to inheret gpt from base Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix in pred step Signed-off-by: arendu <[email protected]> * moved to base model Signed-off-by: arendu <[email protected]> * removing prompt table class moved into prompt encoder class Signed-off-by: arendu <[email protected]> * minor fix Signed-off-by: arendu <[email protected]> * minor fix Signed-off-by: arendu <[email protected]> * changes in dialogue models Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor updates to classifications Signed-off-by: arendu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: hsiehjackson <[email protected]>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Collection: [Note which collection this PR will affect]
Changelog
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information