-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Add inputs_embeds support when generating with GPT-J
#21575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
inputs_embeds support when generating with GPT-J
…sformers into dzmitry/gptj_input_embeds
|
The documentation is not available anymore as the PR was closed or merged. |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the addition 👍
|
@dimitry12 to clarify, If you don't pass them, the output of |
|
@dimitry12 oh I see, the case when the batch size is larger than one is not handled, it is the same issue as in #21578! I'll open a PR soon that fixes it |
|
@gante thank you for the prompt review! Is merging done by HF staff? |
|
Hey @dimitry12 yes it is, but we still need a review from our transformers master @sgugger :) |
Yes, it fails without dummy To replicate, a slightly modified results in: @gante I can open a draft PR with an updated failing test, and see if I can figure it out or if your planned fix for #21578 will also fix it. Please advise what the best process is, definitely willing to help here. |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
|
(For context, the fix to the issue above is in #21580. Merging :D ) |
What does this PR do?
This PR extends #21405 by @gante to GPT-J, making it accept
inputs_embedswhen generating.This is generally useful for soft-prompting but I am specifically using this with https://github.com/jmerullo/limber by @jmerullo
Importantly, I find that dummy
input_idsare still required. Sample code using this feature with GPT-J: