Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
751f88a
handle single timestamp ending
eustlb Oct 31, 2024
000dccd
include last timestamp token
eustlb Oct 31, 2024
70c8aac
handle single timestamp ending
eustlb Oct 31, 2024
3f9f7ea
Merge branch 'huggingface:main' into fix-whispertokenizer
eustlb Oct 31, 2024
e753206
avoid floating points arithm limitations
eustlb Oct 31, 2024
c53fb2c
ensure float64 operations
eustlb Oct 31, 2024
185fb55
new test
eustlb Oct 31, 2024
429904c
make fixup
eustlb Oct 31, 2024
1c72244
make copies
eustlb Oct 31, 2024
ad4f355
Merge branch 'main' into fix-whispertokenizer
eustlb Oct 31, 2024
09af9de
Merge branch 'main' into fix-whispertokenizer
eustlb Nov 2, 2024
7d6f9b4
handle edge case double tokens ending with different tokens
eustlb Nov 4, 2024
937cd2a
handle single timestamp ending
eustlb Nov 4, 2024
9b1d51e
make fixup
eustlb Nov 4, 2024
a3cbe9f
handle conditioning on prev segments
eustlb Nov 5, 2024
7c0da36
fix
eustlb Nov 5, 2024
4a21249
Update src/transformers/models/whisper/generation_whisper.py
eustlb Nov 21, 2024
5b195a8
Merge branch 'main' into fix-whispertokenizer
eustlb Nov 22, 2024
e8f2f69
[run-slow] whisper
Nov 27, 2024
d71d40a
Merge branch 'main' into fix-whispertokenizer
eustlb Nov 27, 2024
961d5f6
Merge branch 'main' into fix-whispertokenizer
eustlb Nov 27, 2024
e04aa92
Merge branch 'main' into fix-whispertokenizer
eustlb Dec 2, 2024
69abfe0
Merge branch 'main' into fix-whispertokenizer
eustlb Dec 4, 2024
16b1ecb
Merge branch 'main' into fix-whispertokenizer
eustlb Dec 4, 2024
5fba3e0
don't call item() to avoid unnecessary sync
eustlb Dec 5, 2024
9624ce9
Merge branch 'main' into fix-whispertokenizer
eustlb Dec 5, 2024
88587bb
fix
eustlb Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def generate(
num_segment_frames: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
time_precision: float = 0.02,
time_precision_features: float = 0.01,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
Expand Down Expand Up @@ -417,6 +418,8 @@ def generate(
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
time_precision_features (`int`, *optional*, defaults to 0.01):
The duration represented by a feature frame in seconds.
return_token_timestamps (`bool`, *optional*):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
Expand Down Expand Up @@ -629,7 +632,7 @@ def generate(
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
time_offset = seek * time_precision / input_stride
time_offset = seek.to(torch.float64) * time_precision / input_stride
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original Whisper code base does such computations in float64. We need to ensure we do the same, especially wince we are comparing in the tests with the original Whisper outputs.

seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)

# 6.2 cut out next 30s segment from input features
Expand Down Expand Up @@ -658,6 +661,7 @@ def generate(
config=self.config,
device=init_tokens.device,
suppress_tokens=suppress_tokens,
timestamp_begin=timestamp_begin,
kwargs=kwargs,
)

Expand Down Expand Up @@ -718,6 +722,7 @@ def generate(
timestamp_begin=timestamp_begin,
seek_num_frames=seek_num_frames,
time_precision=time_precision,
time_precision_features=time_precision_features,
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
Expand Down Expand Up @@ -1665,6 +1670,7 @@ def _prepare_decoder_input_ids(
config,
device,
suppress_tokens,
timestamp_begin,
kwargs,
):
if "decoder_input_ids" in kwargs:
Expand All @@ -1684,6 +1690,14 @@ def _prepare_decoder_input_ids(
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]

for segments in active_segments:
for seg in segments:
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
# the segment finishes with two timestamp tokens
# we need to ignore the last timestamp token
# see https://github.com/huggingface/transformers/pull/34537
seg["tokens"] = seg["tokens"][:-1]
Comment on lines +1693 to +1699
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit costly, I wonder if there's a cleaner way to compute all of this!

Why do we need to get rid of the last timestamp when preparing the input ids ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, we need to include the last timestamp in the case of double-ending timestamps (and not only the penultimate as done before) to enable the tokenizer to differentiate the two cases (single and double ending). Nevertheless, OAI does not have to worry about that because they don't concatenate all the tokens as we do, and for this when conditioning on the previous token, the last token in the case of double ending is omitted. To ensure we do the exact same, we need to remove it when preparing the decoder_input_ids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that any segment can be a double-ending one, they all need to be checked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it is not that costly (complexity is O(batch_size * max number of segments)). When testing, it has no measurable impact on inference speed (see long-form results here, our implem is on par with OAI's).


if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
prev_ids = prompt_ids
else:
Expand Down Expand Up @@ -1778,6 +1792,7 @@ def _retrieve_segment(
timestamp_begin,
seek_num_frames,
time_precision,
time_precision_features,
input_stride,
prev_idx,
idx,
Expand All @@ -1799,17 +1814,22 @@ def _retrieve_segment(
segments = []
if single_timestamp_ending:
slices.append(len(seek_sequence))
else:
# we want to include the last timestamp token in the last segment to know it was no single ending
slices[-1] += 1
Comment on lines +1817 to +1819
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the only way to know latter that we have a double token ending segment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're offsetting the last timestamp by one when the last two tokens are timestamp ? let's say we have [..., T1, T2], you're doing [..., T1, T2+1] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slices is a list of indexes along the generated sequence of tokens seek_sequence. Let's say we have 100 tokens and it was not single ending, meaning last two tokens are timestamps [..., T1, T2]. For this reason, slices[-1] = 99 yet when we will slice after the segments with seek_sequence[last_slice:current_slice], we want to make sure we include T2 in the slice (so that the we further know it is a double timestamp ending segment as explained in the PR's comment) → by adding 1 to last slice, we ensure last iteration will slice 'seek_sequence[last_slice:100]` and that T2 will get included


last_slice = 0
# Add each segment to list of all segments
for current_slice in slices:
for i, current_slice in enumerate(slices):
is_last_slice = i == len(slices) - 1
sliced_tokens = seek_sequence[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
start_timestamp_pos = sliced_tokens[0] - timestamp_begin
idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is the right place or if there's an open issue, but this line now crashes on macs / the mps backend because float64 is not supported on apple silicon. I had to downgrade to 4.46.3 - is there some other way to fix this?

"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
Expand All @@ -1827,16 +1847,16 @@ def _retrieve_segment(
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
last_timestamp_pos = seek_num_frames[prev_idx]
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
segments = [
{
"start": time_offset[prev_idx],
Expand Down
34 changes: 27 additions & 7 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ def basic_normalize(text, remove_diacritics=False):
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)

def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
def _decode_with_timestamps(
self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500
) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
Expand All @@ -538,15 +540,25 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre

cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0

for token in token_ids:
for i, token in enumerate(token_ids):
if token >= timestamp_begin:
timestamp = float((token - timestamp_begin) * time_precision)

if timestamp < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
Comment on lines +551 to +552
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are token_ids always in ascending order ? In that case, don't we always have token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin when entering the if token >= timestamp_begin loop, if i>=2 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we supposed to always have two subsequent timestamp tokens now that you've dealt with single timestamp token in the generation file ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. After concatenating all the generated sequences for each 30sec segment and at the tokenizer decoding phase, we have two possibilities:

  1. [..., <t1>, <t2>, <0.0>, ...] (last segment was double timestamp ending)
  2. [..., <t1>, <0.0>, ...] (last segment was single timestamp ending)

Note that the only reason we can differentiate those two cases here is thanks to the above slices[-1] += 1 that ensures <t2> is included when we are not single timestamp ending. So in 2. we have token >= timestamp_begin (the <0.0>), token_ids[i - 1] >= timestamp_begin but not token_ids[i - 2] >= timestamp_begin

)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp
outputs = outputs[:-2]

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp

outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
Expand All @@ -558,7 +570,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre
]
return "".join(outputs)

def _compute_offsets(self, token_ids, time_precision=0.02):
def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):
"""
Compute offsets for a given tokenized input

Expand All @@ -567,6 +579,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, *optional*, defaults to 0.02):
The time ratio to convert from token to time.
segment_size (`int`, *optional*, defaults to 1500):
The number of features in the input mel spectrogram.
"""
offsets = []
# ensure torch tensor of token ids is placed on cpu
Expand Down Expand Up @@ -597,7 +611,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02):

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
is_single_ending = last_slice >= 2 and not (
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
)
if is_single_ending:
prev_segments_len += segment_size
else:
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

Expand All @@ -609,8 +629,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
start_timestamp_position * time_precision + prev_segments_len * time_precision,
end_timestamp_position * time_precision + prev_segments_len * time_precision,
Comment on lines +632 to +633
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done this way in the original codebase: summing float64 after the multiplication with the position. We avoid this way annoying floating-point arithmetic issues.

),
}
)
Expand Down
36 changes: 28 additions & 8 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
return super()._encode_plus(*args, **kwargs)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
def _decode_with_timestamps(
self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500
) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
Expand All @@ -179,15 +181,25 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre

cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0

for token in token_ids:
for i, token in enumerate(token_ids):
if token >= timestamp_begin:
timestamp = float((token - timestamp_begin) * time_precision)

if timestamp < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp

last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp
outputs = outputs[:-2]

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp

outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
Expand All @@ -200,7 +212,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre
return "".join(outputs)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
def _compute_offsets(self, token_ids, time_precision=0.02):
def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):
"""
Compute offsets for a given tokenized input

Expand All @@ -209,6 +221,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, *optional*, defaults to 0.02):
The time ratio to convert from token to time.
segment_size (`int`, *optional*, defaults to 1500):
The number of features in the input mel spectrogram.
"""
offsets = []
# ensure torch tensor of token ids is placed on cpu
Expand Down Expand Up @@ -239,7 +253,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02):

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
is_single_ending = last_slice >= 2 and not (
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
)
if is_single_ending:
prev_segments_len += segment_size
else:
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

Expand All @@ -251,8 +271,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
start_timestamp_position * time_precision + prev_segments_len * time_precision,
end_timestamp_position * time_precision + prev_segments_len * time_precision,
),
}
)
Expand Down
88 changes: 88 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,94 @@ def test_tiny_longform_timestamps_generation(self):
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)

@slow
def test_small_longform_timestamps_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
model.to(torch_device)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sampling_rate = dataset[0]["audio"]["sampling_rate"]

sample = [*sample[: 15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate :]]
sample = np.array(sample)

input_features = processor(
sample,
sampling_rate=16_000,
padding="longest",
truncation=False,
return_attention_mask=True,
return_tensors="pt",
).input_features

input_features = input_features.to(torch_device)
generated_ids = model.generate(input_features, return_timestamps=True, return_segments=True)

EXPECTED_TRANSCRIPT = [
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"timestamp": (0.0, 6.38),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.38, 11.32),
},
{
"text": " He tells us that at this festive season of the year,",
"timestamp": (11.32, 15.0),
},
{
"text": " With Christmas and roast beef looming before us, similes drawn from eating and its results",
"timestamp": (30.0, 36.76),
},
{
"text": " occur most readily to the mind.",
"timestamp": (36.76, 39.80),
},
{
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
"timestamp": (39.80, 45.36),
},
{
"text": " can discover in it but little of rocky Ithaca.",
"timestamp": (45.36, 49.0),
},
{
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
"timestamp": (49.0, 56.28),
},
{
"text": " are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in",
"timestamp": (56.28, 64.12),
},
{
"text": " the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his",
"timestamp": (64.12, 70.76),
},
{
"text": " sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,",
"timestamp": (70.76, 77.16),
},
{
"text": " Next Man",
"timestamp": (77.16, 78.16),
},
]

transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)

transcript_segments = [
{
"text": processor.decode(seg["tokens"], skip_special_tokens=True),
"timestamp": (seg["start"].item(), seg["end"].item()),
}
for seg in generated_ids["segments"][0]
]
self.assertEqual(transcript_segments, EXPECTED_TRANSCRIPT)

@slow
def test_large_timestamp_generation(self):
set_seed(0)
Expand Down