-
Notifications
You must be signed in to change notification settings - Fork 163
HF ASR Leaderboard Fix #1140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HF ASR Leaderboard Fix #1140
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Jorjeous marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,21 @@ | |
| from tqdm import tqdm | ||
|
|
||
| SYSTEM_MESSAGE = "You are a helpful assistant. /no_think" | ||
| MIN_AUDIO_DURATION = 0.1 # Skip audio shorter than this (causes mel spectrogram errors) | ||
| USER_MESSAGE = "Transcribe the audio file into English text." | ||
| MIN_AUDIO_DURATION = 0.1 # Skip audio shorter than this | ||
|
|
||
| # Speaker IDs to skip in Tedlium dataset | ||
| SKIP_SPEAKER_IDS = {"inter_segment_gap"} | ||
|
|
||
| # Non-speech tokens to skip in GigaSpeech dataset | ||
| NONSPEECH_TOKENS = {"<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"} | ||
|
|
||
|
|
||
| def is_nonspeech_only(text): | ||
| """Check if text contains only non-speech tokens.""" | ||
| tokens = set(text.strip().split()) | ||
| return tokens and tokens.issubset(NONSPEECH_TOKENS) | ||
|
|
||
|
|
||
| # (hf_dataset, hf_config, hf_split, streaming) | ||
| DATASET_CONFIGS = { | ||
|
|
@@ -59,12 +73,17 @@ def save_audio_and_format_entry(entry, dataset_name, audio_dir, sample_idx, with | |
| text = text.strip() if text else "" | ||
|
|
||
| system_message = {"role": "system", "content": SYSTEM_MESSAGE} | ||
| user_message = {"role": "user", "content": "Transcribe the following audio."} | ||
| user_message = {"role": "user", "content": USER_MESSAGE} | ||
|
|
||
| audio_info = entry.get("audio", {}) | ||
| if isinstance(audio_info, dict) and "array" in audio_info and "sampling_rate" in audio_info: | ||
| audio_array = audio_info["array"] | ||
| sampling_rate = audio_info["sampling_rate"] | ||
|
|
||
| # Skip if audio array is empty or invalid | ||
| if audio_array is None or len(audio_array) == 0: | ||
| return None | ||
|
|
||
| duration = len(audio_array) / sampling_rate | ||
|
|
||
| if duration < MIN_AUDIO_DURATION: | ||
|
|
@@ -76,18 +95,24 @@ def save_audio_and_format_entry(entry, dataset_name, audio_dir, sample_idx, with | |
| if with_audio: | ||
| sf.write(str(audio_dir / audio_filename), audio_array, sampling_rate) | ||
|
|
||
| audio_filepath = f"/dataset/asr-leaderboard/data/{dataset_name}/{audio_filename}" | ||
| user_message["audio"] = { | ||
| "path": f"/dataset/asr-leaderboard/data/{dataset_name}/{audio_filename}", | ||
| "path": audio_filepath, | ||
| "duration": float(duration), | ||
| } | ||
|
|
||
| formatted_entry = { | ||
| "task_type": "ASR_LEADERBOARD", | ||
| "task_type": "ASR", | ||
| "expected_answer": text, | ||
| "messages": [system_message, user_message], | ||
| "subset_for_metrics": dataset_name, | ||
| } | ||
|
|
||
| # Add audio_filepath and duration as top-level fields | ||
| if "audio" in user_message: | ||
| formatted_entry["audio_filepath"] = user_message["audio"]["path"] | ||
| formatted_entry["duration"] = user_message["audio"]["duration"] | ||
|
|
||
| if "id" in entry: | ||
| formatted_entry["id"] = entry["id"] | ||
| if "speaker_id" in entry: | ||
|
|
@@ -133,12 +158,17 @@ def prepare_dataset(dataset_name, output_dir, with_audio=True): | |
| if formatted is None: | ||
| skipped += 1 | ||
| continue | ||
| if formatted["expected_answer"]: | ||
| # Skip empty answers, non-speech segments, and non-speech-only samples | ||
| speaker_id = entry.get("speaker_id", "") | ||
| expected = formatted["expected_answer"] | ||
| if expected and speaker_id not in SKIP_SPEAKER_IDS and not is_nonspeech_only(expected): | ||
|
Comment on lines
+162
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic checks |
||
| fout.write(json.dumps(formatted) + "\n") | ||
| count += 1 | ||
| else: | ||
| skipped += 1 | ||
|
|
||
| if skipped > 0: | ||
| print(f"Skipped {skipped} samples with audio < {MIN_AUDIO_DURATION}s") | ||
| print(f"Skipped {skipped} samples (short audio, non-speech, or invalid)") | ||
|
|
||
| print(f"Saved {count} samples to {output_file}") | ||
| return count | ||
|
|
||
Jorjeous marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import asyncio | ||
| import logging | ||
| import re | ||
| from functools import lru_cache | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
|
|
@@ -32,7 +33,6 @@ class AudioEvaluatorConfig(BaseEvaluatorConfig): | |
| """Configuration for audio evaluation.""" | ||
|
|
||
| prompt_config: str = "eval/speechlm/audio" | ||
| apply_whisper_normalization: bool = True | ||
| normalize_asr_pc_standard_wer: bool = True | ||
|
|
||
|
|
||
|
|
@@ -119,61 +119,34 @@ def evaluate_asr_pc(reference: str, hypothesis: str, normalize_standard_wer: boo | |
| "wer_pc": wer_pc, | ||
| "per": per, | ||
| "is_correct": wer_pc < 0.5, | ||
| "text": ref_std, | ||
| "pred_text": hyp_std, | ||
| } | ||
|
|
||
|
|
||
| def preprocess_asr_text(text: str) -> str: | ||
| """Apply Whisper-style normalization: lowercase, normalize, remove brackets.""" | ||
| from whisper.normalizers import EnglishTextNormalizer | ||
|
|
||
| text = text.lower() | ||
| text = EnglishTextNormalizer()(text) | ||
| text = re.sub(r"(\[|\(|\{|\<)[^\(\)\\n\[\]]*(\]|\)|\}|\>)", "", text) | ||
| text = re.sub(r"\s+", " ", text).strip() | ||
| return text | ||
| @lru_cache(maxsize=1) | ||
| def _get_english_normalizer(): | ||
| """Lazily initialize and cache the English text normalizer.""" | ||
| from whisper_normalizer.english import EnglishTextNormalizer | ||
|
|
||
| return EnglishTextNormalizer() | ||
melllinia marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+127
to
+132
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import inside function with
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's problem to be solved by user |
||
|
|
||
| def preprocess_hf_leaderboard(text: str) -> str: | ||
| """Apply HuggingFace leaderboard normalization: lowercase, remove punctuation, normalize unicode.""" | ||
| import unicodedata | ||
|
|
||
| text = unicodedata.normalize("NFC", text) | ||
| text = text.lower() | ||
| text = re.sub(r"[^\w\s]", "", text) | ||
| text = re.sub(r"\s+", " ", text).strip() | ||
| return text | ||
| def preprocess_asr_text(text: str) -> str: | ||
| """Apply Whisper-style normalization (lowercase, remove brackets, normalize whitespace).""" | ||
| return _get_english_normalizer()(text) | ||
melllinia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def evaluate_asr(reference: str, hypothesis: str, apply_normalization: bool = True) -> dict[str, Any]: | ||
| """Evaluate ASR: computes WER with optional Whisper normalization.""" | ||
| def evaluate_asr(reference: str, hypothesis: str) -> dict[str, Any]: | ||
| """Evaluate ASR: computes WER with Whisper normalization.""" | ||
| import jiwer | ||
|
|
||
| if apply_normalization: | ||
| ref = preprocess_asr_text(reference) | ||
| hyp = preprocess_asr_text(hypothesis) | ||
| else: | ||
| ref = normalize_whitespace(reference) | ||
| hyp = normalize_whitespace(hypothesis) | ||
| ref = preprocess_asr_text(reference) | ||
| hyp = preprocess_asr_text(hypothesis) | ||
|
|
||
| if not ref: | ||
| ref = "empty" | ||
| if not hyp: | ||
| hyp = "empty" | ||
|
|
||
| wer_score = jiwer.wer(ref, hyp) | ||
|
|
||
| return { | ||
| "wer": wer_score, | ||
| "is_correct": wer_score < 0.5, | ||
| } | ||
|
|
||
|
|
||
| def evaluate_asr_leaderboard(reference: str, hypothesis: str) -> dict[str, Any]: | ||
| """Evaluate ASR with HuggingFace leaderboard preprocessing for direct comparison.""" | ||
| import jiwer | ||
|
|
||
| ref = preprocess_hf_leaderboard(reference) | ||
| hyp = preprocess_hf_leaderboard(hypothesis) | ||
| # Store normalized texts before empty substitution | ||
| text = ref | ||
| pred_text = hyp | ||
|
Comment on lines
+147
to
+149
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential issue with storing empty normalized text. The code stores the normalized text before replacing empty strings with "empty", which means This could cause issues downstream if code expects non-empty strings. The pattern is inconsistent with other evaluation functions in this file:
The inconsistency suggests this might not be intentional. If empty strings in |
||
|
|
||
| if not ref: | ||
| ref = "empty" | ||
|
|
@@ -185,6 +158,8 @@ def evaluate_asr_leaderboard(reference: str, hypothesis: str) -> dict[str, Any]: | |
| return { | ||
| "wer": wer_score, | ||
| "is_correct": wer_score < 0.5, | ||
| "text": text, | ||
| "pred_text": pred_text, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -193,20 +168,25 @@ def evaluate_translation(reference: str, hypothesis: str) -> dict[str, Any]: | |
| try: | ||
| import sacrebleu | ||
|
|
||
| ref = [reference.strip()] | ||
| hyp = hypothesis.strip() | ||
| bleu = sacrebleu.sentence_bleu(hyp, ref) | ||
| text = reference.strip() | ||
| pred_text = hypothesis.strip() | ||
| ref = [text] | ||
| bleu = sacrebleu.sentence_bleu(pred_text, ref) | ||
| bleu_score = bleu.score / 100.0 | ||
|
|
||
| return { | ||
| "bleu": bleu_score, | ||
| "is_correct": bleu_score > 0.3, | ||
| "text": text, | ||
| "pred_text": pred_text, | ||
| } | ||
| except Exception as e: | ||
| return { | ||
| "bleu": 0.0, | ||
| "is_correct": False, | ||
| "error": str(e), | ||
| "text": reference.strip(), | ||
| "pred_text": hypothesis.strip(), | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -218,6 +198,8 @@ def evaluate_cer(reference: str, hypothesis: str) -> dict[str, Any]: | |
| return { | ||
| "cer": cer_score, | ||
| "is_correct": cer_score < 0.5, | ||
| "text": reference, | ||
| "pred_text": hypothesis, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -235,6 +217,8 @@ def evaluate_hallucination(reference: str, hypothesis: str, audio_context: dict | |
| "char_rate": 0.0, | ||
| "is_correct": True, | ||
| "error": "missing_audio_duration", | ||
| "text": reference, | ||
| "pred_text": hypothesis, | ||
| } | ||
|
|
||
| char_count = len(hypothesis) | ||
|
|
@@ -248,6 +232,8 @@ def evaluate_hallucination(reference: str, hypothesis: str, audio_context: dict | |
| "hallucination_rate": 1.0 if is_hallucinating else 0.0, | ||
| "char_rate": round(char_rate, 2), | ||
| "is_correct": not is_hallucinating, | ||
| "text": reference, | ||
| "pred_text": hypothesis, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -295,6 +281,8 @@ def evaluate_pc_rate(reference: str, hypothesis: str) -> dict[str, Any]: | |
| "punct_f1": round(punct_f1, 3), | ||
| "cap_accuracy": round(cap_accuracy, 3), | ||
| "is_correct": pc_rate > 0.5, | ||
| "text": reference, | ||
| "pred_text": hypothesis, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -326,61 +314,48 @@ def evaluate_sample(sample: dict[str, Any], config: AudioEvaluatorConfig) -> dic | |
| generation = sample.get("generation", "").strip() | ||
| expected_answer = sample.get("expected_answer", "").strip() | ||
|
|
||
| if task_type in ["ASR", "ASR-PC", "ASR_LEADERBOARD", "AST", "Translation", "CER"] and not generation: | ||
| if task_type in ["ASR", "ASR-PC", "AST", "Translation", "CER"] and not generation: | ||
| base = { | ||
| "is_correct": False, | ||
| "error": "missing_generation", | ||
| "predicted_answer": "", | ||
| } | ||
| if task_type in ["AST", "Translation"]: | ||
| return {**base, "bleu": 0.0} | ||
| if task_type == "CER": | ||
| return {**base, "cer": 1.0} | ||
| # ASR / ASR-PC / ASR_LEADERBOARD | ||
| # ASR / ASR-PC | ||
| return {**base, "wer": 1.0} | ||
|
|
||
| if task_type == "ASR-PC": | ||
| metrics = evaluate_asr_pc( | ||
| expected_answer, generation, normalize_standard_wer=config.normalize_asr_pc_standard_wer | ||
| ) | ||
| updates.update(metrics) | ||
| updates["predicted_answer"] = generation | ||
|
|
||
| elif task_type == "ASR": | ||
| metrics = evaluate_asr(expected_answer, generation, apply_normalization=config.apply_whisper_normalization) | ||
| updates.update(metrics) | ||
| updates["predicted_answer"] = generation | ||
|
|
||
| elif task_type == "ASR_LEADERBOARD": | ||
| metrics = evaluate_asr_leaderboard(expected_answer, generation) | ||
| metrics = evaluate_asr(expected_answer, generation) | ||
| updates.update(metrics) | ||
| updates["predicted_answer"] = generation | ||
|
|
||
| elif task_type in ["AST", "Translation"]: | ||
| metrics = evaluate_translation(expected_answer, generation) | ||
| updates.update(metrics) | ||
| updates["predicted_answer"] = generation | ||
|
|
||
| elif task_type == "CER": | ||
| metrics = evaluate_cer(expected_answer, generation) | ||
| updates.update(metrics) | ||
| updates["predicted_answer"] = generation | ||
|
|
||
| elif task_type == "Hallucination": | ||
| audio_context = {"audio_duration": sample.get("audio_duration")} | ||
| metrics = evaluate_hallucination(expected_answer, generation, audio_context) | ||
| updates.update(metrics) | ||
| updates["predicted_answer"] = generation | ||
|
|
||
| elif task_type == "PC-Rate": | ||
| metrics = evaluate_pc_rate(expected_answer, generation) | ||
| updates.update(metrics) | ||
| updates["predicted_answer"] = generation | ||
|
|
||
| else: | ||
| if "requires_judge" not in sample: | ||
| updates["requires_judge"] = True | ||
| updates["predicted_answer"] = generation | ||
| if "is_correct" not in sample: | ||
| updates["is_correct"] = False | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| jiwer>=3.1.0,<4.0.0 # Word/Character Error Rate computation | ||
Jorjeous marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sacrebleu # BLEU score computation | ||
| soundfile # Audio file I/O for dataset preparation | ||
| whisper-normalizer # Lightweight text normalization (EnglishTextNormalizer) | ||
Jorjeous marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.