Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ec7a4f6
adding an input template for formatting input
wasiahmad Oct 3, 2025
83dd6f2
adding an input template for formatting input
wasiahmad Oct 3, 2025
a08fcdd
minor bug fix
wasiahmad Oct 3, 2025
545f09e
replacing comma with semicolon to avoid hydra issues
wasiahmad Oct 3, 2025
c010c57
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 4, 2025
62164ed
using nemo-skills load config
wasiahmad Oct 4, 2025
081cd0e
Merge branch 'main' into apply_input_template
wasiahmad Oct 6, 2025
1260e81
Merge branch 'main' into apply_input_template
wasiahmad Oct 14, 2025
e4d35c5
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 15, 2025
ef650c6
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 17, 2025
2353efe
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 17, 2025
e0e9009
resolving conflicts
wasiahmad Oct 24, 2025
b6e3f46
Merge branch 'main' into apply_input_template
wasiahmad Oct 27, 2025
0863b47
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 28, 2025
7b610ff
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 28, 2025
3c9970e
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 29, 2025
ef68fe7
Merge branch 'main' into apply_input_template
wasiahmad Dec 8, 2025
0d655ed
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Jan 30, 2026
e992227
HF ASR Leaderboard Fix (#1140)
melllinia Jan 15, 2026
8d6766d
Merge branch 'main' into apply_input_template
wasiahmad Mar 3, 2026
66af4d5
Merge branch 'main' into apply_input_template
wasiahmad Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions nemo_skills/dataset/asr-leaderboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

# Settings that define how evaluation should be done by default (all can be changed from cmdline)

# Uses the audio evaluator which computes WER with HuggingFace leaderboard preprocessing
# Uses the audio evaluator which computes WER with Whisper-style text normalization
# Data samples should have task_type="ASR" for proper WER calculation
Comment on lines +17 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Locate task_type handling
echo "=== Search 1: task_type handling ==="
rg -n -C3 --type=py '\btask_type\b' 2>/dev/null | head -100

# 2) Locate audio evaluation / WER entry points
echo ""
echo "=== Search 2: Audio eval / WER entry points ==="
rg -n -C3 --type=py 'WER|word.?error|whisper|eval_type\s*=\s*audio|METRICS_TYPE\s*=\s*"audio"' 2>/dev/null | head -100

# 3) Look specifically for ASR guards/assertions
echo ""
echo "=== Search 3: ASR guards/assertions ==="
rg -n -C3 --type=py 'assert .*task_type|if .*task_type.*ASR|task_type\s*==\s*["'"'"']ASR["'"'"']' 2>/dev/null | head -100

Repository: NVIDIA-NeMo/Skills

Length of output: 18378


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the complete evaluate_sample function
echo "=== Full evaluate_sample function ==="
rg -A 70 'def evaluate_sample\(sample: dict' nemo_skills/evaluation/evaluator/audio.py

# Check if there's any else/default handling for unmatched task_type
echo ""
echo "=== Check for else/default handling ==="
rg -A 80 'elif task_type == "PC-Rate"' nemo_skills/evaluation/evaluator/audio.py

Repository: NVIDIA-NeMo/Skills

Length of output: 4129


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find where evaluate_sample is called
echo "=== Callers of evaluate_sample ==="
rg -B5 -A3 'evaluate_sample\(' nemo_skills/evaluation/evaluator/audio.py | head -60

# Check if there's any validation of task_type before calling evaluate_sample
echo ""
echo "=== Check for task_type validation before evaluate_sample call ==="
rg -B10 'evaluate_sample' nemo_skills/evaluation/evaluator/audio.py | grep -A 10 'def '

Repository: NVIDIA-NeMo/Skills

Length of output: 1613


Enforce task_type="ASR" validation in the evaluator, not only in comments.

The evaluate_sample function at nemo_skills/evaluation/evaluator/audio.py:472 uses .get("task_type", "unknown") with a silent default. If task_type is missing or doesn't match expected values (ASR, ASR-PC, ASR_LEADERBOARD, etc.), the code silently falls through to the else clause (lines 528–531), which skips WER computation and returns minimal fields. This contradicts the documented requirement in nemo_skills/dataset/asr-leaderboard/__init__.py:17-18 that data samples should have task_type="ASR" for proper WER calculation.

Use direct access sample["task_type"] instead of .get() and add explicit validation before metric computation to fail fast when task_type is missing or invalid, preventing silent metric loss.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/dataset/asr-leaderboard/__init__.py` around lines 17 - 18, The
evaluator currently uses sample.get("task_type", "unknown") which silently skips
WER computation; update evaluate_sample in
nemo_skills/evaluation/evaluator/audio.py to access sample["task_type"] directly
and add explicit validation (raise a clear exception) for allowed values (e.g.,
"ASR", "ASR-PC", "ASR_LEADERBOARD") before any metric computation so
missing/invalid task_type fails fast instead of falling through to the else
branch that omits WER.


REQUIRES_DATA_DIR = True
METRICS_TYPE = "audio"
EVAL_ARGS = "++eval_type=audio ++eval_config.normalization_mode=hf_leaderboard"
GENERATION_ARGS = "++prompt_format=openai ++enable_audio=true"
EVAL_ARGS = "++eval_type=audio"
GENERATION_ARGS = "++prompt_format=openai"
73 changes: 19 additions & 54 deletions nemo_skills/evaluation/evaluator/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
import logging
import re
from functools import lru_cache
from typing import Any

import numpy as np
Expand Down Expand Up @@ -168,8 +169,8 @@ def evaluate_asr_pc(
wer_c = jiwer.wer(ref_c, hyp_c)

if normalize_standard_wer:
ref_std = preprocess_asr_text(reference, mode=normalization_mode)
hyp_std = preprocess_asr_text(hypothesis, mode=normalization_mode)
ref_std = preprocess_asr_text(reference)
hyp_std = preprocess_asr_text(hypothesis)
Comment on lines +172 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

apply_whisper_normalization / normalization_mode are effectively ignored in the new flow.

At Line 166 and Line 167, ASR-PC standard WER always uses Whisper normalization when enabled, regardless of normalization_mode. At Line 399, ASR always calls evaluate_asr(...) without honoring config.apply_whisper_normalization. This silently ignores user-provided settings.

Suggested fix (make config effective and fail on unsupported modes)
 def evaluate_asr_pc(
     reference: str, hypothesis: str, normalize_standard_wer: bool = True, normalization_mode: str = "standard"
 ) -> dict[str, Any]:
@@
-    if normalize_standard_wer:
-        ref_std = preprocess_asr_text(reference)
-        hyp_std = preprocess_asr_text(hypothesis)
+    if normalize_standard_wer:
+        if normalization_mode == "standard":
+            ref_std = preprocess_asr_text(reference)
+            hyp_std = preprocess_asr_text(hypothesis)
+        elif normalization_mode == "none":
+            ref_std = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower()))
+            hyp_std = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower()))
+        else:
+            raise ValueError(f"Unsupported normalization_mode: {normalization_mode}")
     else:
         ref_std = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower()))
         hyp_std = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower()))
-def evaluate_asr(reference: str, hypothesis: str) -> dict[str, Any]:
+def evaluate_asr(reference: str, hypothesis: str, apply_whisper_normalization: bool = True) -> dict[str, Any]:
@@
-    ref = preprocess_asr_text(reference)
-    hyp = preprocess_asr_text(hypothesis)
+    if apply_whisper_normalization:
+        ref = preprocess_asr_text(reference)
+        hyp = preprocess_asr_text(hypothesis)
+    else:
+        ref = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower()))
+        hyp = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower()))
-    elif task_type == "ASR":
-        metrics = evaluate_asr(expected_answer, generation)
+    elif task_type == "ASR":
+        metrics = evaluate_asr(
+            expected_answer,
+            generation,
+            apply_whisper_normalization=config.apply_whisper_normalization,
+        )
         updates.update(metrics)

As per coding guidelines "Avoid cases where user-passed parameters are unused; code should fail if user specifies an unsupported argument or if a required argument is missing. Use dataclass or **kwargs syntax to handle this automatically".

Also applies to: 399-399

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/evaluation/evaluator/audio.py` around lines 166 - 167, The code
currently always applies Whisper normalization by calling
preprocess_asr_text(reference) and preprocess_asr_text(hypothesis) and calls
evaluate_asr(...) without honoring config.apply_whisper_normalization or
config.normalization_mode; change the flow so that before preprocessing (or
before calling evaluate_asr) you check config.apply_whisper_normalization and
config.normalization_mode: if apply_whisper_normalization is True and
normalization_mode == "whisper" call the Whisper-specific normalization routine
(or call preprocess_asr_text with a mode parameter), if
apply_whisper_normalization is False skip Whisper normalization and use the
standard text preprocessing, and if normalization_mode is set to an unsupported
value raise an explicit error; also update the evaluate_asr(...) call site to
pass or respect these config flags rather than ignoring them so user settings
are enforced (refer to preprocess_asr_text, evaluate_asr,
config.apply_whisper_normalization, and config.normalization_mode).

else:
ref_std = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower()))
hyp_std = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower()))
Expand Down Expand Up @@ -276,57 +277,17 @@ def resolve_asr_normalization_mode(config: AudioEvaluatorConfig) -> str:
return config.normalization_mode if config.apply_whisper_normalization else "none"


def preprocess_asr_text(text: str, mode: str = "standard") -> str:
"""Normalize ASR text for WER calculation.

Args:
text: Raw text.
mode: Normalization mode:
- "standard": Whisper normalization (default) - converts number words to digits
- "audiobench": Full AudioBench normalization (whisper + digits to words + more)
- "hf_leaderboard": HuggingFace leaderboard style (whisper normalization)
- "none": No normalization (whitespace only)
- "no_tn_itn": Lowercase + remove punctuation, no number word conversion (for TN/ITN eval)
"""
if mode not in VALID_NORMALIZATION_MODES:
raise ValueError(
f"Invalid normalization_mode '{mode}'. Available options: {', '.join(VALID_NORMALIZATION_MODES)}"
)

if mode == "none":
return re.sub(r"\s+", " ", text).strip()

if mode == "no_tn_itn":
# Lowercase + remove punctuation + whitespace normalization
text = text.lower()
text = re.sub(r"[^\w\s]", "", text)
return re.sub(r"\s+", " ", text).strip()

# "standard", "audiobench", and "hf_leaderboard" all use whisper normalization
@lru_cache(maxsize=1)
def _get_english_normalizer():
"""Lazily initialize and cache the English text normalizer."""
from whisper_normalizer.english import EnglishTextNormalizer

