-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[generate] ✨ vectorized beam search ✨ #35802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool 🔥 Thanks for providing benchmark with vectorized beams
Very nice idea to add comments, it is quite hard to follow the logic sometimes. I believe beam search will be tweaked mostly by advanced users, anyway. I still got lost trying to read the code hehe
src/transformers/generation/utils.py
Outdated
| top_num_beam_mask = torch.cat( | ||
| ( | ||
| torch.ones((num_beams), dtype=torch.bool), | ||
| torch.zeros((beams_to_keep - num_beams), dtype=torch.bool), | ||
| ), | ||
| dim=0, | ||
| ).to(next_token_hits_stopping_criteria.device) | ||
| did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part got me lost a bit. Aren't we choosing beam candidates that already have stopped with eos, but happened to be on indices at >num_beams? Since the mask zeroes out the second half after num_beams
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did_top_num_beams_just_finished selects beams that are valid to be considered as finished, i.e. in the top num_beams and that have hit some stopping criteria.
Beams in the latter part of the num_beams dimension are indeed discarded by this mask. But this mask is used to tag "finished beams", not the beams we will use in the next iteration of beam search. Finished beams and only come from the top num_beams beams.
Does this make sense? How can I make this part clearer in the comments?
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me new logic with separate functions made things a bit clear, in a way that each small function has a dosctring, explaining what we do. I think docstring are more verbose than the comments we had prev. But it is a bit annoying to keep going forth and back when reading.
If I had to choose, I would go with prev version with a note in mind for refactoring. AFAIK think most parts of generate will get separated out to functions, when we start refactoring. Decoding methods share a lot and imo we can decompose it.
That way people can also rewrite and plug their own functions. In current way, it doesn't seem to be easy to rewrite a small function (say for ex how sampling is done in get_top_k_continuations) and plug it in to beam search, without having to copy the whole decoding method. WDYT? Do we even need to offer users to override and plug small functions?
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- let's move functions outside scope of functions themselves
- let's keep what you did!
The main arguments:
- re-usabliity: users can easily change a small part of a function
- more readable
- easier to debug and isolate changes to a single function!
🤗
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh and
merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1)
merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1)
merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1)
merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1)
is indeed better ( more readable and less slicing)
as you want, for 1% perf not super important
|
@zucchini-nlp @ArthurZucker PR comments addressed (namely, fns moved out |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
What does this PR do?
✨ Vectorizes beam search ✨
Fixes #34843
Fixes #35618
Fixes #34574
Groundwork for: #35561 #35451 #35436 #32097 #30810
Other related issues: #35023
BeamScorerinside_beam_search(). Advanced users relying on their customBeamScorerinstances should pin transformers to<=4.48for now. Custom beam scoring metrics will be added in a future PR;low_memory=Truepath [which runs the forward pass for each beam in sequence rather than in parallel] was not coded in the refactor. Since this flag was added, several memory-reduction techniques have been introduced, so I'm assuming it's niche now. It would also bloat the code significantly. An informative error message was added.Note to reviewers
Don't look at the diff for beam search, but rather read the new
_beam_searchas if it was a new function. One of goals of the refactor is improving code readability for future hacking (beam search is a form of test-time compute), and thus the new code is heavily commented -- please add a review comment if a certain part of the code isn't immediately understandable, so we can improve it.Benefits of the vectorization:
num_beams(see benchmarks below). Note that we can further improve these numbers --beam_searchshould now be compatible withtorch.compile(to be explored in a future PR);Benchmarks
Test script (small model, beam search, no compilation)
Measured on an RTX4090
Throughput (tokens/sec):
Memory (MB):
Tests run
checked box = no regressions vs
mainRUN_SLOW=1 py.test tests/models/gpt2/test_modeling_gpt2.py -vv-- slow gpt2 tests (reference model for older implementations)RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv-- slow llama tests (key model)RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv-- slow whisper tests (key model, has hard beam search tests) [note: 13 failures onmain]RUN_SLOW=1 py.test tests/models/t5/test_modeling_t5.py -vv-- slow t5 tests (has hard beam search tests)RUN_SLOW=1 py.test tests/models/bart/test_modeling_bart.py -vv-- slow bart tests (has hard beam search tests)RUN_SLOW=1 py.test tests/models/rag/test_modeling_rag.py -vv-- slow rag tests (has hard beam search tests) [note: 4 failures onmain]RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv-- slow cache tests (sanity check)Follow-up
BeamScorerconstrained_beam_search) -- maybe we can expose the scoring method, and write the other beam methods as a scoring modification?