diff --git a/nemo_skills/dataset/asr-leaderboard/__init__.py b/nemo_skills/dataset/asr-leaderboard/__init__.py index 7c3c2661ee..47e9be94f3 100644 --- a/nemo_skills/dataset/asr-leaderboard/__init__.py +++ b/nemo_skills/dataset/asr-leaderboard/__init__.py @@ -13,10 +13,10 @@ # limitations under the License. # Settings that define how evaluation should be done by default (all can be changed from cmdline) -# Uses the audio evaluator which computes WER with Whisper-style text normalization -# Data samples should have task_type="ASR" for proper WER calculation + +# Uses the audio evaluator which computes WER with HuggingFace leaderboard preprocessing DATASET_GROUP = "speechlm" METRICS_TYPE = "audio" -EVAL_ARGS = "++eval_type=audio" -GENERATION_ARGS = "++prompt_format=openai" +EVAL_ARGS = "++eval_type=audio ++eval_config.normalization_mode=hf_leaderboard" +GENERATION_ARGS = "++prompt_format=openai ++enable_audio=true" diff --git a/nemo_skills/dataset/audiobench/judge/__init__.py b/nemo_skills/dataset/audiobench/judge/__init__.py index 62e48d4ec6..2f444cff13 100644 --- a/nemo_skills/dataset/audiobench/judge/__init__.py +++ b/nemo_skills/dataset/audiobench/judge/__init__.py @@ -26,8 +26,8 @@ DATASET_GROUP = "speechlm" METRICS_TYPE = "audio" DEFAULT_SPLIT = "test" -GENERATION_ARGS = "++prompt_format=openai " -EVAL_ARGS = "++eval_type=audio " +GENERATION_ARGS = "++prompt_format=openai ++enable_audio=true" +EVAL_ARGS = "++eval_type=audio ++eval_config.normalization_mode=audiobench" # Judge configuration matching AudioBench official implementation # Using Llama-3.1-70B with vllm (can be overridden in run scripts) diff --git a/nemo_skills/dataset/audiobench/nonjudge/__init__.py b/nemo_skills/dataset/audiobench/nonjudge/__init__.py index d26668ce8f..a8aff69039 100644 --- a/nemo_skills/dataset/audiobench/nonjudge/__init__.py +++ b/nemo_skills/dataset/audiobench/nonjudge/__init__.py @@ -25,7 +25,7 @@ METRICS_TYPE = "audio" # Evaluation settings -EVAL_ARGS = "++eval_type=audio " +EVAL_ARGS = "++eval_type=audio ++eval_config.normalization_mode=audiobench" # Generation settings - OpenAI format for audio-language models -GENERATION_ARGS = "++prompt_format=openai " +GENERATION_ARGS = "++prompt_format=openai ++enable_audio=true" diff --git a/nemo_skills/dataset/librispeech-pc/__init__.py b/nemo_skills/dataset/librispeech-pc/__init__.py index 28b02d9656..2e77558ed0 100644 --- a/nemo_skills/dataset/librispeech-pc/__init__.py +++ b/nemo_skills/dataset/librispeech-pc/__init__.py @@ -25,5 +25,5 @@ EVAL_SPLIT = "test-clean" -EVAL_ARGS = "++eval_type=audio " -GENERATION_ARGS = "++prompt_format=openai " +EVAL_ARGS = "++eval_type=audio" +GENERATION_ARGS = "++prompt_format=openai ++enable_audio=true" diff --git a/nemo_skills/dataset/mmau-pro/closed_form/__init__.py b/nemo_skills/dataset/mmau-pro/closed_form/__init__.py index 4e3b424d84..0181595a35 100644 --- a/nemo_skills/dataset/mmau-pro/closed_form/__init__.py +++ b/nemo_skills/dataset/mmau-pro/closed_form/__init__.py @@ -15,7 +15,8 @@ METRICS_TYPE = "mmau_pro_closed_form" SCORE_MODULE = "nemo_skills.evaluation.metrics.mmau_pro_metrics" -GENERATION_ARGS = "++prompt_format=openai" +GENERATION_ARGS = "++prompt_format=openai ++enable_audio=true" +EVAL_ARGS = "++eval_type=mmau-pro" # NVEmbed judge configuration for closed-form evaluation JUDGE_PIPELINE_ARGS = { diff --git a/nemo_skills/dataset/mmau-pro/open_ended/__init__.py b/nemo_skills/dataset/mmau-pro/open_ended/__init__.py index 22773d6fed..48912ce82a 100644 --- a/nemo_skills/dataset/mmau-pro/open_ended/__init__.py +++ b/nemo_skills/dataset/mmau-pro/open_ended/__init__.py @@ -15,7 +15,7 @@ # Open-ended questions evaluated with LLM judge (Qwen) METRICS_TYPE = "mmau_pro_open_ended" SCORE_MODULE = "nemo_skills.evaluation.metrics.mmau_pro_metrics" -GENERATION_ARGS = "++prompt_format=openai" +GENERATION_ARGS = "++prompt_format=openai ++enable_audio=true" # Judge configuration for open-ended evaluation using NVIDIA API JUDGE_PIPELINE_ARGS = { @@ -23,4 +23,4 @@ "server_type": "openai", "server_address": "https://integrate.api.nvidia.com/v1", } -JUDGE_ARGS = "++prompt_config=judge/speechlm ++generation_key=judgement" +JUDGE_ARGS = "++prompt_config=judge/mmau-pro ++generation_key=judgement" diff --git a/nemo_skills/evaluation/evaluator/audio.py b/nemo_skills/evaluation/evaluator/audio.py index 35149ebd59..c37060e3b6 100644 --- a/nemo_skills/evaluation/evaluator/audio.py +++ b/nemo_skills/evaluation/evaluator/audio.py @@ -17,7 +17,6 @@ import asyncio import logging import re -from functools import lru_cache from typing import Any import numpy as np @@ -34,6 +33,55 @@ class AudioEvaluatorConfig(BaseEvaluatorConfig): prompt_config: str = "eval/speechlm/audio" normalize_asr_pc_standard_wer: bool = True + strip_helpful_prefixes: bool = True + normalization_mode: str = "standard" # "standard", "audiobench", "hf_leaderboard", or "none" + + +# Known model failure responses that should be treated as empty transcriptions +_FAILURE_RESPONSES = [ + r"the speech is in audio format and needs to be transcribed", + r"i do not have access to audio", + r"i cannot access audio", + r"i'm sorry.*i do not have access", + r"as an ai language model.*i do not have access", +] + + +def strip_helpful_prefixes(text: str) -> str: + """Strip ASR response prefixes like 'The audio says: ...' for accurate WER. + + Also removes SRT subtitle timestamps that can appear in vLLM chunked audio generation. + """ + result = text.strip() + + # Check for model failure responses + for failure_pattern in _FAILURE_RESPONSES: + if re.search(failure_pattern, result, flags=re.IGNORECASE): + return "" + + # Remove SRT subtitle timestamps (vLLM chunked audio artifact) + result = re.sub(r"\d+\s+\d{2}:\d{2}:\d{2},\d{3}\s+-->\s+\d{2}:\d{2}:\d{2},\d{3}\s+", "", result) + result = re.sub(r"\d{2}:\d{2}:\d{2},\d{3}\s+-->\s+\d{2}:\d{2}:\d{2},\d{3}\s*", "", result) + result = re.sub(r"\\n\d+\s+(?=\d{2}:\d{2})", " ", result) + result = re.sub(r"\n\d+\s+(?=\d{2}:\d{2})", " ", result) + + # Extract from double quotes + match = re.search(r'"((?:\\.|[^"\\])*)"', result) + if match: + result = match.group(1) + + # Handle colon-quote patterns + if ":'" in result: + result = "'" + result.split(":'")[1] + elif ": '" in result: + result = "'" + result.split(": '")[1] + + # Greedy single quote extraction + match = re.search(r"'(.*)'", result) + if match: + result = match.group(1) + + return result.strip() def normalize_whitespace(text: str) -> str: @@ -88,8 +136,17 @@ def calculate_per(reference: str, hypothesis: str) -> float: return per -def evaluate_asr_pc(reference: str, hypothesis: str, normalize_standard_wer: bool = True) -> dict[str, Any]: - """Evaluate ASR-PC: computes WER, WER_C, WER_PC, PER.""" +def evaluate_asr_pc( + reference: str, hypothesis: str, normalize_standard_wer: bool = True, normalization_mode: str = "standard" +) -> dict[str, Any]: + """Evaluate ASR-PC: computes WER, WER_C, WER_PC, PER. + + Args: + reference: Ground truth transcription. + hypothesis: Model output transcription. + normalize_standard_wer: Whether to apply normalization to standard WER. + normalization_mode: Normalization mode for standard WER ("standard", "audiobench", "hf_leaderboard", "none"). + """ import jiwer ref_pc = normalize_whitespace(reference) @@ -104,8 +161,8 @@ def evaluate_asr_pc(reference: str, hypothesis: str, normalize_standard_wer: boo wer_c = jiwer.wer(ref_c, hyp_c) if normalize_standard_wer: - ref_std = preprocess_asr_text(reference) - hyp_std = preprocess_asr_text(hypothesis) + ref_std = preprocess_asr_text(reference, mode=normalization_mode) + hyp_std = preprocess_asr_text(hypothesis, mode=normalization_mode) else: ref_std = normalize_whitespace(re.sub(r"[^\w\s]", "", reference.lower())) hyp_std = normalize_whitespace(re.sub(r"[^\w\s]", "", hypothesis.lower())) @@ -124,29 +181,149 @@ def evaluate_asr_pc(reference: str, hypothesis: str, normalize_standard_wer: boo } -@lru_cache(maxsize=1) -def _get_english_normalizer(): - """Lazily initialize and cache the English text normalizer.""" +def _normalize_digits_to_words(text: str) -> str: + """Convert standalone digits to words (e.g., '1' -> 'one').""" + digits_to_words = { + "0": "zero", + "1": "one", + "2": "two", + "3": "three", + "4": "four", + "5": "five", + "6": "six", + "7": "seven", + "8": "eight", + "9": "nine", + "10": "ten", + "11": "eleven", + "12": "twelve", + "13": "thirteen", + "14": "fourteen", + "15": "fifteen", + "16": "sixteen", + "17": "seventeen", + "18": "eighteen", + "19": "nineteen", + "20": "twenty", + "30": "thirty", + "40": "forty", + "50": "fifty", + "60": "sixty", + "70": "seventy", + "80": "eighty", + "90": "ninety", + } + for digit, word in digits_to_words.items(): + text = re.sub(r"\b" + digit + r"\b", word, text) + return text + + +def _expand_contractions(text: str) -> str: + """Expand common English contractions (e.g., "I'm" -> "I am").""" + contractions = { + "i'm": "i am", + "you're": "you are", + "he's": "he is", + "she's": "she is", + "it's": "it is", + "we're": "we are", + "they're": "they are", + "i've": "i have", + "you've": "you have", + "we've": "we have", + "they've": "they have", + "isn't": "is not", + "aren't": "are not", + "wasn't": "was not", + "weren't": "were not", + "hasn't": "has not", + "haven't": "have not", + "hadn't": "had not", + "doesn't": "does not", + "don't": "do not", + "didn't": "did not", + "that's": "that is", + } + for contraction, expanded in contractions.items(): + text = re.sub(r"\b" + contraction + r"\b", expanded, text) + return text + + +def _remove_non_speech_elements(text: str) -> str: + """Remove filler words (uh, um, er, ah).""" + non_speech_patterns = r"\b(uh|umm|um|er|ah)\b" + return re.sub(non_speech_patterns, "", text) + + +VALID_NORMALIZATION_MODES = ("standard", "audiobench", "hf_leaderboard", "none") + + +def preprocess_asr_text(text: str, mode: str = "standard") -> str: + """Normalize ASR text for WER calculation. + + Args: + text: Raw text. + mode: Normalization mode: + - "standard": Whisper normalization (default) + - "audiobench": Full AudioBench normalization + - "hf_leaderboard": HuggingFace leaderboard style + - "none": No normalization (whitespace only) + """ + if mode not in VALID_NORMALIZATION_MODES: + raise ValueError( + f"Invalid normalization_mode '{mode}'. Available options: {', '.join(VALID_NORMALIZATION_MODES)}" + ) + + if mode == "none": + return re.sub(r"\s+", " ", text).strip() + + if mode == "hf_leaderboard": + import unicodedata + + text = unicodedata.normalize("NFC", text) + text = text.lower() + text = re.sub(r"[^\w\s]", "", text) + return re.sub(r"\s+", " ", text).strip() + + # "standard" and "audiobench" both start with whisper normalization from whisper_normalizer.english import EnglishTextNormalizer - return EnglishTextNormalizer() + text = text.lower() + text = EnglishTextNormalizer()(text) + + if mode == "audiobench": + # Additional audiobench-specific normalization + import jiwer + + text = _normalize_digits_to_words(text) + text = _expand_contractions(text) + text = re.sub(r"(\[|\(|\{|\<)[^\(\)\\n\[\]]*(\]|\)|\}|\>)", "", text) + jiwer_process = jiwer.Compose( + [ + jiwer.RemoveMultipleSpaces(), + jiwer.ExpandCommonEnglishContractions(), + jiwer.RemoveKaldiNonWords(), + jiwer.RemovePunctuation(), + ] + ) + text = jiwer_process(text) + text = _remove_non_speech_elements(text) + return re.sub(r"\s+", " ", text).strip() -def preprocess_asr_text(text: str) -> str: - """Apply Whisper-style normalization (lowercase, remove brackets, normalize whitespace).""" - return _get_english_normalizer()(text) +def evaluate_asr(reference: str, hypothesis: str, normalization_mode: str = "standard") -> dict[str, Any]: + """Evaluate ASR: computes WER with normalization. -def evaluate_asr(reference: str, hypothesis: str) -> dict[str, Any]: - """Evaluate ASR: computes WER with Whisper normalization.""" + Args: + reference: Ground truth transcription. + hypothesis: Model output transcription. + normalization_mode: "standard", "audiobench", "hf_leaderboard", or "none". + """ import jiwer - ref = preprocess_asr_text(reference) - hyp = preprocess_asr_text(hypothesis) - - # Store normalized texts before empty substitution - text = ref - pred_text = hyp + ref = preprocess_asr_text(reference, mode=normalization_mode) + hyp = preprocess_asr_text(hypothesis, mode=normalization_mode) if not ref: ref = "empty" @@ -158,8 +335,8 @@ def evaluate_asr(reference: str, hypothesis: str) -> dict[str, Any]: return { "wer": wer_score, "is_correct": wer_score < 0.5, - "text": text, - "pred_text": pred_text, + "text": ref, + "pred_text": hyp, } @@ -311,10 +488,14 @@ def evaluate_sample(sample: dict[str, Any], config: AudioEvaluatorConfig) -> dic """Evaluate single sample based on task_type. Returns dict of updates to merge.""" updates = {} task_type = sample.get("task_type", "unknown") - generation = sample.get("generation", "").strip() + generation = sample["generation"].strip() expected_answer = sample.get("expected_answer", "").strip() - if task_type in ["ASR", "ASR-PC", "AST", "Translation", "CER"] and not generation: + # Strip helpful prefixes for ASR tasks (e.g., "The audio says: ...") + if config.strip_helpful_prefixes: + generation = strip_helpful_prefixes(generation) + + if task_type in ["ASR", "ASR-PC", "ASR_LEADERBOARD", "AST", "Translation", "CER"] and not generation: base = { "is_correct": False, "error": "missing_generation", @@ -327,13 +508,25 @@ def evaluate_sample(sample: dict[str, Any], config: AudioEvaluatorConfig) -> dic return {**base, "wer": 1.0} if task_type == "ASR-PC": + mode = config.normalization_mode if config.apply_whisper_normalization else "none" metrics = evaluate_asr_pc( - expected_answer, generation, normalize_standard_wer=config.normalize_asr_pc_standard_wer + expected_answer, + generation, + normalize_standard_wer=config.normalize_asr_pc_standard_wer, + normalization_mode=mode, ) updates.update(metrics) elif task_type == "ASR": - metrics = evaluate_asr(expected_answer, generation) + mode = config.normalization_mode if config.apply_whisper_normalization else "none" + metrics = evaluate_asr(expected_answer, generation, normalization_mode=mode) + updates.update(metrics) + updates["predicted_answer"] = generation + + elif task_type == "ASR_LEADERBOARD": + # ASR_LEADERBOARD uses normalization_mode from config (default hf_leaderboard set in dataset init) + mode = config.normalization_mode if config.apply_whisper_normalization else "none" + metrics = evaluate_asr(expected_answer, generation, normalization_mode=mode) updates.update(metrics) elif task_type in ["AST", "Translation"]: diff --git a/nemo_skills/evaluation/evaluator/mmau_pro.py b/nemo_skills/evaluation/evaluator/mmau_pro.py index b78f9f47ee..88f4bd2675 100644 --- a/nemo_skills/evaluation/evaluator/mmau_pro.py +++ b/nemo_skills/evaluation/evaluator/mmau_pro.py @@ -31,7 +31,7 @@ def eval_mmau_pro(cfg): This evaluator handles instruction following evaluation for MMAU-Pro benchmark. Other question types are handled by different evaluation methods: - Closed-form questions: Evaluated by nvembed_judge.py using NVEmbed similarity matching - - Open-ended questions: Evaluated by LLM judge (Qwen) using judge/speechlm prompt config + - Open-ended questions: Evaluated by LLM judge (Qwen) using judge/mmau-pro prompt config """ eval_config = BaseEvaluatorConfig(**cfg) diff --git a/nemo_skills/evaluation/metrics/map_metrics.py b/nemo_skills/evaluation/metrics/map_metrics.py index f7dfafb85e..4cd3c2d4a7 100644 --- a/nemo_skills/evaluation/metrics/map_metrics.py +++ b/nemo_skills/evaluation/metrics/map_metrics.py @@ -55,6 +55,7 @@ "answer-judgement": AnswerJudgementMetrics, "arena": ArenaMetrics, "audio": AudioMetrics, + "speechlm": AudioMetrics, # Alias for backward compatibility "bfcl": BFCLMetrics, "bird": BirdMetrics, "evalplus": EvalPlusMetrics, diff --git a/nemo_skills/evaluation/metrics/mmau_pro_metrics.py b/nemo_skills/evaluation/metrics/mmau_pro_metrics.py index f079049cc1..acc7076ad6 100644 --- a/nemo_skills/evaluation/metrics/mmau_pro_metrics.py +++ b/nemo_skills/evaluation/metrics/mmau_pro_metrics.py @@ -13,14 +13,58 @@ # limitations under the License. import logging +import re + +import numpy as np from nemo_skills.evaluation.metrics.base import BaseMetrics, as_int, as_percentage -from nemo_skills.evaluation.metrics.utils import is_correct_judgement from nemo_skills.utils import get_logger_name LOG = logging.getLogger(get_logger_name(__file__)) +def extract_multicriteria_scores(judgement_text: str) -> dict[str, float]: + """Extract multi-criteria scores (1-5 scale) from LLM judge evaluation. + + Expected format: + CORRECTNESS: [score] - [justification] + RELEVANCE: [score] - [justification] + COMPLETENESS: [score] - [justification] + CLARITY: [score] - [justification] + OVERALL: [score] - [overall assessment] + + Returns: + Dictionary with keys: correctness, relevance, completeness, clarity, overall + Defaults to 3.0 if score not found. + """ + scores = {} + found_overall = False + + patterns = { + "correctness": r"CORRECTNESS:\s*(\d+(?:\.\d+)?)", + "relevance": r"RELEVANCE:\s*(\d+(?:\.\d+)?)", + "completeness": r"COMPLETENESS:\s*(\d+(?:\.\d+)?)", + "clarity": r"CLARITY:\s*(\d+(?:\.\d+)?)", + "overall": r"OVERALL:\s*(\d+(?:\.\d+)?)", + } + + for criterion, pattern in patterns.items(): + match = re.search(pattern, judgement_text, re.IGNORECASE) + if match: + scores[criterion] = float(match.group(1)) + if criterion == "overall": + found_overall = True + else: + scores[criterion] = 3.0 + + # Fallback: compute overall only if not explicitly provided by judge + if not found_overall: + criteria_scores = [scores.get(k, 3.0) for k in ["correctness", "relevance", "completeness", "clarity"]] + scores["overall"] = sum(criteria_scores) / len(criteria_scores) + + return scores + + class MMAUProMetrics(BaseMetrics): """Metrics class for MMAU-Pro benchmark (all subgroups).""" @@ -28,16 +72,24 @@ def __init__(self, compute_no_answer: bool = True, max_k: int = 1): super().__init__(compute_no_answer=compute_no_answer) self.max_k = max_k + # Track multi-criteria scores for open-ended questions (1-5 scale) + self.multicriteria_scores = { + "correctness": [], + "relevance": [], + "completeness": [], + "clarity": [], + "overall": [], + } + def _get_score_dict(self, prediction: dict) -> dict[str, bool | int | float]: """Extract correctness scores from prediction.""" score_dict = {} - # Open-ended: extract from judge result + # Open-ended: use LLM judge correctness score >= 3 as correct if "judgement" in prediction: - judge_result = is_correct_judgement(prediction["judgement"]) - score_dict["judge_correct"] = judge_result - score_dict["correct"] = judge_result - # Closed-form and instruction following: use is_correct + multicriteria = extract_multicriteria_scores(prediction["judgement"]) + score_dict["correct"] = multicriteria.get("correctness", 3.0) >= 3.0 + # Closed-form / instruction-following: use binary correctness elif "is_correct" in prediction: score_dict["correct"] = prediction["is_correct"] else: @@ -58,24 +110,61 @@ def get_incorrect_sample(self, prediction: dict) -> dict: def update(self, predictions): """Update metrics with new predictions.""" super().update(predictions) - predicted_answers = [pred.get("generation", None).strip() or None for pred in predictions] + + predicted_answers = [(pred.get("generation") or "").strip() or None for pred in predictions] self._compute_pass_at_k(predictions=predictions, predicted_answers=predicted_answers) self._compute_majority_at_k(predictions=predictions, predicted_answers=predicted_answers) + # Collect multi-criteria scores for open-ended questions + for pred in predictions: + if "judgement" in pred: + multicriteria = extract_multicriteria_scores(pred["judgement"]) + for criterion in self.multicriteria_scores: + self.multicriteria_scores[criterion].append(multicriteria.get(criterion, 3.0)) + def get_metrics(self): """Get computed metrics.""" metrics_dict = super().get_metrics() + for agg_mode, agg_metrics in metrics_dict.items(): - # Ensure avg_tokens is always present for MMAU-Pro + # Ensure avg_tokens is present if "avg_tokens" not in agg_metrics: agg_metrics["avg_tokens"] = 0 if "no_answer" in agg_metrics: agg_metrics["no_answer"] = agg_metrics["no_answer"] / 2.0 - # Set success_rate from correct or judge_correct - if "judge_correct" in agg_metrics: - agg_metrics["success_rate"] = agg_metrics["judge_correct"] + + # Add multi-criteria averages for open-ended (convert 1-5 scale to percentage) + if self.multicriteria_scores["overall"]: + for criterion in self.multicriteria_scores: + scores = self.multicriteria_scores[criterion] + if scores: + # Convert 1-5 scale to 0-100 percentage scale + avg_score = np.mean(scores) + std_score = np.std(scores) + agg_metrics[f"avg_{criterion}"] = (avg_score / 5.0) * 100 + agg_metrics[f"std_{criterion}"] = (std_score / 5.0) * 100 + + # Set correct and success_rate to avg_correctness for open-ended + agg_metrics["correct"] = agg_metrics["avg_correctness"] + agg_metrics["success_rate"] = agg_metrics["avg_correctness"] + + # Calculate good/poor response rates based on overall >= 4 or <= 2 + overall_scores = self.multicriteria_scores["overall"] + good_responses = sum(1 for score in overall_scores if score >= 4.0) + poor_responses = sum(1 for score in overall_scores if score <= 2.0) + + agg_metrics["good_response_rate"] = (good_responses / len(overall_scores)) * 100 + agg_metrics["poor_response_rate"] = (poor_responses / len(overall_scores)) * 100 + + # For closed-form / instruction-following: use binary correctness elif "correct" in agg_metrics: agg_metrics["success_rate"] = agg_metrics["correct"] + + # Round all numeric values to 2 decimal places + for key, value in agg_metrics.items(): + if isinstance(value, float) and not isinstance(value, bool): + agg_metrics[key] = round(value, 2) + return metrics_dict def metrics_to_print(self): @@ -87,5 +176,20 @@ def metrics_to_print(self): } if self.compute_no_answer: base_metrics["no_answer"] = as_percentage + + # Add multi-criteria metrics for open-ended questions (now in percentage format) + if self.multicriteria_scores["overall"]: + base_metrics.update( + { + "avg_overall": as_percentage, + "avg_correctness": as_percentage, + "avg_relevance": as_percentage, + "avg_completeness": as_percentage, + "avg_clarity": as_percentage, + "good_response_rate": as_percentage, + "poor_response_rate": as_percentage, + } + ) + base_metrics["num_entries"] = as_int return base_metrics diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index 8ba3dd7555..5cf69af488 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -205,6 +205,15 @@ class GenerationTaskConfig: # If True, will enable litellm disk cache (useful for keeping intermediate results in case of job timelimit failures) enable_litellm_cache: bool = False + # List of content types to drop from messages (e.g., base64 audio) to keep output files smaller + drop_content_types: list[str] = field(default_factory=lambda: ["audio_url"]) + + # Audio configuration - set by benchmarks that need audio processing (mmau-pro, audiobench, etc.) + enable_audio: bool = False # Enable audio preprocessing (set by benchmark configs) + enable_audio_chunking: bool = True + audio_chunk_task_types: list[str] | None = None # If None, chunk all task types; if specified, only chunk these + chunk_audio_threshold_sec: int = 30 # Duration in seconds for each audio chunk + # Evaluation setup if requested. If eval_type is set to None, evaluation is skipped eval_type: str | None = None # "lean4-proof", "math", etc. eval_config: dict = field(default_factory=dict) # Config for the evaluator @@ -408,9 +417,37 @@ def setup_llm(self): output_dir = str(Path(self.cfg.output_file).parent) + # Determine if audio processing is needed + # Benchmarks that need audio set enable_audio=true in their GENERATION_ARGS + needs_audio = self.cfg.enable_audio + + # Build server config, potentially switching to vllm_multimodal for audio tasks + server_config = dict(self.cfg.server) + if needs_audio and server_config.get("server_type") not in ["vllm", "vllm_multimodal"]: + LOG.warning( + f"enable_audio is set but server_type is '{server_config.get('server_type')}'. " + "Audio processing is only supported for vllm_multimodal server types. " + "Audio will not be processed." + ) + if needs_audio and server_config.get("server_type") in [ + "vllm", + "vllm_multimodal", + ]: # helps with backward compatibility + if server_config.get("server_type") == "vllm": + LOG.warning("Auto-switching server_type from 'vllm' to 'vllm_multimodal' for audio processing") + server_config["server_type"] = "vllm_multimodal" + # Pass audio chunking config + server_config.update( + { + "enable_audio_chunking": self.cfg.enable_audio_chunking, + "audio_chunk_task_types": self.cfg.audio_chunk_task_types, + "chunk_audio_threshold_sec": self.cfg.chunk_audio_threshold_sec, + } + ) + if self.cfg.code_execution: llm = get_code_execution_model( - **self.cfg.server, + **server_config, tokenizer=self.tokenizer, sandbox=self.sandbox, data_dir=self.data_dir or "", @@ -418,7 +455,7 @@ def setup_llm(self): ) elif self.cfg.tool_modules is not None: llm = get_tool_calling_model( - **self.cfg.server, + **server_config, tool_modules=self.cfg.tool_modules, tool_overrides=self.cfg.tool_overrides, schema_overrides=self.cfg.schema_overrides, @@ -429,7 +466,7 @@ def setup_llm(self): ) else: llm = get_model( - **self.cfg.server, tokenizer=self.tokenizer, data_dir=self.data_dir or "", output_dir=output_dir + **server_config, tokenizer=self.tokenizer, data_dir=self.data_dir or "", output_dir=output_dir ) if self.cfg.parallel_thinking.mode is not None: @@ -564,6 +601,25 @@ def dump_outputs(self, outputs, data_points, fout): for output in outputs: fout.write(json.dumps(output) + "\n") + def drop_fields_from_messages(self, output): + """Remove specified content types from messages to keep output files smaller. + + Filters out content types listed in drop_content_types config f.e. base64 data. + """ + # Skip if output doesn't have messages (e.g., text completion mode or error cases) + if "messages" not in output: + return + + for message in output["messages"]: + # Skip if content is not a list (e.g., string content in system messages) + if not isinstance(message.get("content"), list): + continue + + # Filter out content types specified in drop_content_types config + message["content"] = [ + content for content in message["content"] if content.get("type") not in self.cfg.drop_content_types + ] + async def postprocess_single_output(self, output, original_data_point): # to make it easier to follow up with other generations and limit accidental errors, we are adding # all of the original data to the output file alongside the new generations @@ -579,6 +635,10 @@ async def postprocess_single_output(self, output, original_data_point): for key in output: original_data_point.pop(key, None) output.update(original_data_point) + + # Drop specified content types (f.e base64 audio) from output to reduce file size + self.drop_fields_from_messages(output) + if self.cfg.parse_reasoning: parse_reasoning( output, diff --git a/nemo_skills/inference/model/__init__.py b/nemo_skills/inference/model/__init__.py index 595d8fd3ee..1af4faef16 100644 --- a/nemo_skills/inference/model/__init__.py +++ b/nemo_skills/inference/model/__init__.py @@ -19,6 +19,15 @@ # NIM models (speech) from .asr_nim import ASRNIMModel + +# Audio utilities +from .audio_utils import ( + audio_file_to_base64, + chunk_audio, + load_audio_file, + make_audio_content_block, + save_audio_chunk_to_base64, +) from .azure import AzureOpenAIModel # Base classes diff --git a/nemo_skills/inference/model/audio_utils.py b/nemo_skills/inference/model/audio_utils.py new file mode 100644 index 0000000000..02c8eaf459 --- /dev/null +++ b/nemo_skills/inference/model/audio_utils.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared audio utility functions for multimodal models. + +This module provides helper functions for audio processing that can be used +by VLLMMultimodalModel and other audio-capable model classes. +""" + +import base64 +import logging +import os + +from nemo_skills.utils import get_logger_name + +LOG = logging.getLogger(get_logger_name(__file__)) + + +def audio_file_to_base64(audio_file_path: str) -> str: + """Encodes an audio file into a base64 string. + + Args: + audio_file_path: Path to the audio file to encode. + + Returns: + Base64 encoded string of the audio file contents. + """ + with open(audio_file_path, "rb") as audio_file: + audio_content = audio_file.read() + return base64.b64encode(audio_content).decode("utf-8") + + +def load_audio_file(audio_file_path: str): + """Load audio file and return array and sampling rate. + + Args: + audio_file_path: Path to the audio file to load. + + Returns: + Tuple of (audio_array, sampling_rate). + """ + import soundfile as sf + + audio_array, sampling_rate = sf.read(audio_file_path) + return audio_array, sampling_rate + + +def chunk_audio(audio_array, sampling_rate, chunk_duration_sec=30, min_chunk_duration_sec=0.5): + """Chunk audio array into segments of specified duration. + + Args: + audio_array: Audio data as numpy array. + sampling_rate: Sampling rate in Hz. + chunk_duration_sec: Duration of each chunk in seconds. + min_chunk_duration_sec: Minimum duration for last chunk (shorter chunks are merged). + + Returns: + List of audio chunks (numpy arrays). + """ + import numpy as np + + chunk_samples = int(chunk_duration_sec * sampling_rate) + min_chunk_samples = int(min_chunk_duration_sec * sampling_rate) + + # Validate minimum audio length + if len(audio_array) < min_chunk_samples: + raise ValueError( + f"Audio too short: {len(audio_array) / sampling_rate:.2f}s < minimum {min_chunk_duration_sec}s" + ) + + num_chunks = int(np.ceil(len(audio_array) / chunk_samples)) + + chunks = [] + for i in range(num_chunks): + start = i * chunk_samples + end = min((i + 1) * chunk_samples, len(audio_array)) + chunk = audio_array[start:end] + + # Merge tiny trailing chunks with previous chunk to avoid empty audio errors + if len(chunk) < min_chunk_samples and chunks: + chunks[-1] = np.concatenate([chunks[-1], chunk]) + else: + chunks.append(chunk) + + return chunks + + +def save_audio_chunk_to_base64(audio_chunk, sampling_rate) -> str: + """Save audio chunk to temporary file and convert to base64. + + Args: + audio_chunk: Audio data as numpy array. + sampling_rate: Sampling rate in Hz. + + Returns: + Base64 encoded audio string. + """ + import tempfile + + import soundfile as sf + + # Create temporary file + tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp_file.name + + try: + tmp_file.close() + sf.write(tmp_path, audio_chunk, sampling_rate) + + # Read and encode + with open(tmp_path, "rb") as f: + audio_content = f.read() + encoded = base64.b64encode(audio_content).decode("utf-8") + finally: + # Clean up + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + return encoded + + +def make_audio_content_block(base64_audio: str, audio_format: str = "audio_url") -> dict: + """Create an audio content block in the specified format. + + Args: + base64_audio: Base64-encoded audio data. + audio_format: Format to use: + - "audio_url": Data URI format for vLLM/Qwen + - "input_audio": OpenAI native format for NVIDIA API/Gemini/Azure + + Returns: + Audio content block dict for API request. + """ + if audio_format == "input_audio": + # OpenAI native format (works with NVIDIA API / Gemini / Azure) + return {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}} + else: + # Data URI format (works with vLLM / Qwen) + return {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}} diff --git a/nemo_skills/inference/model/vllm_multimodal.py b/nemo_skills/inference/model/vllm_multimodal.py index 0569c9efd9..c2285b0ccc 100644 --- a/nemo_skills/inference/model/vllm_multimodal.py +++ b/nemo_skills/inference/model/vllm_multimodal.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""VLLMMultimodalModel with support for audio input and output. + +This module provides a multimodal model class that handles: +- Audio INPUT: encoding audio files to base64, chunking long audio +- Audio OUTPUT: saving audio responses from the server to disk +""" + import base64 +import copy import json import logging import os @@ -20,6 +28,12 @@ from nemo_skills.utils import get_logger_name +from .audio_utils import ( + audio_file_to_base64, + chunk_audio, + load_audio_file, + save_audio_chunk_to_base64, +) from .vllm import VLLMModel LOG = logging.getLogger(get_logger_name(__file__)) @@ -29,15 +43,42 @@ class VLLMMultimodalModel(VLLMModel): - """VLLMModel with support for saving audio responses to disk. + """VLLMModel with support for audio input and output. + + Audio INPUT capabilities: + 1. Converts audio file paths to base64-encoded audio_url format + 2. Chunks long audio files for models with duration limits + 3. Aggregates results from chunked audio processing + + Audio OUTPUT capabilities: + 1. Saves audio responses from the server to disk / output_dir/audio/ + 2. Replaces the base64 data with the file path in the result - When the server returns audio in the response, this model will: - 1. Save the audio bytes to a file in output_dir/audio/ - 2. Replace the base64 data with the file path in the result """ - def __init__(self, **kwargs): + def __init__( + self, + enable_audio_chunking: bool = True, + audio_chunk_task_types: list[str] | None = None, + chunk_audio_threshold_sec: int = 30, + **kwargs, + ): + """Initialize VLLMMultimodalModel with audio I/O support. + + Args: + enable_audio_chunking: Master switch for audio chunking. + audio_chunk_task_types: If None, chunk all task types; if specified, only chunk these. + chunk_audio_threshold_sec: Audio duration threshold for chunking (in seconds). + **kwargs: Other parameters passed to VLLMModel/BaseModel. + """ super().__init__(**kwargs) + + # Audio INPUT config + self.enable_audio_chunking = enable_audio_chunking + self.audio_chunk_task_types = audio_chunk_task_types + self.chunk_audio_threshold_sec = chunk_audio_threshold_sec + + # Audio OUTPUT config self.output_audio_dir = None if self.output_dir: self.output_audio_dir = os.path.join(self.output_dir, "audio") @@ -108,3 +149,258 @@ def _process_audio_response(self, audio_data, response_id: str) -> dict: audio_info["data"] = audio_base64 return audio_info + + # ===================== + # Audio INPUT methods + # ===================== + + def content_text_to_list(self, message: dict) -> dict: + """Convert message content with audio to proper list format. + + Handles 'audio' or 'audios' keys in messages and converts them to + base64-encoded audio_url content items. + + CRITICAL: Audio must come BEFORE text for Qwen models to transcribe correctly. + + Args: + message: Message dict that may contain 'audio' or 'audios' fields. + + Returns: + Message dict with content converted to list format including audio. + """ + if "audio" not in message and "audios" not in message: + return message + + content = message.get("content", "") + if isinstance(content, str): + message["content"] = [{"type": "text", "text": content}] + elif isinstance(content, list): + message["content"] = content + else: + raise TypeError(f"Unexpected content type: {type(content)}") + + audio_items = [] + + if "audio" in message: + audio = message["audio"] + audio_path = os.path.join(self.data_dir, audio["path"]) + base64_audio = audio_file_to_base64(audio_path) + audio_message = {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}} + audio_items.append(audio_message) + del message["audio"] # Remove original audio field after conversion + elif "audios" in message: + for audio in message["audios"]: + audio_path = os.path.join(self.data_dir, audio["path"]) + base64_audio = audio_file_to_base64(audio_path) + audio_message = {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}} + audio_items.append(audio_message) + del message["audios"] # Remove original audios field after conversion + + # Insert audio items at the BEGINNING of content list (before text) + if audio_items: + message["content"] = audio_items + message["content"] + + return message + + def _preprocess_messages_for_model(self, messages: list[dict]) -> list[dict]: + """Preprocess messages - creates copies to avoid mutation. + + Note: /no_think suffix is passed through unchanged (handled by the model). + + Args: + messages: List of message dicts. + + Returns: + Copy of message dicts. + """ + return [copy.deepcopy(msg) for msg in messages] + + def _needs_audio_chunking(self, messages: list[dict], task_type: str = None) -> tuple[bool, str, float]: + """Check if audio in messages needs chunking. + + Args: + messages: List of message dicts. + task_type: Optional task type for chunking filtering. + + Returns: + Tuple of (needs_chunking, audio_path, duration). + """ + if not self.enable_audio_chunking: + return False, None, 0.0 + + # Check if task type should be chunked (if filter is specified) + if self.audio_chunk_task_types is not None: + if task_type not in self.audio_chunk_task_types: + return False, None, 0.0 + + # Find audio in messages + for msg in messages: + if msg.get("role") == "user": + audio_info = msg.get("audio") + if not audio_info: + audios = msg.get("audios", []) + audio_info = audios[0] if audios else {} + if audio_info and "path" in audio_info: + audio_path = os.path.join(self.data_dir, audio_info["path"]) + + if not os.path.exists(audio_path): + return False, None, 0.0 + + # Load audio to check duration + audio_array, sampling_rate = load_audio_file(audio_path) + duration = len(audio_array) / sampling_rate + + if duration > self.chunk_audio_threshold_sec: + return True, audio_path, duration + + return False, None, 0.0 + + async def _generate_with_chunking( + self, + messages: list[dict], + audio_path: str, + duration: float, + tokens_to_generate: int | None = None, + **kwargs, + ) -> dict: + """Generate by chunking long audio and aggregating results. + + Args: + messages: Original messages containing audio reference. + audio_path: Path to the audio file to chunk. + duration: Duration of audio in seconds. + tokens_to_generate: Max tokens per chunk. + **kwargs: Additional generation parameters. + + Returns: + Aggregated result with combined generation from all chunks. + """ + audio_array, sampling_rate = load_audio_file(audio_path) + chunks = chunk_audio(audio_array, sampling_rate, self.chunk_audio_threshold_sec) + + LOG.info(f"Chunking audio ({duration:.1f}s) into {len(chunks)} chunks of {self.chunk_audio_threshold_sec}s") + + if not chunks: + raise RuntimeError("No audio chunks generated - audio may be too short or invalid") + + chunk_results = [] + result = None + + # Track cumulative statistics across chunks + total_input_tokens = 0 + total_generated_tokens = 0 + total_time = 0.0 + + for chunk_idx, audio_chunk in enumerate(chunks): + chunk_messages = [] + + for msg in messages: + msg_copy = copy.deepcopy(msg) + + if msg_copy["role"] == "user" and ("audio" in msg_copy or "audios" in msg_copy): + chunk_base64 = save_audio_chunk_to_base64(audio_chunk, sampling_rate) + + content = msg_copy.get("content", "") + if isinstance(content, str): + text_content = [{"type": "text", "text": content}] + else: + text_content = content + + # Add audio chunk at the beginning (before text) + msg_copy["content"] = [ + {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{chunk_base64}"}} + ] + text_content + + # Remove original audio fields + msg_copy.pop("audio", None) + msg_copy.pop("audios", None) + + chunk_messages.append(msg_copy) + + # Preprocess messages (strip /no_think for Qwen) + chunk_messages = self._preprocess_messages_for_model(chunk_messages) + + # Generate for this chunk using parent's generate_async + result = await super().generate_async( + prompt=chunk_messages, tokens_to_generate=tokens_to_generate, **kwargs + ) + + # Sum statistics from each chunk + total_input_tokens += result.get("input_tokens", 0) + total_generated_tokens += result.get("generated_tokens", 0) + total_time += result.get("time_elapsed", 0.0) + + generation = result["generation"] + chunk_results.append(generation.strip()) + + # Aggregate results + aggregated_text = " ".join(chunk_results) + + if not result: + raise RuntimeError("Audio chunk generation returned no result") + + final_result = result.copy() + final_result["generation"] = aggregated_text + final_result["num_audio_chunks"] = len(chunks) + final_result["audio_duration"] = duration + # Update with summed statistics + final_result["input_tokens"] = total_input_tokens + final_result["generated_tokens"] = total_generated_tokens + final_result["time_elapsed"] = total_time + + return final_result + + async def generate_async( + self, + prompt: str | list[dict] | None = None, + tokens_to_generate: int | None = None, + task_type: str = None, + **kwargs, + ) -> dict: + """Generate with automatic audio chunking for long audio files. + + This override checks if the prompt (messages) contains long audio. + If so, it chunks the audio, processes each chunk separately, and aggregates results. + + Args: + prompt: Either a string (text completion) or list of messages (chat). + tokens_to_generate: Max tokens to generate. + task_type: Optional task type for chunking filtering. + **kwargs: Additional arguments passed to the underlying model. + + Returns: + Generation result dict with 'generation' key and optional metadata. + """ + if isinstance(prompt, list): + messages = prompt + needs_chunking, audio_path, duration = self._needs_audio_chunking(messages, task_type) + + if needs_chunking: + return await self._generate_with_chunking(messages, audio_path, duration, tokens_to_generate, **kwargs) + + # No chunking needed - convert audio fields to base64 format + messages = [self.content_text_to_list(copy.deepcopy(msg)) for msg in messages] + messages = self._preprocess_messages_for_model(messages) + prompt = messages + + # Call parent's generate_async (which handles audio OUTPUT via _parse_chat_completion_response) + return await super().generate_async(prompt=prompt, tokens_to_generate=tokens_to_generate, **kwargs) + + def _build_chat_request_params( + self, + messages: list[dict], + **kwargs, + ) -> dict: + """Build chat request parameters with audio preprocessing. + + Args: + messages: List of message dicts. + **kwargs: Additional parameters for the request. + + Returns: + Request parameters dict. + """ + # content_text_to_list THEN preprocess + messages = [self.content_text_to_list(copy.deepcopy(msg)) for msg in messages] + messages = self._preprocess_messages_for_model(messages) + return super()._build_chat_request_params(messages=messages, **kwargs) diff --git a/nemo_skills/prompt/config/judge/mmau-pro.yaml b/nemo_skills/prompt/config/judge/mmau-pro.yaml new file mode 100644 index 0000000000..5339e4ab0d --- /dev/null +++ b/nemo_skills/prompt/config/judge/mmau-pro.yaml @@ -0,0 +1,30 @@ +# Judge prompt configuration for Speech/Audio Language Model evaluation +# Used for evaluating open-ended responses in MMAU-Pro benchmark +# Uses multi-criteria scoring on 1-5 scale + +user: |- + You are an expert evaluator for audio and speech-related questions. Please evaluate the quality of a model's response to a question. + + Question: {question} + + Reference Answer: {expected_answer} + + Model Response: {generation} + + Please evaluate the model response on the following criteria and provide scores from 1-5 (where 5 is best): + + 1. **Correctness**: How factually accurate is the response compared to the reference? + 2. **Relevance**: How well does the response address the specific question asked? + 3. **Completeness**: Does the response cover all important aspects mentioned in the reference? + 4. **Clarity**: How clear and well-structured is the response? + + For each criterion, provide: + - A score from 1-5 + - A brief justification (1-2 sentences) + + Format your response as: + CORRECTNESS: [score] - [justification] + RELEVANCE: [score] - [justification] + COMPLETENESS: [score] - [justification] + CLARITY: [score] - [justification] + OVERALL: [average score] - [overall assessment] diff --git a/nemo_skills/prompt/config/judge/speechlm.yaml b/nemo_skills/prompt/config/judge/speechlm.yaml deleted file mode 100644 index 4862558145..0000000000 --- a/nemo_skills/prompt/config/judge/speechlm.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# Judge prompt configuration for Speech/Audio Language Model evaluation -# Used for evaluating open-ended responses in MMAU-Pro benchmark -# Follows nemo-skills standard Yes/No judgement pattern - -user: |- - You are an expert evaluator for audio and speech-related questions. Please evaluate whether the model's response correctly answers the question. - - Question: {question} - - Reference Answer: {expected_answer} - - Model Response: {generation} - - Your task is to determine if the model's response is correct based on the reference answer. Consider: - - 1. **Factual Accuracy**: Is the information in the response factually correct? - 2. **Relevance**: Does the response address the specific question asked? - 3. **Completeness**: Does the response cover the key points from the reference answer? - - Please first explain your reasoning in 2-3 sentences, then provide your final judgement. - - Your final judgement must be either "Yes" or "No": - - "Yes" if the model response is correct and adequately answers the question - - "No" if the model response is incorrect, irrelevant, or inadequate - - Format your response as: - Reasoning: [Your explanation] - Judgement: [Yes or No] diff --git a/tests/gpu-tests/test_vllm_audio.py b/tests/gpu-tests/test_vllm_audio.py new file mode 100644 index 0000000000..dcb1539e27 --- /dev/null +++ b/tests/gpu-tests/test_vllm_audio.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU integration tests for vLLM audio generation with VLLMMultimodalModel.""" + +import json +import shutil +import subprocess +import tempfile +from pathlib import Path + +import pytest +from utils import require_env_var + + +@pytest.mark.gpu +def test_vllm_audio_generation(): + """Integration test: Generate with vLLM server using audio input.""" + model_path = require_env_var("NEMO_SKILLS_TEST_HF_MODEL") + model_type = require_env_var("NEMO_SKILLS_TEST_MODEL_TYPE") + + output_dir = f"/tmp/nemo-skills-tests/{model_type}/vllm-audio-generation" + # Clean up output directory + if Path(output_dir).exists(): + shutil.rmtree(output_dir) + + # Create test input file with audio + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + test_data = [ + { + "problem": "Transcribe this audio", + "audio": {"path": "/nemo_run/code/tests/slurm-tests/asr_nim/wavs/t2_16.wav"}, + }, + { + "problem": "What is in this audio?", + "audio": {"path": "/nemo_run/code/tests/slurm-tests/asr_nim/wavs/t3_16.wav"}, + }, + ] + for item in test_data: + f.write(json.dumps(item) + "\n") + input_file = f.name + + try: + cmd = ( + f"ns generate " + f" --cluster test-local --config_dir {Path(__file__).absolute().parent} " + f" --model {model_path} " + f" --output_dir {output_dir} " + f" --server_type vllm_multimodal " + f" --server_gpus 1 " + f" --server_nodes 1 " + f" --server_args '--enforce-eager' " + f" --input_file={input_file} " + f" ++prompt_format=openai " + f" ++skip_filled=False " + ) + subprocess.run(cmd, shell=True, check=True) + + # Verify output exists and has audio-related generation + with open(f"{output_dir}/output.jsonl") as fin: + lines = fin.readlines() + + assert len(lines) == 2, "Should have 2 output lines" + + for line in lines: + data = json.loads(line) + assert "generation" in data, "Should have generation field" + assert len(data["generation"]) > 0, "Generation should not be empty" + # If model supports audio, generation should contain something + print(f"Generated: {data['generation']}") + + finally: + # Cleanup temp file + Path(input_file).unlink(missing_ok=True) diff --git a/tests/test_vllm_audio.py b/tests/test_vllm_audio.py new file mode 100644 index 0000000000..0c8ca1b89a --- /dev/null +++ b/tests/test_vllm_audio.py @@ -0,0 +1,155 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for audio utilities and VLLMMultimodalModel audio input handling.""" + +import base64 +import os +import tempfile +from unittest.mock import patch + +import pytest + +from nemo_skills.inference.model.audio_utils import audio_file_to_base64 +from nemo_skills.inference.model.vllm_multimodal import VLLMMultimodalModel + + +def test_audio_file_to_base64(): + """Test basic audio file encoding to base64.""" + with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f: + test_content = b"RIFF" + b"\x00" * 100 + f.write(test_content) + temp_path = f.name + + try: + result = audio_file_to_base64(temp_path) + assert isinstance(result, str) + assert len(result) > 0 + decoded = base64.b64decode(result) + assert decoded == test_content + finally: + os.unlink(temp_path) + + +@pytest.fixture +def mock_vllm_multimodal_model(tmp_path): + """Create a mock VLLMMultimodalModel for testing audio preprocessing.""" + with patch.object(VLLMMultimodalModel, "__init__", lambda self, **kwargs: None): + model = VLLMMultimodalModel() + model.data_dir = str(tmp_path) + model.output_dir = None + model.output_audio_dir = None + model.enable_audio_chunking = True + model.audio_chunk_task_types = None + model.chunk_audio_threshold_sec = 30 + model._tunnel = None + return model + + +def test_content_text_to_list_with_audio(mock_vllm_multimodal_model, tmp_path): + """Test converting string content with audio to list format. + + CRITICAL: Audio must come BEFORE text for Qwen Audio to transcribe correctly. + """ + audio_path = tmp_path / "test.wav" + with open(audio_path, "wb") as f: + f.write(b"RIFF" + b"\x00" * 100) + + message = {"role": "user", "content": "Describe this audio", "audio": {"path": "test.wav"}} + + result = mock_vllm_multimodal_model.content_text_to_list(message) + + assert isinstance(result["content"], list) + assert len(result["content"]) == 2 + assert result["content"][0]["type"] == "audio_url" + assert result["content"][0]["audio_url"]["url"].startswith("data:audio/wav;base64,") + assert result["content"][1]["type"] == "text" + + +def test_content_text_to_list_with_multiple_audios(mock_vllm_multimodal_model, tmp_path): + """Test handling message with multiple audio files. + + CRITICAL: Audio must come BEFORE text for Qwen Audio to transcribe correctly. + """ + audio_paths = [] + for i in range(2): + audio_path = tmp_path / f"test_{i}.wav" + with open(audio_path, "wb") as f: + f.write(b"RIFF" + b"\x00" * 100) + audio_paths.append(f"test_{i}.wav") + + message = { + "role": "user", + "content": "Compare these", + "audios": [{"path": audio_paths[0]}, {"path": audio_paths[1]}], + } + + result = mock_vllm_multimodal_model.content_text_to_list(message) + + assert isinstance(result["content"], list) + assert len(result["content"]) == 3 + # Audio MUST come before text for Qwen Audio + assert result["content"][0]["type"] == "audio_url" + assert result["content"][1]["type"] == "audio_url" + assert result["content"][2]["type"] == "text" + + +def test_content_text_to_list_no_audio(mock_vllm_multimodal_model): + """Test that messages without audio are returned unchanged.""" + message = {"role": "user", "content": "Hello, world!"} + result = mock_vllm_multimodal_model.content_text_to_list(message) + + assert result["content"] == "Hello, world!" + assert "audio" not in result + + +def test_preprocess_messages_preserves_no_think(mock_vllm_multimodal_model): + """Test that /no_think is preserved in system messages.""" + messages = [ + {"role": "system", "content": "You are a helpful assistant. /no_think"}, + {"role": "user", "content": "Hello"}, + ] + + result = mock_vllm_multimodal_model._preprocess_messages_for_model(messages) + + # /no_think should be preserved, not stripped + assert result[0]["content"] == "You are a helpful assistant. /no_think" + assert result[1]["content"] == "Hello" + + +def test_needs_audio_chunking_disabled(mock_vllm_multimodal_model): + """Test that chunking is skipped when disabled.""" + mock_vllm_multimodal_model.enable_audio_chunking = False + + messages = [{"role": "user", "content": "Test", "audio": {"path": "test.wav"}}] + needs_chunking, audio_path, duration = mock_vllm_multimodal_model._needs_audio_chunking(messages) + + assert needs_chunking is False + assert audio_path is None + assert duration == 0.0 + + +def test_needs_audio_chunking_task_type_filter(mock_vllm_multimodal_model): + """Test that chunking respects task type filter.""" + mock_vllm_multimodal_model.audio_chunk_task_types = ["transcription"] + + messages = [{"role": "user", "content": "Test", "audio": {"path": "test.wav"}}] + + # Task type not in filter - should not chunk + needs_chunking, _, _ = mock_vllm_multimodal_model._needs_audio_chunking(messages, task_type="qa") + assert needs_chunking is False + + # Task type in filter but file doesn't exist - should return False gracefully + needs_chunking, _, _ = mock_vllm_multimodal_model._needs_audio_chunking(messages, task_type="transcription") + assert needs_chunking is False