From 7eaad9d7a89e8a4e368cb9817acd45ef823d5fa4 Mon Sep 17 00:00:00 2001 From: Pzzzzz Date: Thu, 21 Mar 2024 15:02:51 +0800 Subject: [PATCH] Fix run_faster_whisper.py, Fix typo --- examples/whisper/run.py | 6 ++--- examples/whisper/run_faster_whisper.py | 36 ++++++++++++++------------ 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/examples/whisper/run.py b/examples/whisper/run.py index 5f980754f..6e99912cf 100644 --- a/examples/whisper/run.py +++ b/examples/whisper/run.py @@ -373,14 +373,14 @@ def decode_dataset( tensorrt_llm.logger.set_level(args.log_level) model = WhisperTRTLLM(args.engine_dir, args.tokenizer_name, args.debug, args.assets_dir) - normallizer = EnglishTextNormalizer() + normalizer = EnglishTextNormalizer() if args.enable_warmup: results, total_duration = decode_dataset( model, "hf-internal-testing/librispeech_asr_dummy", batch_size=args.batch_size, num_beams=args.num_beams, - normalizer=normallizer, + normalizer=normalizer, mel_filters_dir=args.assets_dir) start_time = time.time() if args.input_file: @@ -398,7 +398,7 @@ def decode_dataset( dtype=args.dtype, batch_size=args.batch_size, num_beams=args.num_beams, - normalizer=normallizer, + normalizer=normalizer, mel_filters_dir=args.assets_dir) elapsed = time.time() - start_time results = sorted(results) diff --git a/examples/whisper/run_faster_whisper.py b/examples/whisper/run_faster_whisper.py index 70a72cbbd..8716ad551 100644 --- a/examples/whisper/run_faster_whisper.py +++ b/examples/whisper/run_faster_whisper.py @@ -15,8 +15,8 @@ import argparse import re import time +from pathlib import Path -import torch from datasets import load_dataset from torch.utils.data import DataLoader from whisper.normalizers import EnglishTextNormalizer @@ -32,12 +32,14 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--input_file', type=str, default=None) + parser.add_argument('--results_dir', type=str, default='tmp') parser.add_argument( '--name', type=str, default="librispeech_dummy_faster_whisper_large_v3_warmup") parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--num_beams', type=int, default=1) + parser.add_argument('--enable_warmup', action='store_true') return parser.parse_args() @@ -47,14 +49,12 @@ def decode_wav_file( model, text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", num_beams=1, - normalizer=None): - mel, total_duration = log_mel_spectrogram(input_file_path, - device='cuda', - return_duration=True) - mel = mel.type(torch.float16) - mel = mel.unsqueeze(0) - predictions = model.process_batch(mel, text_prefix, num_beams) - prediction = predictions[0] + normalizer=None, + sample_rate=16000): + segments, info = model.transcribe(input_file_path, + beam_size=num_beams, + language="en") + prediction = " ".join([segment.text for segment in segments]) # remove all special tokens in the prediction prediction = re.sub(r'<\|.*?\|>', '', prediction) @@ -62,7 +62,7 @@ def decode_wav_file( prediction = normalizer(prediction) print(f"prediction: {prediction}") results = [(0, [""], prediction.split())] - return results, total_duration + return results, info.duration def collate_wrapper(batch): @@ -116,16 +116,16 @@ def decode_dataset( if __name__ == '__main__': args = parse_arguments() - normallizer = EnglishTextNormalizer() + normalizer = EnglishTextNormalizer() model_size_or_path = "large-v3" model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16") - # warmup - results, total_duration = decode_dataset(model, - batch_size=args.batch_size, - num_beams=args.num_beams, - normalizer=normallizer) + if args.enable_warmup: + results, total_duration = decode_dataset(model, + batch_size=args.batch_size, + num_beams=args.num_beams, + normalizer=normalizer) start_time = time.time() if args.input_file: results, total_duration = decode_wav_file(args.input_file, @@ -135,9 +135,11 @@ def decode_dataset( results, total_duration = decode_dataset(model, batch_size=args.batch_size, num_beams=args.num_beams, - normalizer=normallizer) + normalizer=normalizer) elapsed = time.time() - start_time results = sorted(results) + + Path(args.results_dir).mkdir(parents=True, exist_ok=True) store_transcripts(filename=f"tmp/recogs-{args.name}.txt", texts=results) with open(f"tmp/errs-{args.name}.txt", "w") as f: