diff --git a/docs/evaluation/speech-audio.md b/docs/evaluation/speech-audio.md index 2e170891b3..89a4140a3f 100644 --- a/docs/evaluation/speech-audio.md +++ b/docs/evaluation/speech-audio.md @@ -75,7 +75,7 @@ eval( server_entrypoint="/workspace/megatron-lm/server.py", server_container="/path/to/container.sqsh", data_dir="/dataset", - installation_command="pip install sacrebleu jiwer openai-whisper" + installation_command="pip install -r requirements/audio.txt", server_args="--inference-max-requests 1 --model-config /workspace/checkpoint/config.yaml", ) ``` @@ -98,8 +98,8 @@ eval(benchmarks="asr-leaderboard", split="librispeech_clean", ...) --model=/workspace/path/to/checkpoint \ --server_entrypoint=/workspace/megatron-lm/server.py \ --server_container=/path/to/container.sqsh \ - --data_dir=/dataset - --installation_command="pip install sacrebleu jiwer openai-whisper" + --data_dir=/dataset \ + --installation_command="pip install -r requirements/audio.txt" ``` ### MMAU-Pro diff --git a/nemo_skills/dataset/asr-leaderboard/__init__.py b/nemo_skills/dataset/asr-leaderboard/__init__.py index b81cace3bc..7c3c2661ee 100644 --- a/nemo_skills/dataset/asr-leaderboard/__init__.py +++ b/nemo_skills/dataset/asr-leaderboard/__init__.py @@ -13,9 +13,10 @@ # limitations under the License. # Settings that define how evaluation should be done by default (all can be changed from cmdline) -# Uses the audio evaluator which computes WER with HuggingFace leaderboard preprocessing -# Data samples should have task_type="ASR_LEADERBOARD" for proper WER calculation +# Uses the audio evaluator which computes WER with Whisper-style text normalization +# Data samples should have task_type="ASR" for proper WER calculation DATASET_GROUP = "speechlm" METRICS_TYPE = "audio" -GENERATION_ARGS = "++prompt_format=openai ++eval_type=audio" +EVAL_ARGS = "++eval_type=audio" +GENERATION_ARGS = "++prompt_format=openai" diff --git a/nemo_skills/dataset/asr-leaderboard/prepare.py b/nemo_skills/dataset/asr-leaderboard/prepare.py index 25bbafd986..9fdc21ba19 100644 --- a/nemo_skills/dataset/asr-leaderboard/prepare.py +++ b/nemo_skills/dataset/asr-leaderboard/prepare.py @@ -32,7 +32,21 @@ from tqdm import tqdm SYSTEM_MESSAGE = "You are a helpful assistant. /no_think" -MIN_AUDIO_DURATION = 0.1 # Skip audio shorter than this (causes mel spectrogram errors) +USER_MESSAGE = "Transcribe the audio file into English text." +MIN_AUDIO_DURATION = 0.1 # Skip audio shorter than this + +# Speaker IDs to skip in Tedlium dataset +SKIP_SPEAKER_IDS = {"inter_segment_gap"} + +# Non-speech tokens to skip in GigaSpeech dataset +NONSPEECH_TOKENS = {"", "", "", ""} + + +def is_nonspeech_only(text): + """Check if text contains only non-speech tokens.""" + tokens = set(text.strip().split()) + return tokens and tokens.issubset(NONSPEECH_TOKENS) + # (hf_dataset, hf_config, hf_split, streaming) DATASET_CONFIGS = { @@ -59,12 +73,17 @@ def save_audio_and_format_entry(entry, dataset_name, audio_dir, sample_idx, with text = text.strip() if text else "" system_message = {"role": "system", "content": SYSTEM_MESSAGE} - user_message = {"role": "user", "content": "Transcribe the following audio."} + user_message = {"role": "user", "content": USER_MESSAGE} audio_info = entry.get("audio", {}) if isinstance(audio_info, dict) and "array" in audio_info and "sampling_rate" in audio_info: audio_array = audio_info["array"] sampling_rate = audio_info["sampling_rate"] + + # Skip if audio array is empty or invalid + if audio_array is None or len(audio_array) == 0: + return None + duration = len(audio_array) / sampling_rate if duration < MIN_AUDIO_DURATION: @@ -76,18 +95,24 @@ def save_audio_and_format_entry(entry, dataset_name, audio_dir, sample_idx, with if with_audio: sf.write(str(audio_dir / audio_filename), audio_array, sampling_rate) + audio_filepath = f"/dataset/asr-leaderboard/data/{dataset_name}/{audio_filename}" user_message["audio"] = { - "path": f"/dataset/asr-leaderboard/data/{dataset_name}/{audio_filename}", + "path": audio_filepath, "duration": float(duration), } formatted_entry = { - "task_type": "ASR_LEADERBOARD", + "task_type": "ASR", "expected_answer": text, "messages": [system_message, user_message], "subset_for_metrics": dataset_name, } + # Add audio_filepath and duration as top-level fields + if "audio" in user_message: + formatted_entry["audio_filepath"] = user_message["audio"]["path"] + formatted_entry["duration"] = user_message["audio"]["duration"] + if "id" in entry: formatted_entry["id"] = entry["id"] if "speaker_id" in entry: @@ -133,12 +158,17 @@ def prepare_dataset(dataset_name, output_dir, with_audio=True): if formatted is None: skipped += 1 continue - if formatted["expected_answer"]: + # Skip empty answers, non-speech segments, and non-speech-only samples + speaker_id = entry.get("speaker_id", "") + expected = formatted["expected_answer"] + if expected and speaker_id not in SKIP_SPEAKER_IDS and not is_nonspeech_only(expected): fout.write(json.dumps(formatted) + "\n") count += 1 + else: + skipped += 1 if skipped > 0: - print(f"Skipped {skipped} samples with audio < {MIN_AUDIO_DURATION}s") + print(f"Skipped {skipped} samples (short audio, non-speech, or invalid)") print(f"Saved {count} samples to {output_file}") return count diff --git a/nemo_skills/evaluation/evaluator/audio.py b/nemo_skills/evaluation/evaluator/audio.py index ff97181bbb..35149ebd59 100644 --- a/nemo_skills/evaluation/evaluator/audio.py +++ b/nemo_skills/evaluation/evaluator/audio.py @@ -17,6 +17,7 @@ import asyncio import logging import re +from functools import lru_cache from typing import Any import numpy as np @@ -32,7 +33,6 @@ class AudioEvaluatorConfig(BaseEvaluatorConfig): """Configuration for audio evaluation.""" prompt_config: str = "eval/speechlm/audio" - apply_whisper_normalization: bool = True normalize_asr_pc_standard_wer: bool = True @@ -119,61 +119,34 @@ def evaluate_asr_pc(reference: str, hypothesis: str, normalize_standard_wer: boo "wer_pc": wer_pc, "per": per, "is_correct": wer_pc < 0.5, + "text": ref_std, + "pred_text": hyp_std, } -def preprocess_asr_text(text: str) -> str: - """Apply Whisper-style normalization: lowercase, normalize, remove brackets.""" - from whisper.normalizers import EnglishTextNormalizer - - text = text.lower() - text = EnglishTextNormalizer()(text) - text = re.sub(r"(\[|\(|\{|\<)[^\(\)\\n\[\]]*(\]|\)|\}|\>)", "", text) - text = re.sub(r"\s+", " ", text).strip() - return text +@lru_cache(maxsize=1) +def _get_english_normalizer(): + """Lazily initialize and cache the English text normalizer.""" + from whisper_normalizer.english import EnglishTextNormalizer + return EnglishTextNormalizer() -def preprocess_hf_leaderboard(text: str) -> str: - """Apply HuggingFace leaderboard normalization: lowercase, remove punctuation, normalize unicode.""" - import unicodedata - text = unicodedata.normalize("NFC", text) - text = text.lower() - text = re.sub(r"[^\w\s]", "", text) - text = re.sub(r"\s+", " ", text).strip() - return text +def preprocess_asr_text(text: str) -> str: + """Apply Whisper-style normalization (lowercase, remove brackets, normalize whitespace).""" + return _get_english_normalizer()(text) -def evaluate_asr(reference: str, hypothesis: str, apply_normalization: bool = True) -> dict[str, Any]: - """Evaluate ASR: computes WER with optional Whisper normalization.""" +def evaluate_asr(reference: str, hypothesis: str) -> dict[str, Any]: + """Evaluate ASR: computes WER with Whisper normalization.""" import jiwer - if apply_normalization: - ref = preprocess_asr_text(reference) - hyp = preprocess_asr_text(hypothesis) - else: - ref = normalize_whitespace(reference) - hyp = normalize_whitespace(hypothesis) + ref = preprocess_asr_text(reference) + hyp = preprocess_asr_text(hypothesis) - if not ref: - ref = "empty" - if not hyp: - hyp = "empty" - - wer_score = jiwer.wer(ref, hyp) - - return { - "wer": wer_score, - "is_correct": wer_score < 0.5, - } - - -def evaluate_asr_leaderboard(reference: str, hypothesis: str) -> dict[str, Any]: - """Evaluate ASR with HuggingFace leaderboard preprocessing for direct comparison.""" - import jiwer - - ref = preprocess_hf_leaderboard(reference) - hyp = preprocess_hf_leaderboard(hypothesis) + # Store normalized texts before empty substitution + text = ref + pred_text = hyp if not ref: ref = "empty" @@ -185,6 +158,8 @@ def evaluate_asr_leaderboard(reference: str, hypothesis: str) -> dict[str, Any]: return { "wer": wer_score, "is_correct": wer_score < 0.5, + "text": text, + "pred_text": pred_text, } @@ -193,20 +168,25 @@ def evaluate_translation(reference: str, hypothesis: str) -> dict[str, Any]: try: import sacrebleu - ref = [reference.strip()] - hyp = hypothesis.strip() - bleu = sacrebleu.sentence_bleu(hyp, ref) + text = reference.strip() + pred_text = hypothesis.strip() + ref = [text] + bleu = sacrebleu.sentence_bleu(pred_text, ref) bleu_score = bleu.score / 100.0 return { "bleu": bleu_score, "is_correct": bleu_score > 0.3, + "text": text, + "pred_text": pred_text, } except Exception as e: return { "bleu": 0.0, "is_correct": False, "error": str(e), + "text": reference.strip(), + "pred_text": hypothesis.strip(), } @@ -218,6 +198,8 @@ def evaluate_cer(reference: str, hypothesis: str) -> dict[str, Any]: return { "cer": cer_score, "is_correct": cer_score < 0.5, + "text": reference, + "pred_text": hypothesis, } @@ -235,6 +217,8 @@ def evaluate_hallucination(reference: str, hypothesis: str, audio_context: dict "char_rate": 0.0, "is_correct": True, "error": "missing_audio_duration", + "text": reference, + "pred_text": hypothesis, } char_count = len(hypothesis) @@ -248,6 +232,8 @@ def evaluate_hallucination(reference: str, hypothesis: str, audio_context: dict "hallucination_rate": 1.0 if is_hallucinating else 0.0, "char_rate": round(char_rate, 2), "is_correct": not is_hallucinating, + "text": reference, + "pred_text": hypothesis, } @@ -295,6 +281,8 @@ def evaluate_pc_rate(reference: str, hypothesis: str) -> dict[str, Any]: "punct_f1": round(punct_f1, 3), "cap_accuracy": round(cap_accuracy, 3), "is_correct": pc_rate > 0.5, + "text": reference, + "pred_text": hypothesis, } @@ -326,17 +314,16 @@ def evaluate_sample(sample: dict[str, Any], config: AudioEvaluatorConfig) -> dic generation = sample.get("generation", "").strip() expected_answer = sample.get("expected_answer", "").strip() - if task_type in ["ASR", "ASR-PC", "ASR_LEADERBOARD", "AST", "Translation", "CER"] and not generation: + if task_type in ["ASR", "ASR-PC", "AST", "Translation", "CER"] and not generation: base = { "is_correct": False, "error": "missing_generation", - "predicted_answer": "", } if task_type in ["AST", "Translation"]: return {**base, "bleu": 0.0} if task_type == "CER": return {**base, "cer": 1.0} - # ASR / ASR-PC / ASR_LEADERBOARD + # ASR / ASR-PC return {**base, "wer": 1.0} if task_type == "ASR-PC": @@ -344,43 +331,31 @@ def evaluate_sample(sample: dict[str, Any], config: AudioEvaluatorConfig) -> dic expected_answer, generation, normalize_standard_wer=config.normalize_asr_pc_standard_wer ) updates.update(metrics) - updates["predicted_answer"] = generation elif task_type == "ASR": - metrics = evaluate_asr(expected_answer, generation, apply_normalization=config.apply_whisper_normalization) - updates.update(metrics) - updates["predicted_answer"] = generation - - elif task_type == "ASR_LEADERBOARD": - metrics = evaluate_asr_leaderboard(expected_answer, generation) + metrics = evaluate_asr(expected_answer, generation) updates.update(metrics) - updates["predicted_answer"] = generation elif task_type in ["AST", "Translation"]: metrics = evaluate_translation(expected_answer, generation) updates.update(metrics) - updates["predicted_answer"] = generation elif task_type == "CER": metrics = evaluate_cer(expected_answer, generation) updates.update(metrics) - updates["predicted_answer"] = generation elif task_type == "Hallucination": audio_context = {"audio_duration": sample.get("audio_duration")} metrics = evaluate_hallucination(expected_answer, generation, audio_context) updates.update(metrics) - updates["predicted_answer"] = generation elif task_type == "PC-Rate": metrics = evaluate_pc_rate(expected_answer, generation) updates.update(metrics) - updates["predicted_answer"] = generation else: if "requires_judge" not in sample: updates["requires_judge"] = True - updates["predicted_answer"] = generation if "is_correct" not in sample: updates["is_correct"] = False diff --git a/nemo_skills/evaluation/evaluator/compute_eval.py b/nemo_skills/evaluation/evaluator/compute_eval.py index eab2b2d6e0..0d803dbe66 100644 --- a/nemo_skills/evaluation/evaluator/compute_eval.py +++ b/nemo_skills/evaluation/evaluator/compute_eval.py @@ -20,7 +20,7 @@ from compute_eval.utils.eval_utils import get_nvcc_version, parse_semver from pydantic import Field, TypeAdapter -from nemo_skills.evaluation.evaluator import BaseEvaluator +from nemo_skills.evaluation.evaluator.base import BaseEvaluator from nemo_skills.utils import get_logger_name _LOG = logging.getLogger(get_logger_name(__file__)) diff --git a/requirements/audio.txt b/requirements/audio.txt new file mode 100644 index 0000000000..6cabfccf93 --- /dev/null +++ b/requirements/audio.txt @@ -0,0 +1,4 @@ +jiwer>=3.1.0,<4.0.0 # Word/Character Error Rate computation +sacrebleu # BLEU score computation +soundfile # Audio file I/O for dataset preparation +whisper-normalizer # Lightweight text normalization (EnglishTextNormalizer)