-
Notifications
You must be signed in to change notification settings - Fork 31.8k
Generate: correct default model input creation for decoder-only models #21580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: correct default model input creation for decoder-only models #21580
Conversation
|
cc @dimitry12 this fixes the error you reported in #21575 (GPT2 was throwing the same error as GPTJ, and gets fixed here) cc @NielsRogge please don't forget to add batched tests on models with generation capabilities, generate+batching is surprisingly tricky🙏 |
|
The documentation is not available anymore as the PR was closed or merged. |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work on this!
|
But it does seem like it breaks a lot of tests 😅 |
|
(Merging -- the failing test is a known failing test) |
|
@dimitry12 lmk if you see the error when using GPT-J :) |
@gante GPT-J generation without dummy |
What does this PR do?
Fixes #21578 and addresses concerns in #21575
Context
Support for
.generate()frominputs_embedswith selected decoder-only models was added recently (#21405). This feature enables a.generate(inputs_embeds=inputs_embeds)call, i.e. without ainput_ids.This specific call strategy, under the hood, implies that
.generate()is in charge of creating a) aninput_idsfor later use (to concatenate the generated tokens) and b) the correspondingattention_mask. This automated creation was not working properly forbatch_size>1.Changes
The changes in the PR respect the following desiderata (which required moving a few things around):
attention_maskcan be automatically inferred (with all ones) regardless of the shape ofinputs_embeds;inputs_embedsis passed andinput_idsis not, the automatedinput_idshas a sequence length of 1. This is particularly relevant for BLIP, as we don't wantinput_idsto start with the embeddings' sequence length.This PR also adds/enhances tests, to ensure we don't regress on this capability.