text = text.lower()
text = EnglishTextNormalizer()(text)

if mode == "audiobench":
# Additional audiobench-specific normalization
import jiwer

text = _normalize_digits_to_words(text)
text = _expand_contractions(text)
text = re.sub(r"(\[|\(|\{|\<)[^\(\)\\n\[\]]*(\]|\)|\}|\>)", "", text)
jiwer_process = jiwer.Compose(
[
jiwer.RemoveMultipleSpaces(),
jiwer.ExpandCommonEnglishContractions(),
jiwer.RemoveKaldiNonWords(),
jiwer.RemovePunctuation(),
]
)
text = jiwer_process(text)
text = _remove_non_speech_elements(text)
return EnglishTextNormalizer()

return re.sub(r"\s+", " ", text).strip()

def preprocess_asr_text(text: str) -> str:
"""Apply Whisper-style normalization (lowercase, remove brackets, normalize whitespace)."""
return _get_english_normalizer()(text)


def evaluate_asr(reference: str, hypothesis: str, normalization_mode: str = "standard") -> dict[str, Any]:
Expand All @@ -339,8 +300,12 @@ def evaluate_asr(reference: str, hypothesis: str, normalization_mode: str = "sta
"""
import jiwer

ref = preprocess_asr_text(reference, mode=normalization_mode)
hyp = preprocess_asr_text(hypothesis, mode=normalization_mode)
ref = preprocess_asr_text(reference)
hyp = preprocess_asr_text(hypothesis)

# Store normalized texts before empty substitution
text = ref
pred_text = hyp

if not ref:
ref = "empty"
Expand All @@ -352,8 +317,8 @@ def evaluate_asr(reference: str, hypothesis: str, normalization_mode: str = "sta
return {
"wer": wer_score,
"is_correct": wer_score < 0.5,
"text": ref,
"pred_text": hyp,
"text": text,
"pred_text": pred_text,
}


Expand Down
34 changes: 32 additions & 2 deletions nemo_skills/training/nemo_rl/start_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from nemo_skills.prompt.utils import load_config as load_prompt_config
from nemo_skills.utils import setup_make_sequence_length_divisible_by

TokenizerType = PreTrainedTokenizerBase
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
output_key: str = "output",
num_proc: int | None = None,
force_reprocess: bool = False,
input_template_path: str | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also wouldn't this fail for typical cases when e.g. you have keys like problem and generation in your file, but the logic for setting defaults checks for input / output or messages? So that part should also be updated

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kipok I didn't understand this. When a user is passing input_template_path, their provided input_key field must match with keys present in the input_template. Why having keys like problem and generation would cause failure?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the logic in detect_data_format function

):
self.input_key = input_key
self.output_key = output_key
Expand All @@ -100,6 +102,13 @@ def __init__(
else:
self.num_proc = num_proc

self.input_template = None
if input_template_path:
input_template_config = load_prompt_config(input_template_path)
if "user" not in input_template_config:
raise KeyError(f"'user' key is missing in the YAML file: {input_template_path}")
self.input_template = input_template_config["user"]
Comment on lines +105 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify error handling in load_prompt_config.

The template loading logic is sound, but ensure that load_prompt_config properly handles file errors (missing files, permission issues, invalid YAML) to provide clear error messages to users.

Run the following script to examine the error handling in load_prompt_config:

Note: Line 108 has a long error message in the exception. Consider extracting it to a variable for better style compliance.

-            if "user" not in input_template_config:
-                raise KeyError(f"'user' key is missing in the YAML file: {input_template_path}")
+            if "user" not in input_template_config:
+                msg = f"'user' key is missing in the YAML file: {input_template_path}"
+                raise KeyError(msg)

🏁 Script executed:

#!/bin/bash
# Description: Check error handling in load_prompt_config function

# Find and display the load_config function implementation
ast-grep --pattern $'def load_config($$$) {
  $$$
}'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Locate the implementation of load_prompt_config in the Python codebase
rg -nP --type=py "def load_prompt_config" -C3

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for load_prompt_config references across the repository
rg -n "load_prompt_config"

Length of output: 237


🏁 Script executed:

#!/bin/bash
# Display the implementation of load_config in the prompt utils module
rg -n "def load_config" -C10 nemo_skills/prompt

Length of output: 1538


🏁 Script executed:

#!/bin/bash
# Show full implementation of load_config to inspect error handling
sed -n '350,450p' nemo_skills/prompt/utils.py

Length of output: 1814


Add error handling in load_config and extract literal messages to variables

  • In nemo_skills/prompt/utils.py (load_config), wrap the open(config_path) and yaml.safe_load calls in a try/except that catches FileNotFoundError, PermissionError, and yaml.YAMLError, then rethrows with a clear message including config_path.
  • In nemo_skills/training/nemo_rl/start_sft.py (lines 108–109), extract the KeyError message into a msg variable before raising.
🧰 Tools
🪛 Ruff (0.13.3)

108-108: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In nemo_skills/prompt/utils.py (around the load_config function) wrap the file
open and yaml.safe_load calls in a try/except that catches FileNotFoundError,
PermissionError, and yaml.YAMLError and rethrows a new exception (or raise) with
a clear message that includes the config_path; ensure the original exception is
chained. In nemo_skills/training/nemo_rl/start_sft.py around lines 104–109,
extract the KeyError message into a local variable msg (e.g. msg = f"...") and
raise KeyError(msg) instead of inlining the formatted string in the raise
statement.


# Train split
self.formatted_ds = {
"train": self.load_or_process_split(train_ds_path, "train"),
Expand Down Expand Up @@ -130,11 +139,22 @@ def load_or_process_split(self, path: str, split_name: str) -> Dataset:
print(f"[Map] Processing {split_name} dataset from: {path}")
dataset = load_dataset("json", data_files=str(path))["train"]

current_input_key = self.input_key
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Dataset Cache Invalidation Overlooks Template and Key

The PromptResponseDataset cache invalidation is incomplete. It only uses dataset file size, but processing also depends on the input_template_path and input_key. This can cause stale processed data to be loaded if these parameters change.

Fix in Cursor Fix in Web

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wasiahmad this is an important thing to fix if you want to merge this

if self.input_template:
assert "messages" not in dataset.column_names
dataset = dataset.map(
self.apply_input_template,
batched=True,
num_proc=self.num_proc,
)
current_input_key = "formatted_input"

if "messages" not in dataset.column_names:
dataset = dataset.map(
self.add_messages_key,
batched=True,
num_proc=self.num_proc,
fn_kwargs={"input_key": current_input_key},
)

# Save dataset + new size signature
Expand All @@ -146,17 +166,26 @@ def load_or_process_split(self, path: str, split_name: str) -> Dataset:
print(f"[Cache] Saved {split_name} dataset to: {cache_dir}")
return dataset

def add_messages_key(self, examples: dict[str, list[Any]]) -> dict[str, list[list[dict[str, Any]]]]:
def add_messages_key(
self, examples: dict[str, list[Any]], input_key: str
) -> dict[str, list[list[dict[str, Any]]]]:
return {
"messages": [
[
{"role": "user", "content": input_},
{"role": "assistant", "content": output},
]
for input_, output in zip(examples[self.input_key], examples[self.output_key])
for input_, output in zip(examples[input_key], examples[self.output_key])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Dataset Fails When Template Path Missing

If input_key specifies multiple fields using semicolons but input_template_path is not provided, PromptResponseDataset attempts to use the entire semicolon-separated string as a single column name in add_messages_key. This results in a KeyError as no such column exists in the dataset.

Fix in Cursor Fix in Web

]
}

def apply_input_template(self, examples: dict[str, list[Any]]) -> dict[str, list[str]]:
keys = [k.strip() for k in self.input_key.split(";")]
examples["formatted_input"] = [
self.input_template.format(**{k: examples[k][i] for k in keys}) for i in range(len(examples[keys[0]]))
Comment on lines +184 to +185
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accessing examples[k][i] will fail if any key k doesn't exist in the dataset

Suggested change
examples["formatted_input"] = [
self.input_template.format(**{k: examples[k][i] for k in keys}) for i in range(len(examples[keys[0]]))
formatted_inputs = []
for i in range(len(examples[keys[0]])):
format_dict = {k: examples[k][i] for k in keys}
formatted_inputs.append(self.input_template.format(**format_dict))
examples["formatted_input"] = formatted_inputs

]
return examples
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Template Application Fails on Invalid Keys

The apply_input_template method can raise KeyError or IndexError. This occurs because it doesn't validate the keys derived from input_key, leading to issues if input_key contains empty segments after splitting and stripping, or if a derived key is missing from the dataset's examples.

Fix in Cursor Fix in Web



def parse_args():
"""Parse command line arguments."""
Expand Down Expand Up @@ -235,6 +264,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
data_config["input_key"],
data_config["output_key"],
force_reprocess=data_config.get("force_reprocess", False),
input_template_path=data_config.get("input_template_path", None),
)
print(f" ✓ Training dataset loaded with {len(data.formatted_ds['train'])} samples.")
if data.formatted_ds["validation"] is not None:
Expand Down