Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/evaluation/speech-audio.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ eval(
server_entrypoint="/workspace/megatron-lm/server.py",
server_container="/path/to/container.sqsh",
data_dir="/dataset",
installation_command="pip install sacrebleu jiwer openai-whisper"
installation_command="pip install -r requirements/audio.txt",
server_args="--inference-max-requests 1 --model-config /workspace/checkpoint/config.yaml",
)
```
Expand All @@ -98,8 +98,8 @@ eval(benchmarks="asr-leaderboard", split="librispeech_clean", ...)
--model=/workspace/path/to/checkpoint \
--server_entrypoint=/workspace/megatron-lm/server.py \
--server_container=/path/to/container.sqsh \
--data_dir=/dataset
--installation_command="pip install sacrebleu jiwer openai-whisper"
--data_dir=/dataset \
--installation_command="pip install -r requirements/audio.txt"
```

### MMAU-Pro
Expand Down
7 changes: 4 additions & 3 deletions nemo_skills/dataset/asr-leaderboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

# Settings that define how evaluation should be done by default (all can be changed from cmdline)
# Uses the audio evaluator which computes WER with HuggingFace leaderboard preprocessing
# Data samples should have task_type="ASR_LEADERBOARD" for proper WER calculation
# Uses the audio evaluator which computes WER with Whisper-style text normalization
# Data samples should have task_type="ASR" for proper WER calculation

DATASET_GROUP = "speechlm"
METRICS_TYPE = "audio"
GENERATION_ARGS = "++prompt_format=openai ++eval_type=audio"
EVAL_ARGS = "++eval_type=audio"
GENERATION_ARGS = "++prompt_format=openai"
42 changes: 36 additions & 6 deletions nemo_skills/dataset/asr-leaderboard/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,21 @@
from tqdm import tqdm

SYSTEM_MESSAGE = "You are a helpful assistant. /no_think"
MIN_AUDIO_DURATION = 0.1 # Skip audio shorter than this (causes mel spectrogram errors)
USER_MESSAGE = "Transcribe the audio file into English text."
MIN_AUDIO_DURATION = 0.1 # Skip audio shorter than this

# Speaker IDs to skip in Tedlium dataset
SKIP_SPEAKER_IDS = {"inter_segment_gap"}

# Non-speech tokens to skip in GigaSpeech dataset
NONSPEECH_TOKENS = {"<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"}


def is_nonspeech_only(text):
"""Check if text contains only non-speech tokens."""
tokens = set(text.strip().split())
return tokens and tokens.issubset(NONSPEECH_TOKENS)


