Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ptuning residuals bug #6866

Merged
merged 12 commits into from
Jun 22, 2023
Merged

fix ptuning residuals bug #6866

merged 12 commits into from
Jun 22, 2023

Conversation

arendu
Copy link
Collaborator

@arendu arendu commented Jun 14, 2023

What does this PR do ?

  • fixes peft ptuning bug which was causing ptuning to use a residual connection
  • simplifies mixin use for all NLP PEFT models
  • ptuning includes a inference only table which is needed for FT

Collection: [NLP]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@github-actions github-actions bot added the NLP label Jun 14, 2023
@arendu arendu requested a review from titu1994 June 14, 2023 05:23
@arendu arendu marked this pull request as ready for review June 14, 2023 05:24
@arendu arendu requested a review from Davood-M June 14, 2023 16:31

def forward(self, batch_size):
def _forward(self,):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to not make this private for subclasses, rename to forward_inner()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!

virtual_embeddings = self.forward_single_enabled_adapter_(
_bs, ptuning_adapter, adapter_name=AdapterName.PTUNING_ADAPTER, adapter_strategy=strategy,
)
virtual_embeddings = ptuning_adapter(_bs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each adapter type has its own strategy - by removing this you're hard coding the logic and side stepping the strategy. Its not necessary to do that is it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, its not necessary, but just makes the readability and maintainability easier. Its just clearer what is happening to read the code and see a residual connection rather than follow up with where a default strategy is coming from.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh one other limitation i think is that i can not pass additional args? for example for ptuning_adapter the forward now accepts an additional arg like used_cached_reps

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion, seems this is never going to be needed by nlp domain, so its fine to directly call adaper

@@ -70,7 +70,7 @@ def __init__(
self.prompt_embeddings.weight.requires_grad = False

# Set fixed indicies for forward pass
self.register_buffer('indices', torch.LongTensor(list(range(self.total_virtual_tokens))))
self.register_buffer("indices", torch.LongTensor(list(range(self.total_virtual_tokens))), persistent=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a breaking change since older peft modules will have this but newer ones wont. Need to check

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! will check!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this will indeed break older p-tuning checkpoints. However, it is a nice "feature" because older checkpoints will need to be converted anyway to a new param naming format. In that conversion step (which need to be written) I will remove the indices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The persistent part will be the issue, but if that's ok then good enough

@@ -70,7 +70,7 @@ def __init__(
self.prompt_embeddings.weight.requires_grad = False

# Set fixed indicies for forward pass
self.register_buffer('indices', torch.LongTensor(list(range(self.total_virtual_tokens))))
self.register_buffer("indices", torch.LongTensor(list(range(self.total_virtual_tokens))), persistent=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The persistent part will be the issue, but if that's ok then good enough

virtual_embeddings = self.forward_single_enabled_adapter_(
_bs, ptuning_adapter, adapter_name=AdapterName.PTUNING_ADAPTER, adapter_strategy=strategy,
)
virtual_embeddings = ptuning_adapter(_bs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion, seems this is never going to be needed by nlp domain, so its fine to directly call adaper

@arendu arendu changed the title fix lora residuals bug fix ptuning residuals bug Jun 14, 2023
@arendu arendu merged commit a8609ab into main Jun 22, 2023
14 of 15 checks passed
@arendu arendu deleted the adithyare/lora_fix branch June 22, 2023 16:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants