-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Generate: FLAX infers pad token in its absence and has functional example #21009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| eos_token_id = generation_config.eos_token_id | ||
| if isinstance(eos_token_id, list): | ||
| eos_token_id = eos_token_id[0] | ||
| logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | ||
| generation_config.pad_token_id = eos_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took the opportunity also to copy the logic to TF, so it can also handle eos_token_id as a list 👀
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
|
The documentation is not available anymore as the PR was closed or merged. |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix @gante!
What does this PR do?
Some bug fixing in advance of #21007 (PR that adds generation config to Flax), to ensure we start from a functional flax generate codebase.
In particular:
pad_token_idwhen it isNoneandeos_token_idis notNone, like TF and PT do. This is very helpful for open text generation examples, like with GPT2, was an open request (Generating with Flax fails when using Causal Language models #18884), and was one of the causes for failure in the existing example. This also includes the recent changes of Add custom stop token ids for generation #20727, whereeos_token_idcan be a list of tokens.int32type specification was missing in the special tokens -- when converted to JAX variables, JAX assumed they werefloat32;