-
Notifications
You must be signed in to change notification settings - Fork 163
Option to pass a template to format input #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ec7a4f6
83dd6f2
a08fcdd
545f09e
c010c57
62164ed
081cd0e
1260e81
e4d35c5
ef650c6
2353efe
e0e9009
b6e3f46
0863b47
7b610ff
3c9970e
ef68fe7
0d655ed
e992227
8d6766d
66af4d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import asyncio | ||
| import logging | ||
| import re | ||
| from functools import lru_cache | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
|
|
@@ -168,8 +169,8 @@ def evaluate_asr_pc( | |
| wer_c = jiwer.wer(ref_c, hyp_c) | ||
|
|
||
| if normalize_standard_wer: | ||
| ref_std = preprocess_asr_text(reference, mode=normalization_mode) | ||
| hyp_std = preprocess_asr_text(hypothesis, mode=normalization_mode) | ||
| ref_std = preprocess_asr_text(reference) | ||
| hyp_std = preprocess_asr_text(hypothesis) | ||
|
Comment on lines
+172
to
+173
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
At Line 166 and Line 167, ASR-PC standard WER always uses Whisper normalization when enabled, regardless of Suggested fix (make config effective and fail on unsupported modes) def evaluate_asr_pc(
reference: str, hypothesis: str, normalize_standard_wer: bool = True, normalization_mode: str = "standard"
) -> dict[str, Any]:
@@
- if normalize_standard_wer:
- ref_std = preprocess_asr_text(reference)
- hyp_std = preprocess_asr_text(hypothesis)
+ if normalize_standard_wer:
+ if normalization_mode == "standard":
+ ref_std = preprocess_asr_text(reference)
+ hyp_std = preprocess_asr_text(hypothesis)
+ elif normalization_mode == "none":
+ ref_std = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower()))
+ hyp_std = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower()))
+ else:
+ raise ValueError(f"Unsupported normalization_mode: {normalization_mode}")
else:
ref_std = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower()))
hyp_std = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower()))-def evaluate_asr(reference: str, hypothesis: str) -> dict[str, Any]:
+def evaluate_asr(reference: str, hypothesis: str, apply_whisper_normalization: bool = True) -> dict[str, Any]:
@@
- ref = preprocess_asr_text(reference)
- hyp = preprocess_asr_text(hypothesis)
+ if apply_whisper_normalization:
+ ref = preprocess_asr_text(reference)
+ hyp = preprocess_asr_text(hypothesis)
+ else:
+ ref = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower()))
+ hyp = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower()))- elif task_type == "ASR":
- metrics = evaluate_asr(expected_answer, generation)
+ elif task_type == "ASR":
+ metrics = evaluate_asr(
+ expected_answer,
+ generation,
+ apply_whisper_normalization=config.apply_whisper_normalization,
+ )
updates.update(metrics)As per coding guidelines "Avoid cases where user-passed parameters are unused; code should fail if user specifies an unsupported argument or if a required argument is missing. Use dataclass or **kwargs syntax to handle this automatically". Also applies to: 399-399 🤖 Prompt for AI Agents |
||
| else: | ||
| ref_std = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower())) | ||
| hyp_std = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower())) | ||
|
|
@@ -276,57 +277,17 @@ def resolve_asr_normalization_mode(config: AudioEvaluatorConfig) -> str: | |
| return config.normalization_mode if config.apply_whisper_normalization else "none" | ||
|
|
||
|
|
||
| def preprocess_asr_text(text: str, mode: str = "standard") -> str: | ||
| """Normalize ASR text for WER calculation. | ||
|
|
||
| Args: | ||
| text: Raw text. | ||
| mode: Normalization mode: | ||
| - "standard": Whisper normalization (default) - converts number words to digits | ||
| - "audiobench": Full AudioBench normalization (whisper + digits to words + more) | ||
| - "hf_leaderboard": HuggingFace leaderboard style (whisper normalization) | ||
| - "none": No normalization (whitespace only) | ||
| - "no_tn_itn": Lowercase + remove punctuation, no number word conversion (for TN/ITN eval) | ||
| """ | ||
| if mode not in VALID_NORMALIZATION_MODES: | ||
| raise ValueError( | ||
| f"Invalid normalization_mode '{mode}'. Available options: {', '.join(VALID_NORMALIZATION_MODES)}" | ||
| ) | ||
|
|
||
| if mode == "none": | ||
| return re.sub(r"\s+", " ", text).strip() | ||
|
|
||
| if mode == "no_tn_itn": | ||
| # Lowercase + remove punctuation + whitespace normalization | ||
| text = text.lower() | ||
| text = re.sub(r"[^\w\s]", "", text) | ||
| return re.sub(r"\s+", " ", text).strip() | ||
|
|
||
| # "standard", "audiobench", and "hf_leaderboard" all use whisper normalization | ||
| @lru_cache(maxsize=1) | ||
| def _get_english_normalizer(): | ||
| """Lazily initialize and cache the English text normalizer.""" | ||
| from whisper_normalizer.english import EnglishTextNormalizer | ||
|
|
||
| text = text.lower() | ||
| text = EnglishTextNormalizer()(text) | ||
|
|
||
| if mode == "audiobench": | ||
| # Additional audiobench-specific normalization | ||
| import jiwer | ||
|
|
||
| text = _normalize_digits_to_words(text) | ||
| text = _expand_contractions(text) | ||
| text = re.sub(r"(\[|\(|\{|\<)[^\(\)\\n\[\]]*(\]|\)|\}|\>)", "", text) | ||
| jiwer_process = jiwer.Compose( | ||
| [ | ||
| jiwer.RemoveMultipleSpaces(), | ||
| jiwer.ExpandCommonEnglishContractions(), | ||
| jiwer.RemoveKaldiNonWords(), | ||
| jiwer.RemovePunctuation(), | ||
| ] | ||
| ) | ||
| text = jiwer_process(text) | ||
| text = _remove_non_speech_elements(text) | ||
| return EnglishTextNormalizer() | ||
|
|
||
| return re.sub(r"\s+", " ", text).strip() | ||
|
|
||
| def preprocess_asr_text(text: str) -> str: | ||
| """Apply Whisper-style normalization (lowercase, remove brackets, normalize whitespace).""" | ||
| return _get_english_normalizer()(text) | ||
|
|
||
|
|
||
| def evaluate_asr(reference: str, hypothesis: str, normalization_mode: str = "standard") -> dict[str, Any]: | ||
|
|
@@ -339,8 +300,12 @@ def evaluate_asr(reference: str, hypothesis: str, normalization_mode: str = "sta | |
| """ | ||
| import jiwer | ||
|
|
||
| ref = preprocess_asr_text(reference, mode=normalization_mode) | ||
| hyp = preprocess_asr_text(hypothesis, mode=normalization_mode) | ||
| ref = preprocess_asr_text(reference) | ||
| hyp = preprocess_asr_text(hypothesis) | ||
|
|
||
| # Store normalized texts before empty substitution | ||
| text = ref | ||
| pred_text = hyp | ||
|
|
||
| if not ref: | ||
| ref = "empty" | ||
|
|
@@ -352,8 +317,8 @@ def evaluate_asr(reference: str, hypothesis: str, normalization_mode: str = "sta | |
| return { | ||
| "wer": wer_score, | ||
| "is_correct": wer_score < 0.5, | ||
| "text": ref, | ||
| "pred_text": hyp, | ||
| "text": text, | ||
| "pred_text": pred_text, | ||
| } | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -36,6 +36,7 @@ | |||||||||||||||
| from transformers import AutoTokenizer | ||||||||||||||||
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | ||||||||||||||||
|
|
||||||||||||||||
| from nemo_skills.prompt.utils import load_config as load_prompt_config | ||||||||||||||||
| from nemo_skills.utils import setup_make_sequence_length_divisible_by | ||||||||||||||||
|
|
||||||||||||||||
| TokenizerType = PreTrainedTokenizerBase | ||||||||||||||||
|
|
@@ -88,6 +89,7 @@ def __init__( | |||||||||||||||
| output_key: str = "output", | ||||||||||||||||
| num_proc: int | None = None, | ||||||||||||||||
| force_reprocess: bool = False, | ||||||||||||||||
| input_template_path: str | None = None, | ||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also wouldn't this fail for typical cases when e.g. you have keys like problem and generation in your file, but the logic for setting defaults checks for input / output or messages? So that part should also be updated
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Kipok I didn't understand this. When a user is passing
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean the logic in detect_data_format function |
||||||||||||||||
| ): | ||||||||||||||||
| self.input_key = input_key | ||||||||||||||||
| self.output_key = output_key | ||||||||||||||||
|
|
@@ -100,6 +102,13 @@ def __init__( | |||||||||||||||
| else: | ||||||||||||||||
| self.num_proc = num_proc | ||||||||||||||||
|
|
||||||||||||||||
| self.input_template = None | ||||||||||||||||
| if input_template_path: | ||||||||||||||||
| input_template_config = load_prompt_config(input_template_path) | ||||||||||||||||
| if "user" not in input_template_config: | ||||||||||||||||
| raise KeyError(f"'user' key is missing in the YAML file: {input_template_path}") | ||||||||||||||||
| self.input_template = input_template_config["user"] | ||||||||||||||||
|
Comment on lines
+105
to
+110
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major 🧩 Analysis chainVerify error handling in The template loading logic is sound, but ensure that Run the following script to examine the error handling in Note: Line 108 has a long error message in the exception. Consider extracting it to a variable for better style compliance. - if "user" not in input_template_config:
- raise KeyError(f"'user' key is missing in the YAML file: {input_template_path}")
+ if "user" not in input_template_config:
+ msg = f"'user' key is missing in the YAML file: {input_template_path}"
+ raise KeyError(msg)🏁 Script executed: #!/bin/bash
# Description: Check error handling in load_prompt_config function
# Find and display the load_config function implementation
ast-grep --pattern $'def load_config($$$) {
$$$
}'Length of output: 0 🏁 Script executed: #!/bin/bash
# Locate the implementation of load_prompt_config in the Python codebase
rg -nP --type=py "def load_prompt_config" -C3Length of output: 0 🏁 Script executed: #!/bin/bash
# Search for load_prompt_config references across the repository
rg -n "load_prompt_config"Length of output: 237 🏁 Script executed: #!/bin/bash
# Display the implementation of load_config in the prompt utils module
rg -n "def load_config" -C10 nemo_skills/promptLength of output: 1538 🏁 Script executed: #!/bin/bash
# Show full implementation of load_config to inspect error handling
sed -n '350,450p' nemo_skills/prompt/utils.pyLength of output: 1814 Add error handling in
🧰 Tools🪛 Ruff (0.13.3)108-108: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||
|
|
||||||||||||||||
| # Train split | ||||||||||||||||
| self.formatted_ds = { | ||||||||||||||||
| "train": self.load_or_process_split(train_ds_path, "train"), | ||||||||||||||||
|
|
@@ -130,11 +139,22 @@ def load_or_process_split(self, path: str, split_name: str) -> Dataset: | |||||||||||||||
| print(f"[Map] Processing {split_name} dataset from: {path}") | ||||||||||||||||
| dataset = load_dataset("json", data_files=str(path))["train"] | ||||||||||||||||
|
|
||||||||||||||||
| current_input_key = self.input_key | ||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wasiahmad this is an important thing to fix if you want to merge this |
||||||||||||||||
| if self.input_template: | ||||||||||||||||
| assert "messages" not in dataset.column_names | ||||||||||||||||
wasiahmad marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
| dataset = dataset.map( | ||||||||||||||||
| self.apply_input_template, | ||||||||||||||||
| batched=True, | ||||||||||||||||
| num_proc=self.num_proc, | ||||||||||||||||
| ) | ||||||||||||||||
| current_input_key = "formatted_input" | ||||||||||||||||
|
|
||||||||||||||||
| if "messages" not in dataset.column_names: | ||||||||||||||||
| dataset = dataset.map( | ||||||||||||||||
| self.add_messages_key, | ||||||||||||||||
| batched=True, | ||||||||||||||||
| num_proc=self.num_proc, | ||||||||||||||||
| fn_kwargs={"input_key": current_input_key}, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Save dataset + new size signature | ||||||||||||||||
|
|
@@ -146,17 +166,26 @@ def load_or_process_split(self, path: str, split_name: str) -> Dataset: | |||||||||||||||
| print(f"[Cache] Saved {split_name} dataset to: {cache_dir}") | ||||||||||||||||
| return dataset | ||||||||||||||||
|
|
||||||||||||||||
| def add_messages_key(self, examples: dict[str, list[Any]]) -> dict[str, list[list[dict[str, Any]]]]: | ||||||||||||||||
| def add_messages_key( | ||||||||||||||||
| self, examples: dict[str, list[Any]], input_key: str | ||||||||||||||||
| ) -> dict[str, list[list[dict[str, Any]]]]: | ||||||||||||||||
| return { | ||||||||||||||||
| "messages": [ | ||||||||||||||||
| [ | ||||||||||||||||
| {"role": "user", "content": input_}, | ||||||||||||||||
| {"role": "assistant", "content": output}, | ||||||||||||||||
| ] | ||||||||||||||||
| for input_, output in zip(examples[self.input_key], examples[self.output_key]) | ||||||||||||||||
| for input_, output in zip(examples[input_key], examples[self.output_key]) | ||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Dataset Fails When Template Path MissingIf |
||||||||||||||||
| ] | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| def apply_input_template(self, examples: dict[str, list[Any]]) -> dict[str, list[str]]: | ||||||||||||||||
| keys = [k.strip() for k in self.input_key.split(";")] | ||||||||||||||||
| examples["formatted_input"] = [ | ||||||||||||||||
| self.input_template.format(**{k: examples[k][i] for k in keys}) for i in range(len(examples[keys[0]])) | ||||||||||||||||
|
Comment on lines
+184
to
+185
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. accessing
Suggested change
|
||||||||||||||||
| ] | ||||||||||||||||
| return examples | ||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Template Application Fails on Invalid KeysThe |
||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def parse_args(): | ||||||||||||||||
| """Parse command line arguments.""" | ||||||||||||||||
|
|
@@ -235,6 +264,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): | |||||||||||||||
| data_config["input_key"], | ||||||||||||||||
| data_config["output_key"], | ||||||||||||||||
| force_reprocess=data_config.get("force_reprocess", False), | ||||||||||||||||
| input_template_path=data_config.get("input_template_path", None), | ||||||||||||||||
| ) | ||||||||||||||||
| print(f" ✓ Training dataset loaded with {len(data.formatted_ds['train'])} samples.") | ||||||||||||||||
| if data.formatted_ds["validation"] is not None: | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 18378
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 4129
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 1613
Enforce
task_type="ASR"validation in the evaluator, not only in comments.The
evaluate_samplefunction atnemo_skills/evaluation/evaluator/audio.py:472uses.get("task_type", "unknown")with a silent default. If task_type is missing or doesn't match expected values (ASR, ASR-PC, ASR_LEADERBOARD, etc.), the code silently falls through to the else clause (lines 528–531), which skips WER computation and returns minimal fields. This contradicts the documented requirement innemo_skills/dataset/asr-leaderboard/__init__.py:17-18that data samples should havetask_type="ASR"for proper WER calculation.Use direct access
sample["task_type"]instead of.get()and add explicit validation before metric computation to fail fast when task_type is missing or invalid, preventing silent metric loss.🤖 Prompt for AI Agents