diff --git a/nemo_skills/inference/server/serve_unified.py b/nemo_skills/inference/server/serve_unified.py new file mode 100644 index 0000000000..48e3f34755 --- /dev/null +++ b/nemo_skills/inference/server/serve_unified.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CLI entrypoint for the Unified NeMo Inference Server. + +Configuration is YAML-based: provide a config file with backend type and +all backend-specific parameters. The config is validated against the +backend's config class. + +Usage: + python -m nemo_skills.inference.server.serve_unified \\ + --config /path/to/backend_config.yaml \\ + --port 8000 + + # Or with --model for nemo-skills pipeline compatibility: + python -m nemo_skills.inference.server.serve_unified \\ + --model /path/to/model \\ + --backend magpie_tts \\ + --codec_model /path/to/codec \\ + --port 8000 + +Backend-specific options are passed as extra CLI flags and forwarded to the +backend's config dataclass automatically. For example: + + --server_args "--backend magpie_tts --codec_model /path --use_cfg --cfg_scale 2.5" + +Any flag not recognized by the server itself is parsed generically: + --flag -> {"flag": True} + --key value -> {"key": } + --key=value -> {"key": } + --no_flag -> {"flag": False} + +See each backend's config class for available options (e.g. MagpieTTSConfig). + +Example YAML config (backend_config.yaml): + backend: magpie_tts + model_path: /path/to/model + codec_model_path: /path/to/codec + device: cuda + dtype: bfloat16 + temperature: 0.6 + top_k: 80 + use_cfg: true + cfg_scale: 2.5 +""" + +import argparse +import inspect +import os +import shutil +import sys +from typing import Optional + + +def setup_pythonpath(code_path: Optional[str] = None): + """Set up PYTHONPATH for NeMo and the unified server. + + Args: + code_path: Single path or colon-separated paths to add to PYTHONPATH + """ + paths_to_add = [] + + if code_path: + for path in code_path.split(":"): + if path and path not in paths_to_add: + paths_to_add.append(path) + + # Add recipes path for unified server imports + this_dir = os.path.dirname(os.path.abspath(__file__)) + ns_eval_root = os.path.dirname(os.path.dirname(os.path.dirname(this_dir))) + if os.path.exists(os.path.join(ns_eval_root, "recipes")): + paths_to_add.append(ns_eval_root) + + # Container pattern + if os.path.exists("/nemo_run/code"): + paths_to_add.append("/nemo_run/code") + + current_path = os.environ.get("PYTHONPATH", "") + for path in paths_to_add: + if path not in current_path.split(":"): + current_path = f"{path}:{current_path}" if current_path else path + + os.environ["PYTHONPATH"] = current_path + + for path in paths_to_add: + if path not in sys.path: + sys.path.insert(0, path) + + +def apply_safetensors_patch(hack_path: Optional[str]): + """Apply safetensors patch if provided (for some NeMo models).""" + if not hack_path or not os.path.exists(hack_path): + return + + try: + import safetensors.torch as st_torch + + dest_path = inspect.getfile(st_torch) + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + shutil.copyfile(hack_path, dest_path) + print(f"[serve_unified] Applied safetensors patch: {hack_path} -> {dest_path}") + except Exception as e: + print(f"[serve_unified] Warning: Failed to apply safetensors patch: {e}") + + +def load_yaml_config(config_path: str) -> dict: + """Load YAML config file.""" + import yaml + + with open(config_path) as f: + return yaml.safe_load(f) + + +def _coerce_value(value: str): + """Try to coerce a string value to int, float, or bool.""" + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + if value.lower() == "true": + return True + if value.lower() == "false": + return False + return value + + +def parse_extra_args(extra_args: list) -> dict: + """Convert unknown CLI args to a config dict. + + Handles these patterns: + --flag -> {"flag": True} + --key value -> {"key": } + --key=value -> {"key": } + --no_flag -> {"flag": False} (strip no_ prefix) + """ + result = {} + i = 0 + while i < len(extra_args): + arg = extra_args[i] + if not arg.startswith("--"): + i += 1 + continue + + # Handle --key=value + if "=" in arg: + key, value = arg[2:].split("=", 1) + result[key] = _coerce_value(value) + i += 1 + continue + + key = arg[2:] + + # Check if next token is a value (not another flag) + if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("--"): + result[key] = _coerce_value(extra_args[i + 1]) + i += 2 + continue + + # Bare flag: --no_X -> {X: False}, otherwise {key: True} + if key.startswith("no_"): + result[key[3:]] = False + else: + result[key] = True + i += 1 + + return result + + +def main(): + parser = argparse.ArgumentParser( + description="Unified NeMo Inference Server", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Primary: YAML config + parser.add_argument("--config", default=None, help="Path to YAML config file") + + # Standard args for nemo-skills pipeline compatibility + parser.add_argument("--model", default=None, help="Path to the model") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use") + parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument("--host", default="0.0.0.0", help="Server host") + parser.add_argument("--backend", default="magpie_tts", help="Backend type") + + # Server configuration + parser.add_argument("--batch_size", type=int, default=8, help="Maximum batch size") + parser.add_argument("--batch_timeout", type=float, default=0.1, help="Batch timeout in seconds") + + # Generation defaults + parser.add_argument("--max_new_tokens", type=int, default=512, help="Max tokens to generate") + parser.add_argument("--temperature", type=float, default=1.0, help="Generation temperature") + parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling") + + # Model configuration + parser.add_argument("--device", default="cuda", help="Device to use") + parser.add_argument("--dtype", default="bfloat16", help="Model dtype") + + # Environment setup + parser.add_argument("--code_path", default=None, help="Path to add to PYTHONPATH") + parser.add_argument("--hack_path", default=None, help="Path to safetensors patch") + + # Debug + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + + # Parse known args; everything else is backend-specific + args, unknown = parser.parse_known_args() + extra_config = parse_extra_args(unknown) + + # Setup environment + setup_pythonpath(args.code_path) + apply_safetensors_patch(args.hack_path) + + if args.code_path: + os.environ["UNIFIED_SERVER_CODE_PATH"] = args.code_path + + if args.debug: + os.environ["DEBUG"] = "1" + + # Set CUDA devices + if "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(args.num_gpus)) + + # Build configuration + if args.config: + # YAML config mode + config_dict = load_yaml_config(args.config) + backend_type = config_dict.pop("backend", args.backend) + # CLI overrides + if args.model: + config_dict["model_path"] = args.model + # Merge any extra CLI args into YAML config (CLI wins) + config_dict.update(extra_config) + else: + # CLI args mode (backward compatible) + if not args.model: + parser.error("--model is required when not using --config") + backend_type = args.backend + config_dict = { + "model_path": args.model, + "device": args.device, + "dtype": args.dtype, + "max_new_tokens": args.max_new_tokens, + "temperature": args.temperature, + "top_p": args.top_p, + } + # Merge backend-specific args from extra CLI flags + config_dict.update(extra_config) + + # Print configuration + print("=" * 60) + print("[serve_unified] Starting Unified NeMo Inference Server") + print("=" * 60) + print(f" Backend: {backend_type}") + print(f" Model: {config_dict.get('model_path', 'N/A')}") + print(f" Port: {args.port}") + print(f" GPUs: {args.num_gpus}") + print(f" Batch Size: {args.batch_size}") + print(f" Batch Timeout: {args.batch_timeout}s") + if args.config: + print(f" Config: {args.config}") + if extra_config: + print(f" Extra CLI Config: {extra_config}") + print("=" * 60) + + # Import and run + try: + import uvicorn + + from recipes.multimodal.server.unified_server import create_app + + app = create_app( + backend_type=backend_type, + config_dict=config_dict, + batch_size=args.batch_size, + batch_timeout=args.batch_timeout, + ) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + except ImportError as e: + print(f"[serve_unified] Error: Failed to import unified server: {e}") + print("[serve_unified] Make sure the recipes.multimodal.server package is in PYTHONPATH") + sys.exit(1) + except Exception as e: + print(f"[serve_unified] Error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/recipes/multimodal/__init__.py b/recipes/multimodal/__init__.py new file mode 100644 index 0000000000..332d9e7cae --- /dev/null +++ b/recipes/multimodal/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified NeMo Inference Server package. + +Provides a pluggable FastAPI server that supports multiple NeMo model backends +through a backend-agnostic architecture. All backend-specific logic lives in +the backend modules under `backends/`. +""" + +from .server.backends import ( + BackendConfig, + GenerationRequest, + GenerationResult, + InferenceBackend, + get_backend, +) + +__all__ = [ + "InferenceBackend", + "GenerationRequest", + "GenerationResult", + "BackendConfig", + "get_backend", +] diff --git a/recipes/multimodal/server/README.md b/recipes/multimodal/server/README.md new file mode 100644 index 0000000000..3fa1a02cb1 --- /dev/null +++ b/recipes/multimodal/server/README.md @@ -0,0 +1,150 @@ +# Unified Server and Backends + +This directory contains a backend-agnostic inference server with an OpenAI Chat Completions compatible API, plus pluggable backends for different model families. + +## Purpose + +- Expose one HTTP surface (`/v1/chat/completions`) for multiple model types. +- Keep server concerns (request parsing, batching, response shaping) separate from model-specific logic. +- Let each backend define only model load/inference behavior behind a shared interface. + +## Main Components + +- `unified_server.py`: FastAPI server, OpenAI-compatible response format, batch scheduling. +- `backends/base.py`: shared data models and abstract interface: + - `BackendConfig` + - `GenerationRequest` + - `GenerationResult` + - `InferenceBackend` + - `Modality` +- `backends/__init__.py`: backend registry and lazy loading (`BACKEND_REGISTRY`, `get_backend()`). +- `nemo_skills/inference/server/serve_unified.py`: CLI entrypoint used by local runs and cluster jobs. + +## Built-in Backends + +- `nemo_asr` -> `backends/nemo_asr_backend.py` +- `magpie_tts` -> `backends/magpie_tts_backend.py` + +## How To Add a New Backend + +1. Create a config dataclass inheriting `BackendConfig`. +2. Create a backend class inheriting `InferenceBackend`. +3. Register it in `backends/__init__.py` (`BACKEND_REGISTRY`). +4. Start server with `--backend ` and pass backend-specific args. +5. Add tests (unit + slurm where applicable). + +### Required Interface + +Your backend class must implement: + +- `@classmethod get_config_class(cls) -> type` +- `name` property +- `supported_modalities` property +- `load_model(self) -> None` +- `generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]` + +You should also override as needed: + +- `validate_request(self, request) -> Optional[str]` for strict input validation +- `health_check(self) -> Dict[str, Any]` for backend-specific health metadata +- `@classmethod get_extra_routes(cls, backend_instance)` if custom endpoints are needed + +### Minimal Skeleton + +```python +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set + +from recipes.multimodal.server.backends.base import ( + BackendConfig, + GenerationRequest, + GenerationResult, + InferenceBackend, + Modality, +) + + +@dataclass +class MyBackendConfig(BackendConfig): + my_flag: bool = True + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "MyBackendConfig": + known = {"model_path", "device", "dtype", "my_flag"} + return cls( + **{k: v for k, v in d.items() if k in known}, + extra_config={k: v for k, v in d.items() if k not in known}, + ) + + +class MyBackend(InferenceBackend): + @classmethod + def get_config_class(cls) -> type: + return MyBackendConfig + + @property + def name(self) -> str: + return "my_backend" + + @property + def supported_modalities(self) -> Set[Modality]: + return {Modality.TEXT} + + def load_model(self) -> None: + self._model = ... + self._is_loaded = True + + def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: + results = [] + for req in requests: + try: + text_out = f"echo: {req.text or ''}" + results.append(GenerationResult(text=text_out, request_id=req.request_id)) + except Exception as e: + results.append(GenerationResult(error=str(e), request_id=req.request_id)) + return results +``` + +Then register in `recipes/multimodal/server/backends/__init__.py`: + +```python +BACKEND_REGISTRY = { + # ... + "my_backend": ("my_backend", "MyBackend"), +} +``` + +## Running Slurm Tests (Unified ASR and Unified TTS) + +Activate env first: + +```bash +source .venv/bin/activate +source ~/.env +``` + +Run unified ASR backend slurm test: + +```bash +python tests/slurm-tests/unified_asr/run_test.py \ + --workspace /lustre/fsw/portfolios/convai/users/$USER/experiments/slurm-tests/unified_asr \ + --cluster dfw \ + --expname_prefix unified_asr_test \ + --server_container /lustre/fsw/portfolios/convai/users/$USER/workspace/images/nemo-25.11.sqsh +``` + +Run unified TTS backend slurm test: + +```bash +python tests/slurm-tests/unified_tts/run_test.py \ + --workspace /lustre/fsw/portfolios/convai/users/$USER/experiments/slurm-tests/unified_tts \ + --cluster dfw \ + --expname_prefix unified_tts_test \ + --server_container /lustre/fsw/portfolios/convai/users/$USER/workspace/images/nemo-25.11.sqsh \ + --code_path /lustre/fsw/portfolios/convai/users/$USER/workspace/code/NeMo_tts +``` + +Optional flags for both: + +- `--skip_check` to skip checker job +- `--config_dir` to select a different cluster config directory diff --git a/recipes/multimodal/server/__init__.py b/recipes/multimodal/server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipes/multimodal/server/backends/__init__.py b/recipes/multimodal/server/backends/__init__.py new file mode 100644 index 0000000000..e859175857 --- /dev/null +++ b/recipes/multimodal/server/backends/__init__.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Backend implementations for the Unified NeMo Inference Server. + +Available backends: +- magpie_tts: MagpieTTS text-to-speech (audio output from text input) +- nemo_asr: NeMo ASR speech-to-text (text output from audio input) + +Backends are lazily loaded to avoid importing heavy dependencies upfront. +""" + +import importlib + +from .base import BackendConfig, GenerationRequest, GenerationResult, InferenceBackend, Modality + +__all__ = [ + "InferenceBackend", + "GenerationRequest", + "GenerationResult", + "BackendConfig", + "Modality", + "get_backend", + "list_backends", +] + +# Registry of available backends: name -> (module_name, class_name) +BACKEND_REGISTRY = { + "magpie_tts": ("magpie_tts_backend", "MagpieTTSBackend"), + "nemo_asr": ("nemo_asr_backend", "NeMoASRBackend"), +} + + +def list_backends() -> list: + """Return list of available backend names.""" + return list(BACKEND_REGISTRY.keys()) + + +def get_backend(backend_name: str) -> type: + """Get backend class by name with lazy loading. + + Args: + backend_name: One of the registered backend names + + Returns: + Backend class (not instance) + + Raises: + ValueError: If backend name is unknown + ImportError: If backend dependencies are not available + """ + if backend_name not in BACKEND_REGISTRY: + available = ", ".join(BACKEND_REGISTRY.keys()) + raise ValueError(f"Unknown backend: '{backend_name}'. Available backends: {available}") + + module_name, class_name = BACKEND_REGISTRY[backend_name] + + try: + module = importlib.import_module(f".{module_name}", package=__name__) + return getattr(module, class_name) + except ImportError as e: + raise ImportError( + f"Failed to import backend '{backend_name}'. Make sure required dependencies are installed. Error: {e}" + ) from e diff --git a/recipes/multimodal/server/backends/base.py b/recipes/multimodal/server/backends/base.py new file mode 100644 index 0000000000..58fac6cdd9 --- /dev/null +++ b/recipes/multimodal/server/backends/base.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Abstract base class for inference backends. + +All model backends must implement this interface to be usable with the +unified inference server. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set + + +class Modality(str, Enum): + """Supported input/output modalities.""" + + TEXT = "text" + AUDIO_IN = "audio_in" + AUDIO_OUT = "audio_out" + + +@dataclass +class BackendConfig: + """Base configuration for all backends. + + Subclasses should add their own fields. The from_dict() classmethod + handles extracting known fields and putting the rest in extra_config. + """ + + model_path: str = "" + device: str = "cuda" + dtype: str = "bfloat16" + + # Generation defaults + max_new_tokens: int = 512 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: Optional[int] = None + + # Additional model-specific configs passed through + extra_config: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "BackendConfig": + """Create config from dictionary, extracting known fields.""" + known_fields = {f.name for f in cls.__dataclass_fields__.values()} + known = {k: v for k, v in d.items() if k in known_fields and k != "extra_config"} + extra = {k: v for k, v in d.items() if k not in known_fields} + return cls(**known, extra_config=extra) + + +@dataclass +class GenerationRequest: + """A single generation request. + + Supports text and/or audio inputs depending on the backend's capabilities. + """ + + # Text inputs + text: Optional[str] = None + system_prompt: Optional[str] = None + user_prompt: Optional[str] = None + + # Audio input (raw bytes) + audio_bytes: Optional[bytes] = None + sample_rate: int = 16000 + + # Multi-turn audio inputs (list of raw audio bytes) + audio_bytes_list: Optional[List[bytes]] = None + + # Generation parameters (override backend defaults) + max_new_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + seed: Optional[int] = None + + # Additional parameters + extra_params: Dict[str, Any] = field(default_factory=dict) + + # Request tracking + request_id: Optional[str] = None + + +@dataclass +class GenerationResult: + """Result from a generation request. + + Contains text output and optionally audio output, plus metadata. + """ + + # Text output + text: str = "" + + # Audio output (raw bytes, can be encoded to base64 for JSON) + audio_bytes: Optional[bytes] = None + audio_sample_rate: int = 16000 + audio_format: str = "wav" + + # Metadata + request_id: Optional[str] = None + num_tokens_generated: int = 0 + generation_time_ms: float = 0.0 + + # Debug info (optional, backend-specific) + debug_info: Optional[Dict[str, Any]] = None + + # Error handling + error: Optional[str] = None + + def is_success(self) -> bool: + return self.error is None + + +class InferenceBackend(ABC): + """Abstract base class for inference backends. + + Implementations must provide: + - get_config_class(): Return the config dataclass for this backend + - load_model(): Initialize the model from config + - generate(): Run inference on a batch of requests + - supported_modalities: What input/output types are supported + + Optionally: + - get_extra_routes(): Return additional FastAPI routes this backend needs + (e.g., session management endpoints for S2S backends) + """ + + def __init__(self, config: BackendConfig): + self.config = config + self._model = None + self._is_loaded = False + + @classmethod + @abstractmethod + def get_config_class(cls) -> type: + """Return the config dataclass for this backend. + + The returned class must be a subclass of BackendConfig with a + from_dict() classmethod. + """ + ... + + @property + @abstractmethod + def name(self) -> str: + """Return the backend name (e.g., 'salm', 'magpie_tts').""" + ... + + @property + @abstractmethod + def supported_modalities(self) -> Set[Modality]: + """Return the set of supported modalities.""" + ... + + @abstractmethod + def load_model(self) -> None: + """Load and initialize the model. + + Should set self._model and self._is_loaded = True on success. + Called once during server startup. + """ + ... + + @abstractmethod + def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: + """Run inference on a batch of requests. + + Args: + requests: List of generation requests to process + + Returns: + List of generation results, one per request (same order) + """ + ... + + @classmethod + def get_extra_routes(cls, backend_instance: "InferenceBackend") -> list: + """Return additional FastAPI routes this backend needs. + + Override to register custom endpoints (e.g., session management). + Each item should be a dict with 'path', 'endpoint', 'methods'. + + Returns: + List of route dicts, empty by default. + """ + return [] + + @property + def is_loaded(self) -> bool: + """Check if the model is loaded and ready.""" + return self._is_loaded + + def health_check(self) -> Dict[str, Any]: + """Return health status information.""" + return { + "backend": self.name, + "model_loaded": self._is_loaded, + "model_path": self.config.model_path, + "device": self.config.device, + "modalities": [m.value for m in self.supported_modalities], + } + + def get_generation_params(self, request: GenerationRequest) -> Dict[str, Any]: + """Get effective generation parameters, merging request with config defaults.""" + return { + "max_new_tokens": request.max_new_tokens + if request.max_new_tokens is not None + else self.config.max_new_tokens, + "temperature": request.temperature if request.temperature is not None else self.config.temperature, + "top_p": request.top_p if request.top_p is not None else self.config.top_p, + "top_k": request.top_k if request.top_k is not None else self.config.top_k, + } + + def validate_request(self, request: GenerationRequest) -> Optional[str]: + """Validate a request against supported modalities. + + Returns: + Error message if invalid, None if valid + """ + modalities = self.supported_modalities + + has_text_input = request.text is not None + has_audio_input = request.audio_bytes is not None or ( + request.audio_bytes_list is not None and len(request.audio_bytes_list) > 0 + ) + + if has_audio_input and Modality.AUDIO_IN not in modalities: + return f"Backend '{self.name}' does not support audio input" + + if not has_text_input and not has_audio_input: + return "Request must have either text or audio input" + + return None diff --git a/recipes/multimodal/server/backends/magpie_tts_backend.py b/recipes/multimodal/server/backends/magpie_tts_backend.py new file mode 100644 index 0000000000..6cd4bb9ae6 --- /dev/null +++ b/recipes/multimodal/server/backends/magpie_tts_backend.py @@ -0,0 +1,496 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +"""MagpieTTS backend using MagpieInferenceRunner with RTF metrics.""" + +import inspect +import io +import json +import logging +import os +import re +import shutil +import tempfile +import time +from dataclasses import dataclass, fields, is_dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +import soundfile as sf + +from .base import BackendConfig, GenerationRequest, GenerationResult, InferenceBackend, Modality + +logger = logging.getLogger(__name__) + +try: + import nemo.collections.tts.modules.audio_codec_modules as _audio_codec_modules +except ImportError: + _audio_codec_modules = None + +try: + from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import ( + load_evalset_config as _load_evalset_config, + ) + from nemo.collections.tts.modules.magpietts_inference.inference import ( + InferenceConfig as _InferenceConfig, + ) + from nemo.collections.tts.modules.magpietts_inference.inference import ( + MagpieInferenceRunner as _MagpieInferenceRunner, + ) + from nemo.collections.tts.modules.magpietts_inference.utils import ( + ModelLoadConfig as _ModelLoadConfig, + ) + from nemo.collections.tts.modules.magpietts_inference.utils import ( + load_magpie_model as _load_magpie_model, + ) +except ImportError: + _load_evalset_config = None + _InferenceConfig = None + _MagpieInferenceRunner = None + _ModelLoadConfig = None + _load_magpie_model = None + +try: + from nemo.collections.tts.models.magpietts import ModelInferenceParameters as _ModelInferenceParameters +except ImportError: + _ModelInferenceParameters = None + +try: + from huggingface_hub import hf_hub_download as _hf_hub_download +except ImportError: + _hf_hub_download = None + +try: + from nemo_text_processing.text_normalization.normalize import Normalizer as _Normalizer +except ImportError: + _Normalizer = None + + +@dataclass +class MagpieTTSConfig(BackendConfig): + codec_model_path: Optional[str] = None + max_decoder_steps: Optional[int] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + cfg_scale: Optional[float] = None + use_cfg: bool = True + use_local_transformer: bool = True + apply_attention_prior: bool = True + output_sample_rate: int = 22050 + # Checkpoint loading options (alternative to model_path .nemo file) + hparams_file: Optional[str] = None + checkpoint_file: Optional[str] = None + legacy_codebooks: bool = False + legacy_text_conditioning: bool = False + hparams_from_wandb: bool = False + # Text normalization (expands numbers, abbreviations, etc. before TTS) + enable_normalization: bool = False + normalizer_lang: str = "en" + normalizer_input_case: str = "cased" + # Optional allowlist for request-provided context_audio_filepath values. + context_audio_allowed_roots: Optional[List[str]] = None + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "MagpieTTSConfig": + # Handle CLI alias: --codec_model → codec_model_path + if "codec_model" in d and "codec_model_path" not in d: + d = {**d, "codec_model_path": d.pop("codec_model")} + if isinstance(d.get("context_audio_allowed_roots"), str): + d = { + **d, + "context_audio_allowed_roots": [p for p in d["context_audio_allowed_roots"].split(":") if p], + } + known = { + "model_path", + "device", + "dtype", + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "codec_model_path", + "use_cfg", + "cfg_scale", + "max_decoder_steps", + "use_local_transformer", + "apply_attention_prior", + "output_sample_rate", + "hparams_file", + "checkpoint_file", + "legacy_codebooks", + "legacy_text_conditioning", + "hparams_from_wandb", + "enable_normalization", + "normalizer_lang", + "normalizer_input_case", + "context_audio_allowed_roots", + } + return cls( + **{k: v for k, v in d.items() if k in known}, extra_config={k: v for k, v in d.items() if k not in known} + ) + + +class MagpieTTSBackend(InferenceBackend): + """MagpieTTS backend. Input: JSON with 'text' and 'context_audio_filepath'.""" + + @classmethod + def get_config_class(cls) -> type: + return MagpieTTSConfig + + @property + def name(self) -> str: + return "magpie_tts" + + @property + def supported_modalities(self) -> Set[Modality]: + return {Modality.TEXT, Modality.AUDIO_OUT} + + def __init__(self, config: BackendConfig): + self.tts_config = ( + config + if isinstance(config, MagpieTTSConfig) + else MagpieTTSConfig.from_dict( + { + **{ + k: getattr(config, k) + for k in ["model_path", "device", "dtype", "max_new_tokens", "temperature", "top_p", "top_k"] + if hasattr(config, k) + }, + **config.extra_config, + } + ) + ) + super().__init__(self.tts_config) + self._model = self._runner = self._temp_dir = self._checkpoint_name = None + self._normalizer = None + + def _patch_hf_fsspec_loader(self) -> None: + """Patch NeMo load_fsspec to use hf_hub_download for HF resolve URLs.""" + if _audio_codec_modules is None: + logger.warning("nemo TTS audio codec modules are unavailable; skipping load_fsspec HF patch") + return + + orig_load_fsspec = getattr(_audio_codec_modules, "load_fsspec", None) + if not callable(orig_load_fsspec) or getattr(_audio_codec_modules, "_hf_load_fsspec_patched", False): + return + + if _hf_hub_download is None: + logger.warning("huggingface_hub is unavailable; skipping load_fsspec HF patch") + return + + def _hf_resolve_to_local(url: str) -> str | None: + if not isinstance(url, str): + return None + url_no_q = url.split("?", 1)[0] + match = re.match(r"^https?://huggingface\.co/([^/]+)/([^/]+)/resolve/([^/]+)/(.+)$", url_no_q) + if not match: + return None + repo_id = f"{match.group(1)}/{match.group(2)}" + revision = match.group(3) + filename = match.group(4) + token = os.environ.get("HF_TOKEN") or None + return _hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) + + def _load_fsspec_patched(path: str, map_location: str = None, **kwargs): + if isinstance(path, str) and path.startswith("http"): + local_path = _hf_resolve_to_local(path) + if local_path: + return orig_load_fsspec(local_path, map_location=map_location, **kwargs) + return orig_load_fsspec(path, map_location=map_location, **kwargs) + + _audio_codec_modules.load_fsspec = _load_fsspec_patched + _audio_codec_modules._hf_load_fsspec_patched = True + + def _resolve_context_audio_path(self, raw_path: str) -> str: + """Resolve and validate request-provided context path against allowlisted roots.""" + allowed_roots = self.tts_config.context_audio_allowed_roots or [] + if not allowed_roots: + raise ValueError("context_audio_filepath is disabled; configure context_audio_allowed_roots to enable it.") + + resolved = Path(raw_path).expanduser().resolve() + if not resolved.exists(): + raise FileNotFoundError(f"context_audio_filepath not found: {resolved}") + if not resolved.is_file(): + raise ValueError(f"context_audio_filepath is not a file: {resolved}") + + for root in allowed_roots: + root_resolved = Path(root).expanduser().resolve() + try: + resolved.relative_to(root_resolved) + return str(resolved) + except ValueError: + continue + + allowed = ", ".join(str(Path(r).expanduser().resolve()) for r in allowed_roots) + raise ValueError(f"context_audio_filepath '{resolved}' is outside allowed roots: {allowed}") + + def load_model(self) -> None: + # Patch NeMo's load_fsspec() to route HuggingFace resolve URLs through + # huggingface_hub.hf_hub_download() (uses file locks and local caching), + # avoiding 429s when many ranks start concurrently. + self._patch_hf_fsspec_loader() + + if ( + _InferenceConfig is None + or _MagpieInferenceRunner is None + or _ModelLoadConfig is None + or _load_magpie_model is None + or _load_evalset_config is None + ): + raise ImportError("Required NeMo MagpieTTS inference modules are not available in this environment.") + + if not self.tts_config.codec_model_path: + raise ValueError("codec_model_path required") + + # Support both checkpoint mode (hparams + ckpt) and nemo mode + has_ckpt_mode = self.tts_config.hparams_file and self.tts_config.checkpoint_file + if has_ckpt_mode: + cfg = _ModelLoadConfig( + hparams_file=self.tts_config.hparams_file, + checkpoint_file=self.tts_config.checkpoint_file, + codecmodel_path=self.tts_config.codec_model_path, + legacy_codebooks=self.tts_config.legacy_codebooks, + legacy_text_conditioning=self.tts_config.legacy_text_conditioning, + hparams_from_wandb=self.tts_config.hparams_from_wandb, + ) + else: + cfg = _ModelLoadConfig( + nemo_file=self.config.model_path, + codecmodel_path=self.tts_config.codec_model_path, + legacy_codebooks=self.tts_config.legacy_codebooks, + legacy_text_conditioning=self.tts_config.legacy_text_conditioning, + ) + self._model, self._checkpoint_name = _load_magpie_model(cfg, device=self.config.device) + + # Merge args from MagpieTTSConfig into InferenceConfig. NeMo API differs + # across builds: some use ModelInferenceParameters, others do not expose it. + model_inference_candidates = dict(self.tts_config.extra_config) + if self.tts_config.max_decoder_steps is not None: + model_inference_candidates["max_decoder_steps"] = self.tts_config.max_decoder_steps + if self.tts_config.temperature is not None: + model_inference_candidates["temperature"] = self.tts_config.temperature + if self.tts_config.top_k is not None: + model_inference_candidates["top_k"] = self.tts_config.top_k + if self.tts_config.cfg_scale is not None: + model_inference_candidates["cfg_scale"] = self.tts_config.cfg_scale + + inference_ctor_params = inspect.signature(_InferenceConfig).parameters + inference_kwargs = {} + if "batch_size" in inference_ctor_params: + inference_kwargs["batch_size"] = 16 + if "use_cfg" in inference_ctor_params: + inference_kwargs["use_cfg"] = self.tts_config.use_cfg + if "use_local_transformer" in inference_ctor_params: + inference_kwargs["use_local_transformer"] = self.tts_config.use_local_transformer + if "apply_attention_prior" in inference_ctor_params: + inference_kwargs["apply_attention_prior"] = self.tts_config.apply_attention_prior + + if "model_inference_parameters" in inference_ctor_params: + if _ModelInferenceParameters is not None and is_dataclass(_ModelInferenceParameters): + mip_fields = {f.name for f in fields(_ModelInferenceParameters)} + mip_kwargs = {k: v for k, v in model_inference_candidates.items() if k in mip_fields} + if hasattr(_ModelInferenceParameters, "from_dict"): + inference_kwargs["model_inference_parameters"] = _ModelInferenceParameters.from_dict(mip_kwargs) + else: + inference_kwargs["model_inference_parameters"] = _ModelInferenceParameters(**mip_kwargs) + else: + # Older/newer NeMo variants can accept dict-like parameters here. + inference_kwargs["model_inference_parameters"] = model_inference_candidates + else: + # Fallback API: inference params are top-level kwargs on InferenceConfig. + for key, value in model_inference_candidates.items(): + if key in inference_ctor_params: + inference_kwargs[key] = value + + try: + inference_config = _InferenceConfig(**inference_kwargs) + except TypeError: + # Minimal fallback for strict config signatures. + minimal_kwargs = { + k: v + for k, v in inference_kwargs.items() + if k in {"batch_size", "use_cfg", "use_local_transformer", "apply_attention_prior"} + } + inference_config = _InferenceConfig(**minimal_kwargs) + + self._runner = _MagpieInferenceRunner(self._model, inference_config) + + self._temp_dir = tempfile.mkdtemp(prefix="magpie_tts_") + self.tts_config.output_sample_rate = self._model.sample_rate + self._is_loaded = True + + # Initialize text normalizer if enabled + if self.tts_config.enable_normalization: + if _Normalizer is None: + raise RuntimeError( + "Failed to initialize text normalizer while enable_normalization=true: " + "nemo_text_processing is not available." + ) + try: + self._normalizer = _Normalizer( + lang=self.tts_config.normalizer_lang, + input_case=self.tts_config.normalizer_input_case, + ) + logger.info("Text normalizer initialized (lang=%s)", self.tts_config.normalizer_lang) + except Exception as e: + raise RuntimeError("Failed to initialize text normalizer while enable_normalization=true.") from e + + logger.info( + "Loaded MagpieTTS checkpoint=%s sample_rate=%s cfg=%s", + self._checkpoint_name, + self._model.sample_rate, + self.tts_config.use_cfg, + ) + + def _extract_json(self, text: str) -> dict: + """Extract JSON object from text, skipping non-JSON parts.""" + if not text: + return {"text": ""} + idx = text.find("{") + if idx >= 0: + try: + return json.loads(text[idx:]) + except json.JSONDecodeError: + pass + return {"text": text} + + def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: + if not self._is_loaded: + return [GenerationResult(error="Model not loaded", request_id=r.request_id) for r in requests] + if not requests: + return [] + + start_time = time.time() + batch_dir = os.path.join(self._temp_dir, f"batch_{int(time.time() * 1000)}") + output_dir = os.path.join(batch_dir, "output") + os.makedirs(output_dir, exist_ok=True) + + try: + # Reset KV caches to avoid cross-request shape mismatches + if self._model is not None: + decoder = getattr(self._model, "decoder", None) + if decoder is not None and hasattr(decoder, "reset_cache"): + decoder.reset_cache(use_cache=False) + + # Parse requests, extracting JSON from text + parsed = [self._extract_json(r.text) for r in requests] + + # Create audio_dir with symlinks to all context audio files + audio_dir = os.path.join(batch_dir, "audio") + os.makedirs(audio_dir, exist_ok=True) + + manifest_path = os.path.join(batch_dir, "manifest.json") + with open(manifest_path, "w") as f: + for i, p in enumerate(parsed): + ctx = p.get("context_audio_filepath", "") + if ctx: + resolved_ctx = self._resolve_context_audio_path(str(ctx)) + link_name = f"ctx_{i}_{os.path.basename(resolved_ctx)}" + link_path = os.path.join(audio_dir, link_name) + if not os.path.exists(link_path): + os.symlink(resolved_ctx, link_path) + else: + link_name = f"d{i}.wav" + link_path = os.path.join(audio_dir, link_name) + if not os.path.exists(link_path): + sr = int(getattr(self.tts_config, "output_sample_rate", 22050) or 22050) + dur_s = 0.1 + n = max(1, int(sr * dur_s)) + sf.write(link_path, [0.0] * n, sr) + text = p.get("text", "") + if self._normalizer: + try: + text = self._normalizer.normalize(text, punct_pre_process=True, punct_post_process=True) + except Exception as e: + raise RuntimeError(f"Failed to normalize text for sample index {i}") from e + entry = { + "text": text, + "audio_filepath": link_name, + "context_audio_filepath": link_name, + "duration": p.get("duration", 5.0), + "context_audio_duration": p.get("context_audio_duration", 5.0), + } + if p.get("speaker_index") is not None: + entry["speaker_index"] = int(p["speaker_index"]) + f.write(json.dumps(entry) + "\n") + + config_path = os.path.join(batch_dir, "config.json") + with open(config_path, "w") as f: + json.dump({"batch": {"manifest_path": manifest_path, "audio_dir": audio_dir}}, f) + + # Run inference + dataset = self._runner.create_dataset(_load_evalset_config(config_path)) + rtf_list, _, _ = self._runner.run_inference_on_dataset( + dataset, output_dir, save_cross_attention_maps=False, save_context_audio=False + ) + + gen_time = time.time() - start_time + batch_metrics = { + "total_time_sec": gen_time, + "num_samples": len(requests), + **self._runner.compute_mean_rtf_metrics(rtf_list), + } + + # Build results + results = [] + for i, req in enumerate(requests): + path = os.path.join(output_dir, f"predicted_audio_{i}.wav") + if os.path.exists(path): + audio, sr = sf.read(path) + buf = io.BytesIO() + sf.write(buf, audio, sr, format="WAV") + buf.seek(0) + dur = len(audio) / sr + results.append( + GenerationResult( + text=parsed[i].get("text", ""), + audio_bytes=buf.read(), + audio_sample_rate=self.tts_config.output_sample_rate, + audio_format="wav", + request_id=req.request_id, + generation_time_ms=gen_time * 1000 / len(requests), + debug_info={ + "checkpoint": self._checkpoint_name, + "audio_duration_sec": dur, + "rtf": gen_time / len(requests) / dur if dur else 0, + "config": { + "temp": self.tts_config.temperature, + "top_k": self.tts_config.top_k, + "cfg": self.tts_config.use_cfg, + "cfg_scale": self.tts_config.cfg_scale, + }, + "batch_metrics": batch_metrics, + }, + ) + ) + else: + results.append(GenerationResult(error=f"Audio not found: {path}", request_id=req.request_id)) + return results + except Exception as e: + logger.exception("Magpie generation failed") + return [GenerationResult(error=str(e), request_id=r.request_id) for r in requests] + finally: + shutil.rmtree(batch_dir, ignore_errors=True) + + def validate_request(self, request: GenerationRequest) -> Optional[str]: + return "Text required" if not request.text else None + + def health_check(self) -> Dict[str, Any]: + h = super().health_check() + if self._is_loaded: + h.update( + { + "checkpoint": self._checkpoint_name, + "codec": self.tts_config.codec_model_path, + "cfg": self.tts_config.use_cfg, + "cfg_scale": self.tts_config.cfg_scale, + "sample_rate": self.tts_config.output_sample_rate, + } + ) + return h + + def __del__(self): + if getattr(self, "_temp_dir", None) and os.path.exists(self._temp_dir): + shutil.rmtree(self._temp_dir, ignore_errors=True) diff --git a/recipes/multimodal/server/backends/nemo_asr_backend.py b/recipes/multimodal/server/backends/nemo_asr_backend.py new file mode 100644 index 0000000000..41571d7b5c --- /dev/null +++ b/recipes/multimodal/server/backends/nemo_asr_backend.py @@ -0,0 +1,332 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +"""NeMo ASR backend for unified_server. + +This backend performs offline transcription in batches using NeMo ASR models. +It expects audio input and returns transcript text plus optional word metadata +in debug_info for downstream mapping. +""" + +from __future__ import annotations + +import logging +import os +import tempfile +import time +import wave +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +from .base import BackendConfig, GenerationRequest, GenerationResult, InferenceBackend, Modality + +logger = logging.getLogger(__name__) + + +@dataclass +class NeMoASRConfig(BackendConfig): + """Configuration for NeMo ASR backend.""" + + # Optional alias for model_path when using --model_name style setup. + model_name: Optional[str] = None + + # Runtime and batching. + batch_size: int = 16 + num_workers: int = 0 + return_hypotheses: bool = True + warmup: bool = True + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "NeMoASRConfig": + # Allow --model_name to override empty model_path. + if d.get("model_name") and not d.get("model_path"): + d = {**d, "model_path": d["model_name"]} + known = { + "model_path", + "model_name", + "device", + "dtype", + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "batch_size", + "num_workers", + "return_hypotheses", + "warmup", + } + return cls( + **{k: v for k, v in d.items() if k in known}, + extra_config={k: v for k, v in d.items() if k not in known}, + ) + + +class NeMoASRBackend(InferenceBackend): + """Unified-server backend for offline ASR using NeMo models.""" + + @classmethod + def get_config_class(cls) -> type: + return NeMoASRConfig + + @property + def name(self) -> str: + return "nemo_asr" + + @property + def supported_modalities(self) -> Set[Modality]: + return {Modality.AUDIO_IN, Modality.TEXT} + + def __init__(self, config: BackendConfig): + self.asr_config = config if isinstance(config, NeMoASRConfig) else NeMoASRConfig.from_dict(config.extra_config) + super().__init__(self.asr_config) + self._model_name = self.asr_config.model_path or self.asr_config.model_name + self._model = None + + def load_model(self) -> None: + if not self._model_name: + raise ValueError("NeMo ASR backend requires model_path (or model_name).") + + import torch + from nemo.collections.asr.models import ASRModel + + model_ref = self._model_name + map_location = torch.device(self.config.device) + + if Path(model_ref).exists(): + self._model = ASRModel.restore_from(model_ref, map_location=map_location) + else: + self._model = ASRModel.from_pretrained(model_name=model_ref, map_location=map_location) + + if not hasattr(self._model, "to"): + raise RuntimeError(f"Loaded ASR model '{self._model_name}' does not support `.to(device)` placement.") + try: + self._model.to(self.config.device) + except Exception as e: + raise RuntimeError( + f"Failed to move ASR model '{self._model_name}' to device '{self.config.device}'." + ) from e + self._model.eval() + + if self.asr_config.warmup: + self._run_warmup() + + self._is_loaded = True + logger.info("Loaded NeMo ASR model=%s on device=%s", self._model_name, self.config.device) + + def _run_warmup(self) -> None: + """Run one short warmup call so runtime init happens before traffic.""" + fd, path = tempfile.mkstemp(suffix=".wav", prefix="nemo_asr_warmup_") + os.close(fd) + try: + with wave.open(path, "wb") as wavf: + wavf.setnchannels(1) + wavf.setsampwidth(2) + wavf.setframerate(16000) + wavf.writeframes(b"\x00\x00" * 1600) # 0.1 sec silence + self._transcribe_paths([path], return_hypotheses=False, batch_size=1) + except Exception as e: + logger.warning("ASR warmup skipped due to error: %s", e) + finally: + Path(path).unlink(missing_ok=True) + + def _transcribe_paths( + self, + audio_paths: List[str], + *, + return_hypotheses: bool, + batch_size: int, + extra: Optional[Dict[str, Any]] = None, + ) -> Any: + """Call NeMo transcribe with compatibility across signatures.""" + kwargs = { + "batch_size": batch_size, + "return_hypotheses": return_hypotheses, + "num_workers": self.asr_config.num_workers, + } + if extra: + kwargs.update(extra) + + try: + return self._model.transcribe(audio=audio_paths, **kwargs) + except TypeError: + try: + return self._model.transcribe(paths2audio_files=audio_paths, **kwargs) + except TypeError: + return self._model.transcribe(audio_paths, **kwargs) + + @staticmethod + def _normalize_words(words_obj: Any) -> List[Dict[str, Any]]: + """Normalize various word/timestamp schemas to list[dict].""" + if words_obj is None: + return [] + + if isinstance(words_obj, list): + normalized = [] + for item in words_obj: + if isinstance(item, dict): + normalized.append( + { + "word": item.get("word", item.get("text", "")), + "start_time": item.get("start_time", item.get("start", None)), + "end_time": item.get("end_time", item.get("end", None)), + "confidence": item.get("confidence", None), + } + ) + continue + if isinstance(item, (tuple, list)): + word = item[0] if len(item) > 0 else "" + start = item[1] if len(item) > 1 else None + end = item[2] if len(item) > 2 else None + normalized.append({"word": word, "start_time": start, "end_time": end, "confidence": None}) + continue + if isinstance(item, str): + normalized.append({"word": item, "start_time": None, "end_time": None, "confidence": None}) + continue + + word = getattr(item, "word", getattr(item, "text", "")) + normalized.append( + { + "word": word, + "start_time": getattr(item, "start_time", getattr(item, "start", None)), + "end_time": getattr(item, "end_time", getattr(item, "end", None)), + "confidence": getattr(item, "confidence", None), + } + ) + return normalized + + if isinstance(words_obj, dict): + words = words_obj.get("word") or words_obj.get("words") + starts = words_obj.get("start") or words_obj.get("start_time") + ends = words_obj.get("end") or words_obj.get("end_time") + if isinstance(words, list): + out = [] + for idx, word in enumerate(words): + start = starts[idx] if isinstance(starts, list) and idx < len(starts) else None + end = ends[idx] if isinstance(ends, list) and idx < len(ends) else None + out.append({"word": word, "start_time": start, "end_time": end, "confidence": None}) + return out + + return [] + + def _parse_single_hypothesis(self, hyp: Any) -> tuple[str, List[Dict[str, Any]]]: + """Extract transcript and words from heterogeneous NeMo outputs.""" + if isinstance(hyp, str): + return hyp, [] + + if isinstance(hyp, dict): + text = hyp.get("text") or hyp.get("pred_text") or hyp.get("transcript") or "" + words = hyp.get("words") + if words is None: + ts = hyp.get("timestamp") + if isinstance(ts, dict): + words = ts.get("word") + elif isinstance(words, list) and all(isinstance(w, str) for w in words): + ts = hyp.get("timestamp") + if isinstance(ts, dict) and isinstance(ts.get("word"), list): + words = ts["word"] + return text, self._normalize_words(words) + + text = getattr(hyp, "text", None) or getattr(hyp, "pred_text", None) or str(hyp) + words = getattr(hyp, "words", None) + if words is None: + ts = getattr(hyp, "timestamp", None) + if isinstance(ts, dict): + words = ts.get("word") + elif isinstance(words, list) and all(isinstance(w, str) for w in words): + ts = getattr(hyp, "timestamp", None) + if isinstance(ts, dict) and isinstance(ts.get("word"), list): + words = ts["word"] + return text, self._normalize_words(words) + + def _get_request_audio_bytes(self, request: GenerationRequest) -> bytes: + if request.audio_bytes: + return request.audio_bytes + if request.audio_bytes_list: + if len(request.audio_bytes_list) > 1: + raise ValueError("nemo_asr backend currently supports one audio input per request.") + return request.audio_bytes_list[0] + raise ValueError("Request must contain audio_bytes/audio_bytes_list") + + def validate_request(self, request: GenerationRequest) -> Optional[str]: + has_audio = request.audio_bytes is not None or ( + request.audio_bytes_list is not None and len(request.audio_bytes_list) > 0 + ) + if not has_audio: + return "nemo_asr backend requires audio input" + return None + + def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: + if not self._is_loaded: + return [GenerationResult(error="Model not loaded", request_id=r.request_id) for r in requests] + if not requests: + return [] + + tmp_dir = Path(tempfile.mkdtemp(prefix="nemo_asr_batch_")) + start = time.time() + temp_paths: List[str] = [] + valid_indices: List[int] = [] + results: List[Optional[GenerationResult]] = [None] * len(requests) + + try: + for idx, req in enumerate(requests): + try: + audio_bytes = self._get_request_audio_bytes(req) + p = tmp_dir / f"req_{idx:04d}.wav" + p.write_bytes(audio_bytes) + temp_paths.append(str(p)) + valid_indices.append(idx) + except Exception as e: + results[idx] = GenerationResult(error=str(e), request_id=req.request_id) + + if temp_paths: + first_extra = requests[valid_indices[0]].extra_params or {} + return_hypotheses = bool(first_extra.get("return_hypotheses", self.asr_config.return_hypotheses)) + transcribe_batch_size = int(first_extra.get("batch_size", self.asr_config.batch_size)) + transcribe_batch_size = max(1, min(transcribe_batch_size, len(temp_paths))) + + # Pass through optional ASR params if present (useful for canary-style models). + optional_keys = ["timestamps", "task", "source_lang", "target_lang", "pnc", "channel_selector"] + extra = {k: first_extra[k] for k in optional_keys if k in first_extra} + + hyps = self._transcribe_paths( + temp_paths, + return_hypotheses=return_hypotheses, + batch_size=transcribe_batch_size, + extra=extra, + ) + + if not isinstance(hyps, list): + hyps = [hyps] + if len(hyps) != len(temp_paths): + raise RuntimeError(f"ASR output size mismatch: got {len(hyps)} for {len(temp_paths)} inputs.") + + per_req_ms = (time.time() - start) * 1000.0 / max(len(temp_paths), 1) + for out_idx, hyp in enumerate(hyps): + req_idx = valid_indices[out_idx] + req = requests[req_idx] + text, words = self._parse_single_hypothesis(hyp) + results[req_idx] = GenerationResult( + text=text, + request_id=req.request_id, + generation_time_ms=per_req_ms, + debug_info={ + "words": words, + "backend": "nemo_asr", + "model": self._model_name, + "batch_size": transcribe_batch_size, + }, + ) + + return [r if r is not None else GenerationResult(error="Unknown ASR backend error") for r in results] + except Exception as e: + return [GenerationResult(error=str(e), request_id=r.request_id) for r in requests] + finally: + for p in temp_paths: + Path(p).unlink(missing_ok=True) + try: + tmp_dir.rmdir() + except OSError: + pass diff --git a/recipes/multimodal/server/unified_server.py b/recipes/multimodal/server/unified_server.py new file mode 100644 index 0000000000..1cf24c9de8 --- /dev/null +++ b/recipes/multimodal/server/unified_server.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified NeMo Inference Server with OpenAI-compatible API. + +Backend-agnostic: all backend-specific logic lives in backend modules. +The server only knows about GenerationRequest/GenerationResult and the +InferenceBackend interface. + +Exposes /v1/chat/completions endpoint for OpenAI compatibility. +Backends may register additional routes via get_extra_routes(). +""" + +import asyncio +import base64 +import hashlib +import json +import logging +import os +import re +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional + +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse + +from .backends import GenerationRequest, GenerationResult, get_backend + +# Debug flag +DEBUG = os.getenv("DEBUG", "").lower() in ("true", "1", "yes", "on") +logger = logging.getLogger(__name__) + + +@dataclass +class PendingRequest: + """Container for a pending batched request.""" + + request: GenerationRequest + future: asyncio.Future + timestamp: float + + +class RequestBatcher: + """Manages request batching with configurable delay.""" + + def __init__(self, backend, batch_size: int, batch_timeout: float): + self.backend = backend + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.pending_requests: List[PendingRequest] = [] + self.lock = asyncio.Lock() + self.timeout_task: Optional[asyncio.Task] = None + self.processing = False + + # Stats + self.total_requests = 0 + self.total_batches = 0 + + async def add_request(self, request: GenerationRequest) -> GenerationResult: + """Add a request and wait for result.""" + future = asyncio.get_running_loop().create_future() + pending = PendingRequest(request=request, future=future, timestamp=time.time()) + + async with self.lock: + self.pending_requests.append(pending) + + if len(self.pending_requests) >= self.batch_size: + if DEBUG: + print(f"[Batcher] Batch full ({self.batch_size}), processing immediately") + asyncio.create_task(self._process_batch()) + elif self.batch_timeout == 0: + asyncio.create_task(self._process_batch()) + elif self.timeout_task is None or self.timeout_task.done(): + self.timeout_task = asyncio.create_task(self._timeout_handler()) + + return await future + + async def _timeout_handler(self): + """Handle batch timeout.""" + await asyncio.sleep(self.batch_timeout) + async with self.lock: + if self.pending_requests and not self.processing: + if DEBUG: + print(f"[Batcher] Timeout, processing {len(self.pending_requests)} requests") + asyncio.create_task(self._process_batch()) + + async def _process_batch(self): + """Process pending requests as a batch.""" + async with self.lock: + if not self.pending_requests or self.processing: + return + + self.processing = True + batch = self.pending_requests[: self.batch_size] + self.pending_requests = self.pending_requests[self.batch_size :] + + try: + requests = [p.request for p in batch] + + if DEBUG: + print(f"[Batcher] Processing batch of {len(requests)} requests") + + loop = asyncio.get_running_loop() + results = await loop.run_in_executor(None, self.backend.generate, requests) + + if len(results) != len(batch): + raise RuntimeError(f"Backend returned {len(results)} results for {len(batch)} requests") + + for pending, result in zip(batch, results, strict=True): + if not pending.future.done(): + pending.future.set_result(result) + + self.total_requests += len(batch) + self.total_batches += 1 + + except Exception as e: + for pending in batch: + if not pending.future.done(): + pending.future.set_exception(e) + finally: + async with self.lock: + self.processing = False + if self.pending_requests: + if self.batch_timeout == 0 or len(self.pending_requests) >= self.batch_size: + asyncio.create_task(self._process_batch()) + elif self.timeout_task is None or self.timeout_task.done(): + self.timeout_task = asyncio.create_task(self._timeout_handler()) + + +# Global state +backend_instance = None +request_batcher = None +server_config = {} + + +def extract_audio_from_messages(messages: List[Dict[str, Any]]) -> List[bytes]: + """Extract all audio bytes from OpenAI-format messages. + + Supports these message content block formats: + - {"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}} + - {"type": "input_audio", "input_audio": {"data": "...", "format": "wav"}} + """ + audio_list = [] + data_uri_pattern = re.compile(r"^data:audio/[^;]+;base64,(.+)$") + + for message in messages: + content = message.get("content") + if isinstance(content, list): + for item in content: + if not isinstance(item, dict): + continue + + try: + if item.get("type") == "audio_url": + audio_url = item.get("audio_url", {}) + url = audio_url.get("url", "") + match = data_uri_pattern.match(url) + if match: + audio_list.append(base64.b64decode(match.group(1))) + elif item.get("type") == "input_audio": + input_audio = item.get("input_audio", {}) + data = input_audio.get("data", "") + if data: + audio_list.append(base64.b64decode(data)) + except Exception as e: + if DEBUG: + print(f"[Server] Warning: Failed to decode audio block: {e}") + return audio_list + + +def extract_text_from_messages(messages: List[Dict[str, Any]]) -> str: + """Extract text content from OpenAI-format messages.""" + texts = [] + for message in messages: + if message.get("role") == "system": + continue + content = message.get("content") + if isinstance(content, str): + if content: + texts.append(content) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text = item.get("text", "") + if text: + texts.append(text) + elif isinstance(item, str): + texts.append(item) + return " ".join(texts) + + +def extract_system_prompt(messages: List[Dict[str, Any]]) -> Optional[str]: + """Extract system prompt from messages.""" + for message in messages: + if message.get("role") == "system": + content = message.get("content") + if isinstance(content, str): + return content + elif isinstance(content, list): + texts = [ + item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "text" + ] + return " ".join(texts) if texts else None + return None + + +def create_app( + backend_type: str, + config_dict: Dict[str, Any], + batch_size: int = 8, + batch_timeout: float = 0.1, +) -> FastAPI: + """Create and configure the FastAPI app. + + Args: + backend_type: Name of the backend to use (e.g., 'magpie_tts'). + config_dict: Full configuration dict for the backend's config class. + batch_size: Maximum batch size for request batching. + batch_timeout: Seconds to wait before processing an incomplete batch. + """ + global backend_instance, request_batcher, server_config + + app = FastAPI( + title="Unified NeMo Inference Server", + description=f"OpenAI-compatible API for NeMo model inference ({backend_type} backend)", + version="1.0.0", + ) + + server_config = { + "backend_type": backend_type, + "model_path": config_dict.get("model_path", ""), + "batch_size": batch_size, + "batch_timeout": batch_timeout, + } + + @app.on_event("startup") + async def startup(): + global backend_instance, request_batcher + + # Look up backend class and its config class + BackendClass = get_backend(backend_type) + ConfigClass = BackendClass.get_config_class() + + # Validate and create config + config = ConfigClass.from_dict(config_dict) + + # Instantiate and load backend + print(f"[Server] Initializing {backend_type} backend...") + backend_instance = BackendClass(config) + backend_instance.load_model() + + # Create batcher + request_batcher = RequestBatcher(backend_instance, batch_size, batch_timeout) + + # Register any extra routes from the backend + extra_routes = BackendClass.get_extra_routes(backend_instance) + for route in extra_routes: + app.add_api_route( + route["path"], + route["endpoint"], + methods=route.get("methods", ["GET"]), + ) + print(f"[Server] Registered extra route: {route['path']}") + + print("[Server] Ready!") + print(f" Backend: {backend_type}") + print(f" Model: {config.model_path}") + print(f" Batch size: {batch_size}") + print(f" Batch timeout: {batch_timeout}s") + + @app.get("/") + async def root(): + """Root endpoint with server info.""" + return { + "service": "Unified NeMo Inference Server", + "version": "1.0.0", + "backend": server_config.get("backend_type"), + "model": server_config.get("model_path"), + "endpoints": ["/v1/chat/completions", "/health", "/v1/models"], + } + + @app.get("/health") + async def health(): + """Health check endpoint.""" + if backend_instance is None: + return JSONResponse(status_code=503, content={"status": "not_ready", "error": "Backend not initialized"}) + + health_info = backend_instance.health_check() + health_info["status"] = "healthy" if backend_instance.is_loaded else "not_ready" + health_info["timestamp"] = datetime.now().isoformat() + + return health_info + + @app.get("/v1/models") + async def list_models(): + """OpenAI-compatible models endpoint.""" + model_id = server_config.get("model_path", "unknown") + return { + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "nvidia", + } + ], + } + + @app.post("/v1/chat/completions") + async def chat_completions(request: Dict[str, Any]): + """OpenAI-compatible chat completions endpoint with audio support.""" + if backend_instance is None or not backend_instance.is_loaded: + raise HTTPException(status_code=503, detail="Model not loaded") + + try: + messages = request.get("messages", []) + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") + + # Extract components from messages + audio_bytes_list = extract_audio_from_messages(messages) + text = extract_text_from_messages(messages) + system_prompt = extract_system_prompt(messages) + + # Get generation parameters + max_tokens = request.get("max_tokens", 512) + temperature = request.get("temperature", 1.0) + top_p = request.get("top_p", 1.0) + seed = request.get("seed") + + # Create generation request + gen_request = GenerationRequest( + text=text if text else None, + system_prompt=system_prompt, + audio_bytes=audio_bytes_list[0] if len(audio_bytes_list) == 1 else None, + audio_bytes_list=audio_bytes_list if len(audio_bytes_list) > 1 else None, + max_new_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + seed=seed, + request_id=hashlib.md5(f"{time.time()}".encode()).hexdigest()[:8], + extra_params=request.get("extra_body", {}), + ) + + # Validate request + error = backend_instance.validate_request(gen_request) + if error: + raise HTTPException(status_code=400, detail=error) + + # Process through batcher + result = await request_batcher.add_request(gen_request) + + if not result.is_success(): + error_id = hashlib.md5(f"{time.time_ns()}".encode()).hexdigest()[:8] + logger.error( + "Backend generation failed [error_id=%s, request_id=%s]: %s", + error_id, + gen_request.request_id, + result.error, + ) + raise HTTPException(status_code=500, detail=f"Internal server error (error_id={error_id})") + + # Build OpenAI-compatible response + response_id = f"chatcmpl-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}" + message_content = result.text or "" + + # Save outputs if AUDIO_SAVE_DIR is set + save_dir = os.environ.get("AUDIO_SAVE_DIR", "") + if save_dir: + try: + os.makedirs(save_dir, exist_ok=True) + except Exception as e: + error_id = hashlib.md5(f"{time.time_ns()}".encode()).hexdigest()[:8] + logger.error( + "Failed to prepare AUDIO_SAVE_DIR [error_id=%s, save_dir=%s]: %s", + error_id, + save_dir, + e, + ) + raise HTTPException(status_code=500, detail=f"Internal server error (error_id={error_id})") + saved_audio_path = None + save_failures = [] + + if save_dir: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + base_filename = f"response_{timestamp}_{response_id}" + + try: + saved_json_path = os.path.join(save_dir, f"{base_filename}.json") + json_output = { + "response_id": response_id, + "timestamp": timestamp, + "text": message_content, + "debug_info": result.debug_info, + "generation_time_ms": result.generation_time_ms, + "num_tokens_generated": result.num_tokens_generated, + } + with open(saved_json_path, "w") as f: + json.dump(json_output, f, indent=2) + except Exception as e: + save_failures.append(f"json:{type(e).__name__}:{e}") + + if result.audio_bytes: + try: + saved_audio_path = os.path.join(save_dir, f"{base_filename}.wav") + with open(saved_audio_path, "wb") as f: + f.write(result.audio_bytes) + except Exception as e: + save_failures.append(f"audio:{type(e).__name__}:{e}") + + if save_failures: + error_id = hashlib.md5(f"{time.time_ns()}".encode()).hexdigest()[:8] + logger.error( + "Failed to save response artifacts [error_id=%s, request_id=%s, response_id=%s, save_dir=%s, failures=%s]", + error_id, + gen_request.request_id, + response_id, + save_dir, + save_failures, + ) + raise HTTPException(status_code=500, detail=f"Internal server error (error_id={error_id})") + + # Build audio output if available + audio_output = None + if result.audio_bytes: + audio_output = { + "data": base64.b64encode(result.audio_bytes).decode("utf-8"), + "format": result.audio_format or "wav", + "sample_rate": result.audio_sample_rate, + "expires_at": int(time.time()) + 3600, + "transcript": result.text or "", + } + + # Embed debug_info in content as JSON + final_content = message_content + if result.debug_info: + final_content = f"{message_content}\n{json.dumps(result.debug_info)}" + + response = { + "id": response_id, + "object": "chat.completion", + "created": int(time.time()), + "model": server_config.get("model_path"), + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": final_content, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": -1, + "completion_tokens": result.num_tokens_generated or -1, + "total_tokens": -1, + }, + } + + if audio_output: + response["choices"][0]["message"]["audio"] = audio_output + + if result.debug_info: + response["debug_info"] = result.debug_info + + if saved_audio_path: + response["saved_audio_path"] = saved_audio_path + + return response + + except HTTPException: + raise + except Exception: + error_id = hashlib.md5(f"{time.time_ns()}".encode()).hexdigest()[:8] + logger.exception("Unhandled chat completion error [error_id=%s]", error_id) + raise HTTPException(status_code=500, detail=f"Internal server error (error_id={error_id})") + + return app diff --git a/tests/slurm-tests/run_all.sh b/tests/slurm-tests/run_all.sh index 200cf5e3fc..8f1ebfda4c 100755 --- a/tests/slurm-tests/run_all.sh +++ b/tests/slurm-tests/run_all.sh @@ -14,4 +14,8 @@ python tests/slurm-tests/qwen3_4b_evals/run_test.py --cluster $CLUSTER --workspa python tests/slurm-tests/omr_simple_recipe/run_test.py --cluster $CLUSTER --workspace /workspace/nemo-skills-slurm-ci/$RUN_NAME/omr_simple_recipe/nemo-rl --expname_prefix omr_simple_recipe_nemo_rl_$RUN_NAME # sleep 10 python tests/slurm-tests/qwen3coder_30b_swebench/run_test.py --cluster $CLUSTER --workspace /workspace/nemo-skills-slurm-ci/$RUN_NAME/qwen3coder_30b_swebench --expname_prefix qwen3coder_30b_swebench_$RUN_NAME --container_formatter '/swe-bench-images/swebench_sweb.eval.x86_64.{instance_id}.sif' +# sleep 10 +python tests/slurm-tests/unified_asr/run_test.py --cluster $CLUSTER --workspace /workspace/nemo-skills-slurm-ci/$RUN_NAME/unified_asr --expname_prefix unified_asr_$RUN_NAME +# sleep 10 +python tests/slurm-tests/unified_tts/run_test.py --cluster $CLUSTER --workspace /workspace/nemo-skills-slurm-ci/$RUN_NAME/unified_tts --expname_prefix unified_tts_$RUN_NAME # wait diff --git a/tests/slurm-tests/unified_asr/asr_openai.test b/tests/slurm-tests/unified_asr/asr_openai.test new file mode 100644 index 0000000000..b84b03f219 --- /dev/null +++ b/tests/slurm-tests/unified_asr/asr_openai.test @@ -0,0 +1,2 @@ +{"id": "sample_2", "_reference": "sample 2 this is a test of text to speech synthesis", "messages": [{"role": "user", "content": "Transcribe this audio accurately.", "audio": {"path": "/nemo_run/code/tests/slurm-tests/asr_nim/wavs/t2_16.wav"}}]} +{"id": "sample_3", "_reference": "sample 3 hello how are you today", "messages": [{"role": "user", "content": "Transcribe this audio accurately.", "audio": {"path": "/nemo_run/code/tests/slurm-tests/asr_nim/wavs/t3_16.wav"}}]} diff --git a/tests/slurm-tests/unified_asr/check_results.py b/tests/slurm-tests/unified_asr/check_results.py new file mode 100644 index 0000000000..b5c2c6e5b6 --- /dev/null +++ b/tests/slurm-tests/unified_asr/check_results.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import re +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) # for utils.py +from utils import assert_all, soft_assert # noqa: E402 + +_DIGIT_TO_WORD = { + "0": "zero", + "1": "one", + "2": "two", + "3": "three", + "4": "four", + "5": "five", + "6": "six", + "7": "seven", + "8": "eight", + "9": "nine", +} + + +def normalize_text(text: str) -> str: + text = text.lower().replace("-", " ") + text = re.sub(r"[^\w\s]", "", text) + tokens = [] + for token in text.split(): + if token in _DIGIT_TO_WORD: + tokens.append(_DIGIT_TO_WORD[token]) + else: + tokens.append(token) + return " ".join(tokens) + + +def load_references() -> dict[str, str]: + container_path = Path("/nemo_run/code/tests/slurm-tests/unified_asr/asr_openai.test") + local_path = Path(__file__).resolve().parent / "asr_openai.test" + reference_path = container_path if container_path.exists() else local_path + refs: dict[str, str] = {} + with reference_path.open("rt", encoding="utf-8") as fin: + for line in fin: + if not line.strip(): + continue + row = json.loads(line) + refs[row["id"]] = row["_reference"] + return refs + + +def load_outputs(output_dir: Path) -> list[dict]: + rows: list[dict] = [] + files = sorted(output_dir.glob("output*.jsonl")) + soft_assert(len(files) > 0, f"No output JSONL files found in {output_dir}") + for fpath in files: + with fpath.open("rt", encoding="utf-8") as fin: + for line in fin: + if line.strip(): + rows.append(json.loads(line)) + return rows + + +def check_asr_results(workspace: str): + output_dir = Path(workspace) / "asr_outputs" + soft_assert(output_dir.exists(), f"Missing output directory: {output_dir}") + if not output_dir.exists(): + return + + references = load_references() + rows = load_outputs(output_dir) + + soft_assert(len(rows) == len(references), f"Expected {len(references)} outputs, found {len(rows)}") + + found_debug_words = 0 + for row in rows: + sample_id = row.get("id") + soft_assert(sample_id in references, f"Unexpected sample id in output: {sample_id}") + if sample_id not in references: + continue + + transcript = (row.get("generation") or "").strip() + soft_assert(bool(transcript), f"Empty transcript for sample {sample_id}") + if not transcript: + continue + + ref_words = set(normalize_text(references[sample_id]).split()) + hyp_words = set(normalize_text(transcript).split()) + missing = sorted(ref_words - hyp_words) + soft_assert(not missing, f"Sample {sample_id}: missing reference words: {', '.join(missing)}") + + debug_info = row.get("debug_info") + if isinstance(debug_info, dict) and isinstance(debug_info.get("words"), list) and debug_info["words"]: + found_debug_words += 1 + + soft_assert(found_debug_words > 0, "No outputs contained debug_info.words from nemo_asr backend") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", required=True, help="Workspace directory containing results") + args = parser.parse_args() + + check_asr_results(args.workspace) + assert_all() + + +if __name__ == "__main__": + main() diff --git a/tests/slurm-tests/unified_asr/run_test.py b/tests/slurm-tests/unified_asr/run_test.py new file mode 100644 index 0000000000..701bba8117 --- /dev/null +++ b/tests/slurm-tests/unified_asr/run_test.py @@ -0,0 +1,122 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from nemo_skills.pipeline.cli import generate, run_cmd, wrap_arguments +from nemo_skills.pipeline.utils import create_remote_directory, get_cluster_config + +DEFAULT_SERVER_CONTAINER = "nvcr.io/nvidia/nemo:25.11" +DEFAULT_MODEL = "nvidia/parakeet-tdt-0.6b-v2" +DEFAULT_INSTALLATION_COMMAND = ( + "pip install func_timeout 'compute-eval @ git+https://github.com/NVIDIA/compute-eval.git@2d14770'" +) + + +def ensure_workspace_exists(workspace: str, cluster: str, config_dir: str | None = None) -> None: + cluster_config = get_cluster_config(cluster, config_dir=config_dir) + create_remote_directory(workspace, cluster_config) + + +def run_unified_asr_test( + workspace: str, + cluster: str, + expname_prefix: str, + server_container: str, + model: str, + config_dir: str | None = None, + installation_command: str | None = DEFAULT_INSTALLATION_COMMAND, +) -> str: + input_file = "/nemo_run/code/tests/slurm-tests/unified_asr/asr_openai.test" + output_dir = f"{workspace}/asr_outputs" + mount_paths = f"{Path(workspace).parent}:{Path(workspace).parent}" + + generate( + ctx=wrap_arguments( + "++prompt_format=openai " + "++prompt_config=null " + "++enable_audio=true " + "++server.server_type=vllm_multimodal " + "++max_concurrent_requests=2 " + "++inference.temperature=0.0 " + "++inference.top_p=1.0 " + "++inference.top_k=-1 " + "++inference.tokens_to_generate=256" + ), + cluster=cluster, + generation_module="nemo_skills.inference.generate", + input_file=input_file, + output_dir=output_dir, + model=model, + server_type="generic", + num_chunks=1, + server_gpus=1, + server_nodes=1, + server_entrypoint="HF_HUB_OFFLINE=0 python -m nemo_skills.inference.server.serve_unified", + server_container=server_container, + server_args="--backend nemo_asr --batch_size 2", + mount_paths=mount_paths, + config_dir=config_dir, + installation_command=installation_command, + expname=expname_prefix, + ) + return expname_prefix + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", required=True, help="Workspace directory containing all experiment data") + parser.add_argument("--cluster", required=True, help="Cluster name") + parser.add_argument("--expname_prefix", required=True, help="Experiment name prefix") + parser.add_argument("--model", default=DEFAULT_MODEL, help="ASR model path/name") + parser.add_argument("--server_container", default=DEFAULT_SERVER_CONTAINER, help="Container image for server job") + parser.add_argument("--config_dir", default=None, help="Optional directory containing cluster config YAMLs") + parser.add_argument( + "--installation_command", + default=DEFAULT_INSTALLATION_COMMAND, + help="Optional install command for generation container bootstrap", + ) + parser.add_argument("--skip_check", action="store_true", help="Skip scheduling results checker") + args = parser.parse_args() + + ensure_workspace_exists(args.workspace, args.cluster, config_dir=args.config_dir) + + asr_expname = run_unified_asr_test( + workspace=args.workspace, + cluster=args.cluster, + expname_prefix=args.expname_prefix, + server_container=args.server_container, + model=args.model, + config_dir=args.config_dir, + installation_command=args.installation_command, + ) + + if args.skip_check: + return + + checker_cmd = f"python tests/slurm-tests/unified_asr/check_results.py --workspace {args.workspace}" + run_cmd( + ctx=wrap_arguments(checker_cmd), + cluster=args.cluster, + expname=f"{args.expname_prefix}-check-results", + log_dir=f"{args.workspace}/check-results-logs", + mount_paths=f"{Path(args.workspace).parent}:{Path(args.workspace).parent}", + config_dir=args.config_dir, + run_after=asr_expname, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/slurm-tests/unified_tts/README.md b/tests/slurm-tests/unified_tts/README.md new file mode 100644 index 0000000000..9676aab176 --- /dev/null +++ b/tests/slurm-tests/unified_tts/README.md @@ -0,0 +1,43 @@ +# Unified TTS Slurm Test + +`run_test.py` defaults to: + +- `--server_container nvcr.io/nvidia/nemo:25.11` +- `--model nvidia/magpie_tts_multilingual_357m` +- `--codec_model nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps` + +## Temporary note about `--code_path` + +The current `magpie_tts` unified backend imports `magpietts_inference` from NeMo. +In current test environments, this module is not consistently available from the stock container alone. + +For now, pass a NeMo code tree explicitly via `--code_path`. +When a newer NeMo image includes the required Magpie TTS modules, this manual `--code_path` override should no longer be necessary. + +## Local image usage (recommended for current cluster runs) + +Keep the default in code as NVCR, but override `--server_container` at runtime. + +### DFW example + +```bash +python tests/slurm-tests/unified_tts/run_test.py \ + --cluster dfw \ + --config_dir "$PWD/cluster_configs" \ + --workspace /lustre/fsw/portfolios/convai/users//experiments/dialog_scripts2tts/unified_tts_dfw \ + --expname_prefix unified_tts_dfw \ + --server_container /lustre/fsw/portfolios/convai/users//workspace/images/nemo-25.11.sqsh \ + --code_path /lustre/fsw/portfolios/convai/users//workspace/code/NeMo_tts +``` + +### IAD (draco_oci) example + +```bash +python tests/slurm-tests/unified_tts/run_test.py \ + --cluster iad \ + --config_dir "$PWD/cluster_configs" \ + --workspace /lustre/fsw/portfolios/llmservice/users//experiments/dialog_scripts2tts/unified_tts_iad \ + --expname_prefix unified_tts_iad \ + --server_container /lustre/fsw/portfolios/llmservice/users//workspace/images/nemo-25.11.sqsh \ + --code_path /lustre/fsw/portfolios/llmservice/users//workspace/code/NeMo_tts +``` diff --git a/tests/slurm-tests/unified_tts/check_results.py b/tests/slurm-tests/unified_tts/check_results.py new file mode 100644 index 0000000000..a87f9fe228 --- /dev/null +++ b/tests/slurm-tests/unified_tts/check_results.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) # for utils.py +from utils import assert_all, soft_assert # noqa: E402 + +EXPECTED_NUM_SAMPLES = 6 + + +def load_outputs(output_dir: Path) -> list[dict]: + rows: list[dict] = [] + files = sorted(output_dir.glob("output*.jsonl")) + soft_assert(len(files) > 0, f"No output JSONL files found in {output_dir}") + for fpath in files: + with fpath.open("rt", encoding="utf-8") as fin: + for line in fin: + if line.strip(): + rows.append(json.loads(line)) + return rows + + +def resolve_audio_path(audio_path: str, workspace: str) -> Path: + path = Path(audio_path) + if path.is_absolute(): + return path + return Path(workspace) / "tts_outputs" / "audio" / path.name + + +def check_tts_results(workspace: str): + output_dir = Path(workspace) / "tts_outputs" + soft_assert(output_dir.exists(), f"Missing output directory: {output_dir}") + if not output_dir.exists(): + return + + rows = load_outputs(output_dir) + soft_assert(len(rows) == EXPECTED_NUM_SAMPLES, f"Expected {EXPECTED_NUM_SAMPLES} outputs, found {len(rows)}") + + for row in rows: + sample_id = row.get("id", "") + audio_info = row.get("audio") + soft_assert(isinstance(audio_info, dict), f"Missing 'audio' block in output row {sample_id}") + if not isinstance(audio_info, dict): + continue + + audio_path = audio_info.get("path") + soft_assert(bool(audio_path), f"Missing audio path for row {sample_id}") + if not audio_path: + continue + + resolved = resolve_audio_path(audio_path, workspace) + soft_assert(resolved.exists(), f"Audio file does not exist for {sample_id}: {resolved}") + if not resolved.exists(): + continue + + size_bytes = resolved.stat().st_size + soft_assert(size_bytes > 0, f"Audio file is empty for {sample_id}: {resolved}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", required=True, help="Workspace directory containing results") + args = parser.parse_args() + + check_tts_results(args.workspace) + assert_all() + + +if __name__ == "__main__": + main() diff --git a/tests/slurm-tests/unified_tts/run_test.py b/tests/slurm-tests/unified_tts/run_test.py new file mode 100644 index 0000000000..9671bc8a3b --- /dev/null +++ b/tests/slurm-tests/unified_tts/run_test.py @@ -0,0 +1,139 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from nemo_skills.pipeline.cli import generate, run_cmd, wrap_arguments +from nemo_skills.pipeline.utils import create_remote_directory, get_cluster_config + +DEFAULT_SERVER_CONTAINER = "nvcr.io/nvidia/nemo:25.11" +DEFAULT_MODEL = "nvidia/magpie_tts_multilingual_357m" +DEFAULT_CODEC = "nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps" +DEFAULT_INSTALLATION_COMMAND = ( + "pip install func_timeout 'compute-eval @ git+https://github.com/NVIDIA/compute-eval.git@2d14770'" +) + + +def ensure_workspace_exists(workspace: str, cluster: str, config_dir: str | None = None) -> None: + cluster_config = get_cluster_config(cluster, config_dir=config_dir) + create_remote_directory(workspace, cluster_config) + + +def run_unified_tts_test( + workspace: str, + cluster: str, + expname_prefix: str, + server_container: str, + model: str, + codec_model: str, + config_dir: str | None = None, + installation_command: str | None = DEFAULT_INSTALLATION_COMMAND, + code_path: str | None = None, +) -> str: + input_file = "/nemo_run/code/tests/slurm-tests/unified_tts/tts_openai.test" + output_dir = f"{workspace}/tts_outputs" + mount_paths = f"{Path(workspace).parent}:{Path(workspace).parent}" + server_args = f"--backend magpie_tts --codec_model {codec_model} --batch_size 6" + + if code_path: + # Allow passing one or more colon-separated code roots. + code_paths = [p for p in code_path.split(":") if p] + mount_paths = ",".join([mount_paths] + [f"{p}:{p}" for p in code_paths]) + server_args += f" --code_path {code_path}" + + generate( + ctx=wrap_arguments( + "++prompt_format=openai " + "++prompt_config=null " + "++server.server_type=vllm_multimodal " + "++max_concurrent_requests=2 " + "++inference.temperature=0.0 " + "++inference.top_p=1.0 " + "++inference.top_k=-1 " + "++inference.tokens_to_generate=256" + ), + cluster=cluster, + generation_module="nemo_skills.inference.generate", + input_file=input_file, + output_dir=output_dir, + model=model, + server_type="generic", + num_chunks=1, + server_gpus=1, + server_nodes=1, + server_entrypoint="HF_HUB_OFFLINE=0 python -m nemo_skills.inference.server.serve_unified", + server_container=server_container, + server_args=server_args, + mount_paths=mount_paths, + config_dir=config_dir, + installation_command=installation_command, + expname=expname_prefix, + ) + return expname_prefix + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", required=True, help="Workspace directory containing all experiment data") + parser.add_argument("--cluster", required=True, help="Cluster name") + parser.add_argument("--expname_prefix", required=True, help="Experiment name prefix") + parser.add_argument("--model", default=DEFAULT_MODEL, help="Magpie model path/name") + parser.add_argument("--codec_model", default=DEFAULT_CODEC, help="Codec model path/name") + parser.add_argument("--server_container", default=DEFAULT_SERVER_CONTAINER, help="Container image for server job") + parser.add_argument("--config_dir", default=None, help="Optional directory containing cluster config YAMLs") + parser.add_argument( + "--code_path", + default=None, + help="Optional colon-separated path(s) to prepend to PYTHONPATH in server container", + ) + parser.add_argument( + "--installation_command", + default=DEFAULT_INSTALLATION_COMMAND, + help="Optional install command for generation container bootstrap", + ) + parser.add_argument("--skip_check", action="store_true", help="Skip scheduling results checker") + args = parser.parse_args() + + ensure_workspace_exists(args.workspace, args.cluster, config_dir=args.config_dir) + + tts_expname = run_unified_tts_test( + workspace=args.workspace, + cluster=args.cluster, + expname_prefix=args.expname_prefix, + server_container=args.server_container, + model=args.model, + codec_model=args.codec_model, + config_dir=args.config_dir, + installation_command=args.installation_command, + code_path=args.code_path, + ) + + if args.skip_check: + return + + checker_cmd = f"python tests/slurm-tests/unified_tts/check_results.py --workspace {args.workspace}" + run_cmd( + ctx=wrap_arguments(checker_cmd), + cluster=args.cluster, + expname=f"{args.expname_prefix}-check-results", + log_dir=f"{args.workspace}/check-results-logs", + mount_paths=f"{Path(args.workspace).parent}:{Path(args.workspace).parent}", + config_dir=args.config_dir, + run_after=tts_expname, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/slurm-tests/unified_tts/tts_openai.test b/tests/slurm-tests/unified_tts/tts_openai.test new file mode 100644 index 0000000000..65206d69b7 --- /dev/null +++ b/tests/slurm-tests/unified_tts/tts_openai.test @@ -0,0 +1,6 @@ +{"id": "sample_1", "messages": [{"role": "user", "content": "Sample one. Hello, how are you today?"}]} +{"id": "sample_2", "messages": [{"role": "user", "content": "Sample two. This is a test of text-to-speech synthesis."}]} +{"id": "sample_3", "messages": [{"role": "user", "content": "Sample three. Hello, how are you today?"}]} +{"id": "sample_4", "messages": [{"role": "user", "content": "Sample four. This is a test of text-to-speech synthesis."}]} +{"id": "sample_5", "messages": [{"role": "user", "content": "Sample five. Hello, how are you today?"}]} +{"id": "sample_6", "messages": [{"role": "user", "content": "Sample six. This is a test of text-to-speech synthesis."}]} diff --git a/tests/test_magpie_tts_backend.py b/tests/test_magpie_tts_backend.py new file mode 100644 index 0000000000..a6d9ac0b2e --- /dev/null +++ b/tests/test_magpie_tts_backend.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + +from recipes.multimodal.server.backends.magpie_tts_backend import MagpieTTSBackend, MagpieTTSConfig + + +def test_context_audio_path_is_disabled_without_allowlist(tmp_path: Path): + backend = MagpieTTSBackend( + MagpieTTSConfig(model_path="dummy", codec_model_path="dummy", context_audio_allowed_roots=None) + ) + target = tmp_path / "ctx.wav" + target.write_bytes(b"fake") + + with pytest.raises(ValueError, match="context_audio_filepath is disabled"): + backend._resolve_context_audio_path(str(target)) + + +def test_context_audio_path_must_be_under_allowed_roots(tmp_path: Path): + allowed = tmp_path / "allowed" + outside = tmp_path / "outside" + allowed.mkdir() + outside.mkdir() + in_root = allowed / "in_root.wav" + in_root.write_bytes(b"fake") + out_of_root = outside / "out_of_root.wav" + out_of_root.write_bytes(b"fake") + + backend = MagpieTTSBackend( + MagpieTTSConfig( + model_path="dummy", + codec_model_path="dummy", + context_audio_allowed_roots=[str(allowed)], + ) + ) + + assert backend._resolve_context_audio_path(str(in_root)) == str(in_root.resolve()) + with pytest.raises(ValueError, match="outside allowed roots"): + backend._resolve_context_audio_path(str(out_of_root)) diff --git a/tests/test_nemo_asr_backend.py b/tests/test_nemo_asr_backend.py new file mode 100644 index 0000000000..f6ad5c287d --- /dev/null +++ b/tests/test_nemo_asr_backend.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from recipes.multimodal.server.backends import GenerationRequest +from recipes.multimodal.server.backends.nemo_asr_backend import NeMoASRBackend, NeMoASRConfig + + +class _FakeHypothesis: + def __init__(self, text: str): + self.text = text + self.words = [{"word": text, "start_time": 0.0, "end_time": 0.2, "confidence": 1.0}] + + +class _FakeTimestampHypothesis: + def __init__(self): + self.text = "hello world" + self.words = ["hello", "world"] + self.timestamp = { + "word": [ + {"word": "hello", "start": 0.1, "end": 0.5}, + {"word": "world", "start": 0.5, "end": 0.9}, + ] + } + + +class _FakeASRModel: + def __init__(self): + self.calls = [] + + def transcribe(self, audio=None, **kwargs): + self.calls.append((audio, kwargs)) + return [_FakeHypothesis(f"transcript_{idx}") for idx, _ in enumerate(audio)] + + +def test_nemo_asr_backend_validate_request_requires_audio(): + backend = NeMoASRBackend(NeMoASRConfig(model_path="dummy")) + err = backend.validate_request(GenerationRequest(text="x")) + assert err is not None + + +def test_generation_params_preserve_explicit_zero_values(): + backend = NeMoASRBackend( + NeMoASRConfig(model_path="dummy", max_new_tokens=128, temperature=0.8, top_p=0.95, top_k=40) + ) + + params = backend.get_generation_params(GenerationRequest(temperature=0.0, top_p=0.0, top_k=0)) + + assert params["max_new_tokens"] == 128 + assert params["temperature"] == 0.0 + assert params["top_p"] == 0.0 + assert params["top_k"] == 0 + + +def test_nemo_asr_backend_generate_batched_with_words(): + backend = NeMoASRBackend(NeMoASRConfig(model_path="dummy", batch_size=4)) + backend._model = _FakeASRModel() + backend._is_loaded = True + + reqs = [ + GenerationRequest( + audio_bytes=b"RIFF" + b"\x00" * 64, request_id="r1", extra_params={"return_hypotheses": True} + ), + GenerationRequest( + audio_bytes=b"RIFF" + b"\x00" * 64, request_id="r2", extra_params={"return_hypotheses": True} + ), + ] + results = backend.generate(reqs) + + assert len(results) == 2 + assert results[0].text == "transcript_0" + assert results[1].text == "transcript_1" + assert results[0].debug_info["words"][0]["word"] == "transcript_0" + assert results[1].debug_info["words"][0]["word"] == "transcript_1" + assert results[0].request_id == "r1" + assert results[1].request_id == "r2" + + +def test_nemo_asr_backend_prefers_timestamp_words_when_words_are_strings(): + backend = NeMoASRBackend(NeMoASRConfig(model_path="dummy")) + text, words = backend._parse_single_hypothesis(_FakeTimestampHypothesis()) + + assert text == "hello world" + assert words == [ + {"word": "hello", "start_time": 0.1, "end_time": 0.5, "confidence": None}, + {"word": "world", "start_time": 0.5, "end_time": 0.9, "confidence": None}, + ] diff --git a/tests/test_unified_server_audio_parser.py b/tests/test_unified_server_audio_parser.py new file mode 100644 index 0000000000..d01aa07ee7 --- /dev/null +++ b/tests/test_unified_server_audio_parser.py @@ -0,0 +1,120 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for unified_server audio extraction from chat-completion messages.""" + +import base64 +import importlib + +import pytest + +pytest.importorskip("fastapi") +pytest.importorskip("uvicorn") + +extract_audio_from_messages = importlib.import_module( + "recipes.multimodal.server.unified_server" +).extract_audio_from_messages +extract_text_from_messages = importlib.import_module( + "recipes.multimodal.server.unified_server" +).extract_text_from_messages + + +def _b64(data: bytes) -> str: + return base64.b64encode(data).decode("utf-8") + + +def test_extract_audio_from_messages_audio_url_only(): + raw = b"audio-from-audio-url" + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": f"data:audio/wav;base64,{_b64(raw)}"}, + } + ], + } + ] + + assert extract_audio_from_messages(messages) == [raw] + + +def test_extract_audio_from_messages_input_audio_only(): + raw = b"audio-from-input-audio" + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {"data": _b64(raw), "format": "wav"}, + } + ], + } + ] + + assert extract_audio_from_messages(messages) == [raw] + + +def test_extract_audio_from_messages_mixed_order_is_preserved(): + first = b"first-audio" + second = b"second-audio" + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {"data": _b64(first), "format": "wav"}, + }, + { + "type": "audio_url", + "audio_url": {"url": f"data:audio/wav;base64,{_b64(second)}"}, + }, + ], + } + ] + + assert extract_audio_from_messages(messages) == [first, second] + + +def test_extract_audio_from_messages_skips_non_audio_or_malformed_blocks(): + valid = b"valid-audio" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "ignore me"}, + {"type": "audio_url", "audio_url": {"url": "http://example.com/a.wav"}}, + {"type": "input_audio", "input_audio": {"format": "wav"}}, + { + "type": "audio_url", + "audio_url": {"url": f"data:audio/wav;base64,{_b64(valid)}"}, + }, + ], + } + ] + + assert extract_audio_from_messages(messages) == [valid] + + +def test_extract_text_from_messages_ignores_system_role(): + messages = [ + {"role": "system", "content": "You are a strict assistant."}, + {"role": "user", "content": "Say hello."}, + {"role": "assistant", "content": "Hello."}, + ] + + assert extract_text_from_messages(messages) == "Say hello. Hello." diff --git a/tests/test_unified_server_batcher.py b/tests/test_unified_server_batcher.py new file mode 100644 index 0000000000..d69fdf1b9e --- /dev/null +++ b/tests/test_unified_server_batcher.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import importlib + +import pytest + +pytest.importorskip("fastapi") +pytest.importorskip("uvicorn") + +unified_server = importlib.import_module("recipes.multimodal.server.unified_server") +GenerationRequest = importlib.import_module("recipes.multimodal.server.backends").GenerationRequest + + +class _MismatchedBackend: + def generate(self, requests): + del requests + return [] + + +def test_request_batcher_fails_on_batch_result_length_mismatch(): + async def _run(): + batcher = unified_server.RequestBatcher(_MismatchedBackend(), batch_size=1, batch_timeout=0) + with pytest.raises(RuntimeError, match="Backend returned 0 results for 1 requests"): + await batcher.add_request(GenerationRequest(text="hello")) + + asyncio.run(_run()) diff --git a/tests/test_unified_server_error_handling.py b/tests/test_unified_server_error_handling.py new file mode 100644 index 0000000000..302c22d6b8 --- /dev/null +++ b/tests/test_unified_server_error_handling.py @@ -0,0 +1,129 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + +import pytest + +pytest.importorskip("fastapi") +pytest.importorskip("uvicorn") + +from fastapi.testclient import TestClient + +from recipes.multimodal.server.backends import ( + BackendConfig, + GenerationRequest, + GenerationResult, + InferenceBackend, + Modality, +) + +unified_server = importlib.import_module("recipes.multimodal.server.unified_server") + + +class _ErrorBackend(InferenceBackend): + @classmethod + def get_config_class(cls) -> type: + return BackendConfig + + @property + def name(self) -> str: + return "error_backend" + + @property + def supported_modalities(self): + return {Modality.TEXT} + + def load_model(self) -> None: + self._is_loaded = True + + def generate(self, requests: list[GenerationRequest]) -> list[GenerationResult]: + return [ + GenerationResult(error=f"sensitive backend error for {request.request_id}: /tmp/secret/path") + for request in requests + ] + + +class _OkBackend(InferenceBackend): + @classmethod + def get_config_class(cls) -> type: + return BackendConfig + + @property + def name(self) -> str: + return "ok_backend" + + @property + def supported_modalities(self): + return {Modality.TEXT} + + def load_model(self) -> None: + self._is_loaded = True + + def generate(self, requests: list[GenerationRequest]) -> list[GenerationResult]: + del requests + return [GenerationResult(text="ok")] + + +def test_chat_completion_does_not_leak_raw_backend_error(monkeypatch): + monkeypatch.setattr(unified_server, "get_backend", lambda backend_type: _ErrorBackend) + app = unified_server.create_app( + backend_type="error_backend", + config_dict={"model_path": "dummy"}, + batch_size=1, + batch_timeout=0, + ) + + with TestClient(app) as client: + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + + assert response.status_code == 500 + detail = response.json()["detail"] + assert detail.startswith("Internal server error (error_id=") + assert "/tmp/secret/path" not in detail + + +def test_chat_completion_returns_500_if_audio_save_dir_cannot_be_prepared(monkeypatch): + monkeypatch.setattr(unified_server, "get_backend", lambda backend_type: _OkBackend) + monkeypatch.setenv("AUDIO_SAVE_DIR", "/forbidden/save/path") + + original_makedirs = unified_server.os.makedirs + + def _fail_makedirs(path, exist_ok=False): + if path == "/forbidden/save/path": + raise PermissionError("no write permission") + return original_makedirs(path, exist_ok=exist_ok) + + monkeypatch.setattr(unified_server.os, "makedirs", _fail_makedirs) + + app = unified_server.create_app( + backend_type="ok_backend", + config_dict={"model_path": "dummy"}, + batch_size=1, + batch_timeout=0, + ) + + with TestClient(app) as client: + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + + assert response.status_code == 500 + detail = response.json()["detail"] + assert detail.startswith("Internal server error (error_id=") + assert "no write permission" not in detail