-
Notifications
You must be signed in to change notification settings - Fork 31.8k
[WIP] Add Flax diverse group search #24508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This look promising already @yeandy! Left some comments regarding the design below. In addition, could we add a few tests to confirm that:
- Group beam search runs when we call
model.generate - That group beam search is jit'able
- And that we get equivalence with PyTorch
| trace: bool = True, | ||
| params: Optional[Dict[str, jnp.ndarray]] = None, | ||
| num_return_sequences: Optional[int] = None, | ||
| num_beam_groups: Optional[int] = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PyTorch we define a separate beam search method for group beam search:
transformers/src/transformers/generation/utils.py
Line 3375 in 33b5ef5
| def group_beam_search( |
We only trigger this method if num_beam_groups>1:
transformers/src/transformers/generation/utils.py
Line 1426 in 33b5ef5
| is_group_beam_gen_mode = ( |
My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one. IMO this is cleaner for the reader and more compartmentalised for building on
cc @gante as well for Flax generate design decision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sanchit-gandhi!
My first commit was to get a prototype working for num_beam_groups=1. I intend to refactor the beam search logic to make sure it works for other num_beam_groups sizes.
- Will do.
- My current logic is jittable, as I've been doing some testing from this example. Are there test in the HF repo that explicitly test whether a function is jittable? Or is sufficient to have an E2E test jits the function?
- Will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one.
+1 :)
(btw, there was a recent bugfix on the PT side, might be relevant here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, sounds good @yeandy! Excited to see how this pans out!
| add_penalty = ~did_topk_just_finished | beams_in_batch_are_full | ||
| topk_log_probs += add_penalty * np.array(-1.0e7) | ||
|
|
||
| # Add additional logic for diverse beam search |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! My only nit is that we try and avoid lambda functions in transformers - would you be able to re-write these as standard function definitions please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely thanks
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Hey @yeandy! This PR is looking in good shape - thanks for your efforts so far! Would you like to go all the way and see it to completion? Happy to help with the remainder of the integration! |
|
Hey @sanchit-gandhi. Due to other commitments, I currently don't have bandwidth to continue this. And the timeline for me to get to this unknown right now. If someone else wants to work on this, I'm ok with that. |
|
Thanks for letting me know @yeandy! Best of luck with your other commitments, I hope they go well 🤗 Opening this one up to the community to complete! |
|
For those who wonder what the status is for this PR, it seems all TF/Flax support has been deprecated. So this PR is no longer in scope. |
|
Yes, this should have been closed long ago! |
What does this PR do?
Mimics #9006, but for Flax.
We want to match how PyTorch's logic accounts for
group_sizeandnum_beam_groupshere and hereFixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.