Skip to content

Commit

Permalink
Improve timestamp heuristics. (#1461)
Browse files Browse the repository at this point in the history
* Improve timestamp heuristics.

* Track pauses with last_speech_timestamp
  • Loading branch information
ryanheise authored Jun 29, 2023
1 parent 248b6cb commit f572f21
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 28 deletions.
84 changes: 56 additions & 28 deletions whisper/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,28 +225,6 @@ def find_alignment(
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]

# hack: truncate long words at the start of a window and the start of a sentence.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
sentence_end_marks = ".。!!??"
# ensure words at sentence boundaries are not longer than twice the median word duration.
for i in range(1, len(start_times)):
if end_times[i] - start_times[i] > max_duration:
if words[i] in sentence_end_marks:
end_times[i] = start_times[i] + max_duration
elif words[i - 1] in sentence_end_marks:
start_times[i] = end_times[i] - max_duration
# ensure the first and second word is not longer than twice the median word duration.
if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration:
if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration:
boundary = max(end_times[1] / 2, end_times[1] - max_duration)
end_times[0] = start_times[1] = boundary
start_times[0] = max(0, end_times[0] - max_duration)

return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
Expand Down Expand Up @@ -298,6 +276,7 @@ def add_word_timestamps(
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
last_speech_timestamp: float,
**kwargs,
):
if len(segments) == 0:
Expand All @@ -310,6 +289,25 @@ def add_word_timestamps(

text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
max_duration = median_duration * 2

# hack: truncate long words at sentence boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2

This comment has been minimized.

Copy link
@MoMaT

MoMaT Jul 7, 2023

Both median_duration and max_duration are calculated before (lines 294-295).

This comment has been minimized.

Copy link
@ryanheise

ryanheise Jul 7, 2023

Author Contributor

Well spotted, I forgot to delete one of those after refactoring made reuse possible. Do you want to submit a PR?

sentence_end_marks = ".。!!??"
# ensure words at sentence boundaries are not longer than twice the median word duration.
for i in range(1, len(alignment)):
if alignment[i].end - alignment[i].start > max_duration:
if alignment[i].word in sentence_end_marks:
alignment[i].end = alignment[i].start + max_duration
elif alignment[i - 1].word in sentence_end_marks:
alignment[i].start = alignment[i].end - max_duration

merge_punctuations(alignment, prepend_punctuations, append_punctuations)

time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
Expand All @@ -335,18 +333,48 @@ def add_word_timestamps(
saved_tokens += len(timing.tokens)
word_index += 1

# hack: truncate long words at segment boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(words) > 0:
segment["start"] = words[0]["start"]
# hack: prefer the segment-level end timestamp if the last word is too long.
# a better segmentation algorithm based on VAD should be able to replace this.
# ensure the first and second word after a pause is not longer than
# twice the median word duration.
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
words[0]["end"] - words[0]["start"] > max_duration
or (
len(words) > 1
and words[1]["end"] - words[0]["start"] > max_duration * 2
)
):
if (
len(words) > 1
and words[1]["end"] - words[1]["start"] > max_duration
):
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
words[0]["end"] = words[1]["start"] = boundary
words[0]["start"] = max(0, words[0]["end"] - max_duration)

# prefer the segment-level start timestamp if the first word is too long.
if (
segment["start"] < words[0]["end"]
and segment["start"] - 0.5 > words[0]["start"]
):
words[0]["start"] = max(
0, min(words[0]["end"] - median_duration, segment["start"])
)
else:
segment["start"] = words[0]["start"]

# prefer the segment-level end timestamp if the last word is too long.
if (
segment["end"] > words[-1]["start"]
and segment["end"] + 0.5 < words[-1]["end"]
):
# adjust the word-level timestamps based on the segment-level timestamps
words[-1]["end"] = segment["end"]
words[-1]["end"] = max(
words[-1]["start"] + median_duration, segment["end"]
)
else:
# adjust the segment-level timestamps based on the word-level timestamps
segment["end"] = words[-1]["end"]

last_speech_timestamp = segment["end"]

segment["words"] = words
4 changes: 4 additions & 0 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def new_segment(
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES]
Expand Down Expand Up @@ -321,10 +322,13 @@ def new_segment(
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
Expand Down

0 comments on commit f572f21

Please sign in to comment.