# (hf_dataset, hf_config, hf_split, streaming)
DATASET_CONFIGS = {
Expand All @@ -59,12 +73,17 @@ def save_audio_and_format_entry(entry, dataset_name, audio_dir, sample_idx, with
text = text.strip() if text else ""

system_message = {"role": "system", "content": SYSTEM_MESSAGE}
user_message = {"role": "user", "content": "Transcribe the following audio."}
user_message = {"role": "user", "content": USER_MESSAGE}

audio_info = entry.get("audio", {})
if isinstance(audio_info, dict) and "array" in audio_info and "sampling_rate" in audio_info:
audio_array = audio_info["array"]
sampling_rate = audio_info["sampling_rate"]

# Skip if audio array is empty or invalid
if audio_array is None or len(audio_array) == 0:
return None

duration = len(audio_array) / sampling_rate

if duration < MIN_AUDIO_DURATION:
Expand All @@ -76,18 +95,24 @@ def save_audio_and_format_entry(entry, dataset_name, audio_dir, sample_idx, with
if with_audio:
sf.write(str(audio_dir / audio_filename), audio_array, sampling_rate)

audio_filepath = f"/dataset/asr-leaderboard/data/{dataset_name}/{audio_filename}"
user_message["audio"] = {
"path": f"/dataset/asr-leaderboard/data/{dataset_name}/{audio_filename}",
"path": audio_filepath,
"duration": float(duration),
}

formatted_entry = {
"task_type": "ASR_LEADERBOARD",
"task_type": "ASR",
"expected_answer": text,
"messages": [system_message, user_message],
"subset_for_metrics": dataset_name,
}

# Add audio_filepath and duration as top-level fields
if "audio" in user_message:
formatted_entry["audio_filepath"] = user_message["audio"]["path"]
formatted_entry["duration"] = user_message["audio"]["duration"]

if "id" in entry:
formatted_entry["id"] = entry["id"]
if "speaker_id" in entry:
Expand Down Expand Up @@ -133,12 +158,17 @@ def prepare_dataset(dataset_name, output_dir, with_audio=True):
if formatted is None:
skipped += 1
continue
if formatted["expected_answer"]:
# Skip empty answers, non-speech segments, and non-speech-only samples
speaker_id = entry.get("speaker_id", "")
expected = formatted["expected_answer"]
if expected and speaker_id not in SKIP_SPEAKER_IDS and not is_nonspeech_only(expected):
Comment on lines +162 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic checks expected for truthiness but is_nonspeech_only() already handles empty strings correctly (returns False for empty). the expected check is redundant for the is_nonspeech_only() branch.

fout.write(json.dumps(formatted) + "\n")
count += 1
else:
skipped += 1

if skipped > 0:
print(f"Skipped {skipped} samples with audio < {MIN_AUDIO_DURATION}s")
print(f"Skipped {skipped} samples (short audio, non-speech, or invalid)")

print(f"Saved {count} samples to {output_file}")
return count
Expand Down
103 changes: 39 additions & 64 deletions nemo_skills/evaluation/evaluator/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
import logging
import re
from functools import lru_cache
from typing import Any

import numpy as np
Expand All @@ -32,7 +33,6 @@ class AudioEvaluatorConfig(BaseEvaluatorConfig):
"""Configuration for audio evaluation."""

prompt_config: str = "eval/speechlm/audio"
apply_whisper_normalization: bool = True
normalize_asr_pc_standard_wer: bool = True


Expand Down Expand Up @@ -119,61 +119,34 @@ def evaluate_asr_pc(reference: str, hypothesis: str, normalize_standard_wer: boo
"wer_pc": wer_pc,
"per": per,
"is_correct": wer_pc < 0.5,
"text": ref_std,
"pred_text": hyp_std,
}


def preprocess_asr_text(text: str) -> str:
"""Apply Whisper-style normalization: lowercase, normalize, remove brackets."""
from whisper.normalizers import EnglishTextNormalizer

text = text.lower()
text = EnglishTextNormalizer()(text)
text = re.sub(r"(\[|\(|\{|\<)[^\(\)\\n\[\]]*(\]|\)|\}|\>)", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
@lru_cache(maxsize=1)
def _get_english_normalizer():
"""Lazily initialize and cache the English text normalizer."""
from whisper_normalizer.english import EnglishTextNormalizer

return EnglishTextNormalizer()
Comment on lines +127 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import inside function with lru_cache could cause issues. if import fails after first call, the cached normalizer becomes stale. move import to module level or inside the returned function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's problem to be solved by user


def preprocess_hf_leaderboard(text: str) -> str:
"""Apply HuggingFace leaderboard normalization: lowercase, remove punctuation, normalize unicode."""
import unicodedata

text = unicodedata.normalize("NFC", text)
text = text.lower()
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def preprocess_asr_text(text: str) -> str:
"""Apply Whisper-style normalization (lowercase, remove brackets, normalize whitespace)."""
return _get_english_normalizer()(text)


def evaluate_asr(reference: str, hypothesis: str, apply_normalization: bool = True) -> dict[str, Any]:
"""Evaluate ASR: computes WER with optional Whisper normalization."""
def evaluate_asr(reference: str, hypothesis: str) -> dict[str, Any]:
"""Evaluate ASR: computes WER with Whisper normalization."""
import jiwer

if apply_normalization:
ref = preprocess_asr_text(reference)
hyp = preprocess_asr_text(hypothesis)
else:
ref = normalize_whitespace(reference)
hyp = normalize_whitespace(hypothesis)
ref = preprocess_asr_text(reference)
hyp = preprocess_asr_text(hypothesis)

if not ref:
ref = "empty"
if not hyp:
hyp = "empty"

wer_score = jiwer.wer(ref, hyp)

return {
"wer": wer_score,
"is_correct": wer_score < 0.5,
}


def evaluate_asr_leaderboard(reference: str, hypothesis: str) -> dict[str, Any]:
"""Evaluate ASR with HuggingFace leaderboard preprocessing for direct comparison."""
import jiwer

ref = preprocess_hf_leaderboard(reference)
hyp = preprocess_hf_leaderboard(hypothesis)
# Store normalized texts before empty substitution
text = ref
pred_text = hyp
Comment on lines +147 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue with storing empty normalized text. The code stores the normalized text before replacing empty strings with "empty", which means text and pred_text could be empty strings in the returned dictionary.

This could cause issues downstream if code expects non-empty strings. The pattern is inconsistent with other evaluation functions in this file:

  • evaluate_asr_pc() (line 122-123) stores the normalized text BEFORE empty substitution
  • evaluate_translation() (line 171-172) stores the original .strip() text
  • evaluate_cer() (line 201-202) stores the original text without any empty handling

The inconsistency suggests this might not be intentional. If empty strings in text/pred_text fields are acceptable, this is fine. Otherwise, consider storing after the empty substitution or storing the original text like other functions do.


if not ref:
ref = "empty"
Expand All @@ -185,6 +158,8 @@ def evaluate_asr_leaderboard(reference: str, hypothesis: str) -> dict[str, Any]:
return {
"wer": wer_score,
"is_correct": wer_score < 0.5,
"text": text,
"pred_text": pred_text,
}


Expand All @@ -193,20 +168,25 @@ def evaluate_translation(reference: str, hypothesis: str) -> dict[str, Any]:
try:
import sacrebleu

ref = [reference.strip()]
hyp = hypothesis.strip()
bleu = sacrebleu.sentence_bleu(hyp, ref)
text = reference.strip()
pred_text = hypothesis.strip()
ref = [text]
bleu = sacrebleu.sentence_bleu(pred_text, ref)
bleu_score = bleu.score / 100.0

return {
"bleu": bleu_score,
"is_correct": bleu_score > 0.3,
"text": text,
"pred_text": pred_text,
}
except Exception as e:
return {
"bleu": 0.0,
"is_correct": False,
"error": str(e),
"text": reference.strip(),
"pred_text": hypothesis.strip(),
}


Expand All @@ -218,6 +198,8 @@ def evaluate_cer(reference: str, hypothesis: str) -> dict[str, Any]:
return {
"cer": cer_score,
"is_correct": cer_score < 0.5,
"text": reference,
"pred_text": hypothesis,
}


Expand All @@ -235,6 +217,8 @@ def evaluate_hallucination(reference: str, hypothesis: str, audio_context: dict
"char_rate": 0.0,
"is_correct": True,
"error": "missing_audio_duration",
"text": reference,
"pred_text": hypothesis,
}

char_count = len(hypothesis)
Expand All @@ -248,6 +232,8 @@ def evaluate_hallucination(reference: str, hypothesis: str, audio_context: dict
"hallucination_rate": 1.0 if is_hallucinating else 0.0,
"char_rate": round(char_rate, 2),
"is_correct": not is_hallucinating,
"text": reference,
"pred_text": hypothesis,
}


Expand Down Expand Up @@ -295,6 +281,8 @@ def evaluate_pc_rate(reference: str, hypothesis: str) -> dict[str, Any]:
"punct_f1": round(punct_f1, 3),
"cap_accuracy": round(cap_accuracy, 3),
"is_correct": pc_rate > 0.5,
"text": reference,
"pred_text": hypothesis,
}


Expand Down Expand Up @@ -326,61 +314,48 @@ def evaluate_sample(sample: dict[str, Any], config: AudioEvaluatorConfig) -> dic
generation = sample.get("generation", "").strip()
expected_answer = sample.get("expected_answer", "").strip()

if task_type in ["ASR", "ASR-PC", "ASR_LEADERBOARD", "AST", "Translation", "CER"] and not generation:
if task_type in ["ASR", "ASR-PC", "AST", "Translation", "CER"] and not generation:
base = {
"is_correct": False,
"error": "missing_generation",
"predicted_answer": "",
}
if task_type in ["AST", "Translation"]:
return {**base, "bleu": 0.0}
if task_type == "CER":
return {**base, "cer": 1.0}
# ASR / ASR-PC / ASR_LEADERBOARD
# ASR / ASR-PC
return {**base, "wer": 1.0}

if task_type == "ASR-PC":
metrics = evaluate_asr_pc(
expected_answer, generation, normalize_standard_wer=config.normalize_asr_pc_standard_wer
)
updates.update(metrics)
updates["predicted_answer"] = generation

elif task_type == "ASR":
metrics = evaluate_asr(expected_answer, generation, apply_normalization=config.apply_whisper_normalization)
updates.update(metrics)
updates["predicted_answer"] = generation

elif task_type == "ASR_LEADERBOARD":
metrics = evaluate_asr_leaderboard(expected_answer, generation)
metrics = evaluate_asr(expected_answer, generation)
updates.update(metrics)
updates["predicted_answer"] = generation

elif task_type in ["AST", "Translation"]:
metrics = evaluate_translation(expected_answer, generation)
updates.update(metrics)
updates["predicted_answer"] = generation

elif task_type == "CER":
metrics = evaluate_cer(expected_answer, generation)
updates.update(metrics)
updates["predicted_answer"] = generation

elif task_type == "Hallucination":
audio_context = {"audio_duration": sample.get("audio_duration")}
metrics = evaluate_hallucination(expected_answer, generation, audio_context)
updates.update(metrics)
updates["predicted_answer"] = generation

elif task_type == "PC-Rate":
metrics = evaluate_pc_rate(expected_answer, generation)
updates.update(metrics)
updates["predicted_answer"] = generation

else:
if "requires_judge" not in sample:
updates["requires_judge"] = True
updates["predicted_answer"] = generation
if "is_correct" not in sample:
updates["is_correct"] = False

Expand Down
2 changes: 1 addition & 1 deletion nemo_skills/evaluation/evaluator/compute_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from compute_eval.utils.eval_utils import get_nvcc_version, parse_semver
from pydantic import Field, TypeAdapter

from nemo_skills.evaluation.evaluator import BaseEvaluator
from nemo_skills.evaluation.evaluator.base import BaseEvaluator
from nemo_skills.utils import get_logger_name

_LOG = logging.getLogger(get_logger_name(__file__))
Expand Down
4 changes: 4 additions & 0 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
jiwer>=3.1.0,<4.0.0 # Word/Character Error Rate computation
sacrebleu # BLEU score computation
soundfile # Audio file I/O for dataset preparation
whisper-normalizer # Lightweight text normalization (EnglishTextNormalizer)