Skip to content

Commit

Permalink
Remove negative duration segments from whisper (#928)
Browse files Browse the repository at this point in the history
This PR addresses #891.

* Remove segments with non-positive duration from the whisper output
* Segment post-processing to force non-overlapping is made optional
(disabled by default)
* Allow overlapping segments in forced alignment workflow
  • Loading branch information
pzelasko authored Dec 16, 2022
2 parents e83afd3 + f411440 commit 891bad1
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 16 deletions.
7 changes: 7 additions & 0 deletions lhotse/bin/modes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def workflows():
"-d", "--device", default="cpu", help="Device on which to run the inference."
)
@click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.")
@click.option(
"--force-nonoverlapping/--keep-overlapping",
default=False,
help="If True, the Whisper segment time-stamps will be processed to make sure they are non-overlapping.",
)
def annotate_with_whisper(
out_cuts: str,
recordings_manifest: Optional[str],
Expand All @@ -67,6 +72,7 @@ def annotate_with_whisper(
language: Optional[str],
device: str,
jobs: int,
force_nonoverlapping: bool,
):
"""
Use OpenAI Whisper model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST.
Expand Down Expand Up @@ -101,6 +107,7 @@ def annotate_with_whisper(
language=language,
model_name=model_name,
device=device,
force_nonoverlapping=force_nonoverlapping,
),
total=len(manifest),
desc="Annotating with Whisper",
Expand Down
2 changes: 1 addition & 1 deletion lhotse/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def trim_supervisions_to_recordings(
continue
if s.end > end:
trimmed += 1
s = s.trim(recordings[s.recording_id].duration)
s = s.trim(end=end)
sups.append(s)
if verbose and removed:
logging.warning(
Expand Down
4 changes: 0 additions & 4 deletions lhotse/workflows/forced_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ def align_with_torchaudio(
discard_symbols = _make_discard_symbols_regex(labels)

for cut in cuts:
assert not cut.has_overlapping_supervisions, (
f"We don't support forced alignment of cuts with overlapping supervisions "
f"(cut ID: '{cut.id}')"
)

for idx, subcut in enumerate(cut.trim_to_supervisions(keep_overlapping=False)):
sup = subcut.supervisions[0]
Expand Down
53 changes: 42 additions & 11 deletions lhotse/workflows/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def annotate_with_whisper(
language: Optional[str] = None,
model_name: str = "base",
device: str = "cpu",
force_nonoverlapping: bool = False,
) -> Generator[MonoCut, None, None]:
"""
Use OpenAI Whisper model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST.
Expand All @@ -35,6 +36,8 @@ def annotate_with_whisper(
:param language: specify the language if known upfront, otherwise it will be auto-detected.
:param model_name: one of available Whisper variants (base, medium, large, etc.).
:param device: Where to run the inference (cpu, cuda, etc.).
:param force_nonoverlapping: if True, the Whisper segment time-stamps will be processed to make
sure they are non-overlapping.
:return: a generator of cuts (use ``CutSet.open_writer()`` to write them).
"""
assert is_module_available("whisper"), (
Expand All @@ -44,15 +47,23 @@ def annotate_with_whisper(
)

if isinstance(manifest, RecordingSet):
yield from _annotate_recordings(manifest, language, model_name, device)
yield from _annotate_recordings(
manifest, language, model_name, device, force_nonoverlapping
)
elif isinstance(manifest, CutSet):
yield from _annotate_cuts(manifest, language, model_name, device)
yield from _annotate_cuts(
manifest, language, model_name, device, force_nonoverlapping
)
else:
raise ValueError("The ``manifest`` must be either a RecordingSet or a CutSet.")


def _annotate_recordings(
recordings: RecordingSet, language: str, model_name: str, device: str
recordings: RecordingSet,
language: str,
model_name: str,
device: str,
force_nonoverlapping: bool,
):
"""
Helper function that annotates a RecordingSet with Whisper.
Expand All @@ -70,6 +81,7 @@ def _annotate_recordings(
continue
audio = torch.from_numpy(recording.resample(16000).load_audio()).squeeze(0)
result = whisper.transcribe(model=model, audio=audio, language=language)
# Create supervisions from segments while filtering out those with negative duration.
supervisions = [
SupervisionSegment(
id=f"{recording.id}-{segment['id']:06d}",
Expand All @@ -82,10 +94,15 @@ def _annotate_recordings(
language=result["language"],
)
for segment in result["segments"]
if segment["end"] - segment["start"] > 0
]
cut = recording.to_cut()
if supervisions:
supervisions = _postprocess_timestamps(supervisions)
supervisions = (
_postprocess_timestamps(supervisions)
if force_nonoverlapping
else supervisions
)
cut.supervisions = list(
trim_supervisions_to_recordings(
recordings=recording, supervisions=supervisions, verbose=False
Expand All @@ -94,7 +111,13 @@ def _annotate_recordings(
yield cut


def _annotate_cuts(cuts: CutSet, language: str, model_name: str, device: str):
def _annotate_cuts(
cuts: CutSet,
language: str,
model_name: str,
device: str,
force_nonoverlapping: bool,
):
"""
Helper function that annotates a CutSet with Whisper.
"""
Expand All @@ -111,23 +134,29 @@ def _annotate_cuts(cuts: CutSet, language: str, model_name: str, device: str):
continue
audio = torch.from_numpy(cut.resample(16000).load_audio()).squeeze(0)
result = whisper.transcribe(model=model, audio=audio, language=language)
# Create supervisions from segments while filtering out those with negative duration.
supervisions = [
SupervisionSegment(
id=f"{cut.id}-{segment['id']:06d}",
recording_id=cut.recording_id,
start=round(segment["start"], ndigits=8),
duration=max(
cut.duration,
add_durations(
segment["end"], -segment["start"], sampling_rate=16000
),
duration=add_durations(
min(segment["end"], cut.duration),
-segment["start"],
sampling_rate=16000,
),
text=segment["text"].strip(),
language=result["language"],
)
for segment in result["segments"]
if segment["end"] - segment["start"] > 0
]
new_cut = fastcopy(cut, supervisions=_postprocess_timestamps(supervisions))
new_cut = fastcopy(
cut,
supervisions=_postprocess_timestamps(supervisions)
if force_nonoverlapping
else supervisions,
)
yield new_cut


Expand All @@ -139,6 +168,8 @@ def _postprocess_timestamps(supervisions: List[SupervisionSegment]):
"""
from cytoolz import sliding_window

supervisions = sorted(supervisions, key=lambda s: s.start)

if len(supervisions) < 2:
return supervisions
out = []
Expand Down

0 comments on commit 891bad1

Please sign in to comment.