-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[Whisper + beam search] fix usage of beam_indices
#38259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| tensor containing the timestamps in seconds for each predicted token | ||
| """ | ||
| # Create a list with `decoder_layers` elements, each a tensor of shape | ||
| # (batch size, attention_heads, output length, input length). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape comments were incorrect for the case w/beam search
| weight_length = None | ||
|
|
||
| if "beam_indices" in generate_outputs: | ||
| # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this if block, I've rewritten comments to better explain what's happening
| # beam search takes `decoder_input_ids` into account in the `beam_indices` length | ||
| # but forgot to shift the beam_indices by the number of `decoder_input_ids` | ||
| beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length]) | ||
| # we actually shift the beam indices here | ||
| beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct with the beam search refactor (#35802): beam_indices was corrected to have the same output length as the other optional outputs (= length of generated tokens)
| # In that case, the `cross_attentions` weights are too long and we have to make sure that they have | ||
| # the right `output_length` | ||
|
|
||
| weights = weights[:, :, :weight_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant: we rebuild weights below with sequence length = range(unrolled_beam_indices.shape[1]) (= weight_length)
| # since the beam search strategy chooses the most probable sequences at the end of the search. | ||
| # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length | ||
| weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() | ||
| weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
root issue of #36093: weight_length is off by 1. The comments in the new version explain why :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits on the shape comments. Important step to have something to "work" again even if it's not producing the correct output quality-wise at first :)
Co-authored-by: Anton Vlasjuk <[email protected]>
What does this PR do?
Fixes the shape issues reported in #36093, which have been around since the code was added 👀 . It doesn't fix the quality of word timestamp outputs (see e.g. #36632), but rather how we gather the cross attentions from the right beams with beam search, which was broken.
test_tiny_token_timestamp_batch_generationis a test that has the same pattern, beam search + timestamps, and is failing onmainwith the same exception as reported in #36093. This PR does NOT fix that test, but allows the test to move past the shape exception until the output quality checks, which are broken 🙃