-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[Whisper] Fix whisper tokenizer #34537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
751f88a
000dccd
70c8aac
3f9f7ea
e753206
c53fb2c
185fb55
429904c
1c72244
ad4f355
09af9de
7d6f9b4
937cd2a
9b1d51e
a3cbe9f
7c0da36
4a21249
5b195a8
e8f2f69
d71d40a
961d5f6
e04aa92
69abfe0
16b1ecb
5fba3e0
9624ce9
88587bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -308,6 +308,7 @@ def generate( | |
| num_segment_frames: Optional[int] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| time_precision: float = 0.02, | ||
| time_precision_features: float = 0.01, | ||
| return_token_timestamps: Optional[bool] = None, | ||
| return_segments: bool = False, | ||
| return_dict_in_generate: Optional[bool] = None, | ||
|
|
@@ -417,6 +418,8 @@ def generate( | |
| time_precision (`int`, *optional*, defaults to 0.02): | ||
| The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts | ||
| for 20 ms. | ||
| time_precision_features (`int`, *optional*, defaults to 0.01): | ||
| The duration represented by a feature frame in seconds. | ||
| return_token_timestamps (`bool`, *optional*): | ||
| Whether to return token-level timestamps with the text. This can be used with or without the | ||
| `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into | ||
|
|
@@ -629,7 +632,7 @@ def generate( | |
| cur_bsz=cur_bsz, | ||
| batch_idx_map=batch_idx_map, | ||
| ) | ||
| time_offset = seek * time_precision / input_stride | ||
| time_offset = seek.to(torch.float64) * time_precision / input_stride | ||
| seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) | ||
|
|
||
| # 6.2 cut out next 30s segment from input features | ||
|
|
@@ -658,6 +661,7 @@ def generate( | |
| config=self.config, | ||
| device=init_tokens.device, | ||
| suppress_tokens=suppress_tokens, | ||
| timestamp_begin=timestamp_begin, | ||
| kwargs=kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -718,6 +722,7 @@ def generate( | |
| timestamp_begin=timestamp_begin, | ||
| seek_num_frames=seek_num_frames, | ||
| time_precision=time_precision, | ||
| time_precision_features=time_precision_features, | ||
| input_stride=input_stride, | ||
| prev_idx=prev_i, | ||
| idx=i, | ||
|
|
@@ -1665,6 +1670,7 @@ def _prepare_decoder_input_ids( | |
| config, | ||
| device, | ||
| suppress_tokens, | ||
| timestamp_begin, | ||
| kwargs, | ||
| ): | ||
| if "decoder_input_ids" in kwargs: | ||
|
|
@@ -1684,6 +1690,14 @@ def _prepare_decoder_input_ids( | |
| # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 | ||
| active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] | ||
|
|
||
| for segments in active_segments: | ||
| for seg in segments: | ||
| if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin: | ||
| # the segment finishes with two timestamp tokens | ||
| # we need to ignore the last timestamp token | ||
| # see https://github.com/huggingface/transformers/pull/34537 | ||
| seg["tokens"] = seg["tokens"][:-1] | ||
|
Comment on lines
+1693
to
+1699
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks a bit costly, I wonder if there's a cleaner way to compute all of this! Why do we need to get rid of the last timestamp when preparing the input ids ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, we need to include the last timestamp in the case of double-ending timestamps (and not only the penultimate as done before) to enable the tokenizer to differentiate the two cases (single and double ending). Nevertheless, OAI does not have to worry about that because they don't concatenate all the tokens as we do, and for this when conditioning on the previous token, the last token in the case of double ending is omitted. To ensure we do the exact same, we need to remove it when preparing the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that any segment can be a double-ending one, they all need to be checked.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like it is not that costly (complexity is |
||
|
|
||
| if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments": | ||
| prev_ids = prompt_ids | ||
| else: | ||
|
|
@@ -1778,6 +1792,7 @@ def _retrieve_segment( | |
| timestamp_begin, | ||
| seek_num_frames, | ||
| time_precision, | ||
| time_precision_features, | ||
| input_stride, | ||
| prev_idx, | ||
| idx, | ||
|
|
@@ -1799,17 +1814,22 @@ def _retrieve_segment( | |
| segments = [] | ||
| if single_timestamp_ending: | ||
| slices.append(len(seek_sequence)) | ||
| else: | ||
| # we want to include the last timestamp token in the last segment to know it was no single ending | ||
| slices[-1] += 1 | ||
|
Comment on lines
+1817
to
+1819
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's the only way to know latter that we have a double token ending segment
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you're offsetting the last timestamp by one when the last two tokens are timestamp ? let's say we have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| last_slice = 0 | ||
| # Add each segment to list of all segments | ||
| for current_slice in slices: | ||
| for i, current_slice in enumerate(slices): | ||
| is_last_slice = i == len(slices) - 1 | ||
| sliced_tokens = seek_sequence[last_slice:current_slice] | ||
| start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin | ||
| end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin | ||
| start_timestamp_pos = sliced_tokens[0] - timestamp_begin | ||
| idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2 | ||
| end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin | ||
| segments.append( | ||
| { | ||
| "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, | ||
| "end": time_offset[prev_idx] + end_timestamp_pos * time_precision, | ||
| "start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if this is the right place or if there's an open issue, but this line now crashes on macs / the mps backend because float64 is not supported on apple silicon. I had to downgrade to 4.46.3 - is there some other way to fix this? |
||
| "end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision, | ||
| "tokens": sliced_tokens, | ||
| "result": seek_outputs[idx], | ||
| } | ||
|
|
@@ -1827,16 +1847,16 @@ def _retrieve_segment( | |
| # otherwise, ignore the unfinished segment and seek to the last timestamp | ||
| # here we throw away all predictions after the last predicted "end of segment" | ||
| # since we are cutting right in the middle of an audio | ||
| last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin | ||
| last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin | ||
| segment_offset = last_timestamp_pos * input_stride | ||
| else: | ||
| # If whisper does not predict any "end of segment" token, then | ||
| # the whole decoding is considered a segment and we add it to the list of segments | ||
| timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] | ||
| last_timestamp_pos = seek_num_frames[prev_idx] | ||
| if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: | ||
| last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision) | ||
| if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin: | ||
| # no consecutive timestamps but it has a timestamp; use the last one. | ||
| last_timestamp_pos = timestamps[-1].item() - timestamp_begin | ||
| last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64) | ||
| segments = [ | ||
| { | ||
| "start": time_offset[prev_idx], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -528,7 +528,9 @@ def basic_normalize(text, remove_diacritics=False): | |
| normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics) | ||
| return normalizer(text) | ||
|
|
||
| def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: | ||
| def _decode_with_timestamps( | ||
| self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500 | ||
| ) -> str: | ||
| """ | ||
| Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes | ||
| given tokens with timestamps tokens annotated, e.g. "<|1.08|>". | ||
|
|
@@ -538,15 +540,25 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre | |
|
|
||
| cur_max_timestamp = 0.0 | ||
| prev_segments_len = 0.0 | ||
| penultimate_timestamp = 0.0 | ||
|
|
||
| for token in token_ids: | ||
| for i, token in enumerate(token_ids): | ||
| if token >= timestamp_begin: | ||
| timestamp = float((token - timestamp_begin) * time_precision) | ||
|
|
||
| if timestamp < cur_max_timestamp: | ||
| # next segment has started | ||
| prev_segments_len += cur_max_timestamp | ||
| last_was_single_ending = i >= 2 and not ( | ||
| token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin | ||
|
Comment on lines
+551
to
+552
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't we supposed to always have two subsequent timestamp tokens now that you've dealt with single timestamp token in the generation file ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessarily. After concatenating all the generated sequences for each 30sec segment and at the tokenizer decoding phase, we have two possibilities:
Note that the only reason we can differentiate those two cases here is thanks to the above |
||
| ) | ||
| if last_was_single_ending: | ||
| prev_segments_len += time_precision * segment_size | ||
| else: | ||
| cur_max_timestamp = penultimate_timestamp | ||
| prev_segments_len += penultimate_timestamp | ||
| outputs = outputs[:-2] | ||
|
|
||
| penultimate_timestamp = cur_max_timestamp | ||
| cur_max_timestamp = timestamp | ||
|
|
||
| outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") | ||
|
|
@@ -558,7 +570,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre | |
| ] | ||
| return "".join(outputs) | ||
|
|
||
| def _compute_offsets(self, token_ids, time_precision=0.02): | ||
| def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): | ||
| """ | ||
| Compute offsets for a given tokenized input | ||
|
|
||
|
|
@@ -567,6 +579,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): | |
| List of tokenized input ids. Can be obtained using the `__call__` method. | ||
| time_precision (`float`, *optional*, defaults to 0.02): | ||
| The time ratio to convert from token to time. | ||
| segment_size (`int`, *optional*, defaults to 1500): | ||
| The number of features in the input mel spectrogram. | ||
| """ | ||
| offsets = [] | ||
| # ensure torch tensor of token ids is placed on cpu | ||
|
|
@@ -597,7 +611,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02): | |
|
|
||
| if start_timestamp_position < cur_max_timestamp: | ||
| # next segment has started | ||
| prev_segments_len += cur_max_timestamp | ||
| is_single_ending = last_slice >= 2 and not ( | ||
| token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin | ||
| ) | ||
| if is_single_ending: | ||
| prev_segments_len += segment_size | ||
| else: | ||
| prev_segments_len += cur_max_timestamp | ||
|
|
||
| cur_max_timestamp = end_timestamp_position | ||
|
|
||
|
|
@@ -609,8 +629,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): | |
| { | ||
| "text": text, | ||
| "timestamp": ( | ||
| (start_timestamp_position + prev_segments_len) * time_precision, | ||
| (end_timestamp_position + prev_segments_len) * time_precision, | ||
| start_timestamp_position * time_precision + prev_segments_len * time_precision, | ||
| end_timestamp_position * time_precision + prev_segments_len * time_precision, | ||
|
Comment on lines
+632
to
+633
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's done this way in the original codebase: summing float64 after the multiplication with the position. We avoid this way annoying floating-point arithmetic issues. |
||
| ), | ||
| } | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original Whisper code base does such computations in float64. We need to ensure we do the same, especially wince we are comparing in the tests with the original Whisper outputs.