Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented May 21, 2025

What does this PR do?

Fixes the shape issues reported in #36093, which have been around since the code was added 👀 . It doesn't fix the quality of word timestamp outputs (see e.g. #36632), but rather how we gather the cross attentions from the right beams with beam search, which was broken.

test_tiny_token_timestamp_batch_generation is a test that has the same pattern, beam search + timestamps, and is failing on main with the same exception as reported in #36093. This PR does NOT fix that test, but allows the test to move past the shape exception until the output quality checks, which are broken 🙃

tensor containing the timestamps in seconds for each predicted token
"""
# Create a list with `decoder_layers` elements, each a tensor of shape
# (batch size, attention_heads, output length, input length).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape comments were incorrect for the case w/beam search

weight_length = None

if "beam_indices" in generate_outputs:
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this if block, I've rewritten comments to better explain what's happening

Comment on lines -253 to -257
# beam search takes `decoder_input_ids` into account in the `beam_indices` length
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
# we actually shift the beam indices here
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct with the beam search refactor (#35802): beam_indices was corrected to have the same output length as the other optional outputs (= length of generated tokens)

# In that case, the `cross_attentions` weights are too long and we have to make sure that they have
# the right `output_length`

weights = weights[:, :, :weight_length]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant: we rebuild weights below with sequence length = range(unrolled_beam_indices.shape[1]) (= weight_length)

# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

root issue of #36093: weight_length is off by 1. The comments in the new version explain why :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nits on the shape comments. Important step to have something to "work" again even if it's not producing the correct output quality-wise at first :)

@gante gante enabled auto-merge (squash) May 23, 2025 09:53
@gante gante merged commit a6b51e7 into huggingface:main May 23, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants