diff --git a/nemo_skills/inference/server/serve_unified.py b/nemo_skills/inference/server/serve_unified.py new file mode 100644 index 0000000000..48e3f34755 --- /dev/null +++ b/nemo_skills/inference/server/serve_unified.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CLI entrypoint for the Unified NeMo Inference Server. + +Configuration is YAML-based: provide a config file with backend type and +all backend-specific parameters. The config is validated against the +backend's config class. + +Usage: + python -m nemo_skills.inference.server.serve_unified \\ + --config /path/to/backend_config.yaml \\ + --port 8000 + + # Or with --model for nemo-skills pipeline compatibility: + python -m nemo_skills.inference.server.serve_unified \\ + --model /path/to/model \\ + --backend magpie_tts \\ + --codec_model /path/to/codec \\ + --port 8000 + +Backend-specific options are passed as extra CLI flags and forwarded to the +backend's config dataclass automatically. For example: + + --server_args "--backend magpie_tts --codec_model /path --use_cfg --cfg_scale 2.5" + +Any flag not recognized by the server itself is parsed generically: + --flag -> {"flag": True} + --key value -> {"key": } + --key=value -> {"key": } + --no_flag -> {"flag": False} + +See each backend's config class for available options (e.g. MagpieTTSConfig). + +Example YAML config (backend_config.yaml): + backend: magpie_tts + model_path: /path/to/model + codec_model_path: /path/to/codec + device: cuda + dtype: bfloat16 + temperature: 0.6 + top_k: 80 + use_cfg: true + cfg_scale: 2.5 +""" + +import argparse +import inspect +import os +import shutil +import sys +from typing import Optional + + +def setup_pythonpath(code_path: Optional[str] = None): + """Set up PYTHONPATH for NeMo and the unified server. + + Args: + code_path: Single path or colon-separated paths to add to PYTHONPATH + """ + paths_to_add = [] + + if code_path: + for path in code_path.split(":"): + if path and path not in paths_to_add: + paths_to_add.append(path) + + # Add recipes path for unified server imports + this_dir = os.path.dirname(os.path.abspath(__file__)) + ns_eval_root = os.path.dirname(os.path.dirname(os.path.dirname(this_dir))) + if os.path.exists(os.path.join(ns_eval_root, "recipes")): + paths_to_add.append(ns_eval_root) + + # Container pattern + if os.path.exists("/nemo_run/code"): + paths_to_add.append("/nemo_run/code") + + current_path = os.environ.get("PYTHONPATH", "") + for path in paths_to_add: + if path not in current_path.split(":"): + current_path = f"{path}:{current_path}" if current_path else path + + os.environ["PYTHONPATH"] = current_path + + for path in paths_to_add: + if path not in sys.path: + sys.path.insert(0, path) + + +def apply_safetensors_patch(hack_path: Optional[str]): + """Apply safetensors patch if provided (for some NeMo models).""" + if not hack_path or not os.path.exists(hack_path): + return + + try: + import safetensors.torch as st_torch + + dest_path = inspect.getfile(st_torch) + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + shutil.copyfile(hack_path, dest_path) + print(f"[serve_unified] Applied safetensors patch: {hack_path} -> {dest_path}") + except Exception as e: + print(f"[serve_unified] Warning: Failed to apply safetensors patch: {e}") + + +def load_yaml_config(config_path: str) -> dict: + """Load YAML config file.""" + import yaml + + with open(config_path) as f: + return yaml.safe_load(f) + + +def _coerce_value(value: str): + """Try to coerce a string value to int, float, or bool.""" + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + if value.lower() == "true": + return True + if value.lower() == "false": + return False + return value + + +def parse_extra_args(extra_args: list) -> dict: + """Convert unknown CLI args to a config dict. + + Handles these patterns: + --flag -> {"flag": True} + --key value -> {"key": } + --key=value -> {"key": } + --no_flag -> {"flag": False} (strip no_ prefix) + """ + result = {} + i = 0 + while i < len(extra_args): + arg = extra_args[i] + if not arg.startswith("--"): + i += 1 + continue + + # Handle --key=value + if "=" in arg: + key, value = arg[2:].split("=", 1) + result[key] = _coerce_value(value) + i += 1 + continue + + key = arg[2:] + + # Check if next token is a value (not another flag) + if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("--"): + result[key] = _coerce_value(extra_args[i + 1]) + i += 2 + continue + + # Bare flag: --no_X -> {X: False}, otherwise {key: True} + if key.startswith("no_"): + result[key[3:]] = False + else: + result[key] = True + i += 1 + + return result + + +def main(): + parser = argparse.ArgumentParser( + description="Unified NeMo Inference Server", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Primary: YAML config + parser.add_argument("--config", default=None, help="Path to YAML config file") + + # Standard args for nemo-skills pipeline compatibility + parser.add_argument("--model", default=None, help="Path to the model") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use") + parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument("--host", default="0.0.0.0", help="Server host") + parser.add_argument("--backend", default="magpie_tts", help="Backend type") + + # Server configuration + parser.add_argument("--batch_size", type=int, default=8, help="Maximum batch size") + parser.add_argument("--batch_timeout", type=float, default=0.1, help="Batch timeout in seconds") + + # Generation defaults + parser.add_argument("--max_new_tokens", type=int, default=512, help="Max tokens to generate") + parser.add_argument("--temperature", type=float, default=1.0, help="Generation temperature") + parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling") + + # Model configuration + parser.add_argument("--device", default="cuda", help="Device to use") + parser.add_argument("--dtype", default="bfloat16", help="Model dtype") + + # Environment setup + parser.add_argument("--code_path", default=None, help="Path to add to PYTHONPATH") + parser.add_argument("--hack_path", default=None, help="Path to safetensors patch") + + # Debug + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + + # Parse known args; everything else is backend-specific + args, unknown = parser.parse_known_args() + extra_config = parse_extra_args(unknown) + + # Setup environment + setup_pythonpath(args.code_path) + apply_safetensors_patch(args.hack_path) + + if args.code_path: + os.environ["UNIFIED_SERVER_CODE_PATH"] = args.code_path + + if args.debug: + os.environ["DEBUG"] = "1" + + # Set CUDA devices + if "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(args.num_gpus)) + + # Build configuration + if args.config: + # YAML config mode + config_dict = load_yaml_config(args.config) + backend_type = config_dict.pop("backend", args.backend) + # CLI overrides + if args.model: + config_dict["model_path"] = args.model + # Merge any extra CLI args into YAML config (CLI wins) + config_dict.update(extra_config) + else: + # CLI args mode (backward compatible) + if not args.model: + parser.error("--model is required when not using --config") + backend_type = args.backend + config_dict = { + "model_path": args.model, + "device": args.device, + "dtype": args.dtype, + "max_new_tokens": args.max_new_tokens, + "temperature": args.temperature, + "top_p": args.top_p, + } + # Merge backend-specific args from extra CLI flags + config_dict.update(extra_config) + + # Print configuration + print("=" * 60) + print("[serve_unified] Starting Unified NeMo Inference Server") + print("=" * 60) + print(f" Backend: {backend_type}") + print(f" Model: {config_dict.get('model_path', 'N/A')}") + print(f" Port: {args.port}") + print(f" GPUs: {args.num_gpus}") + print(f" Batch Size: {args.batch_size}") + print(f" Batch Timeout: {args.batch_timeout}s") + if args.config: + print(f" Config: {args.config}") + if extra_config: + print(f" Extra CLI Config: {extra_config}") + print("=" * 60) + + # Import and run + try: + import uvicorn + + from recipes.multimodal.server.unified_server import create_app + + app = create_app( + backend_type=backend_type, + config_dict=config_dict, + batch_size=args.batch_size, + batch_timeout=args.batch_timeout, + ) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + except ImportError as e: + print(f"[serve_unified] Error: Failed to import unified server: {e}") + print("[serve_unified] Make sure the recipes.multimodal.server package is in PYTHONPATH") + sys.exit(1) + except Exception as e: + print(f"[serve_unified] Error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/recipes/multimodal/__init__.py b/recipes/multimodal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipes/multimodal/server/__init__.py b/recipes/multimodal/server/__init__.py new file mode 100644 index 0000000000..7e6202751d --- /dev/null +++ b/recipes/multimodal/server/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified NeMo Inference Server package. + +Provides a pluggable FastAPI server that supports multiple NeMo model backends +through a backend-agnostic architecture. All backend-specific logic lives in +the backend modules under `backends/`. +""" + +from .backends import ( + BackendConfig, + GenerationRequest, + GenerationResult, + InferenceBackend, + get_backend, +) + +__all__ = [ + "InferenceBackend", + "GenerationRequest", + "GenerationResult", + "BackendConfig", + "get_backend", +] diff --git a/recipes/multimodal/server/backends/__init__.py b/recipes/multimodal/server/backends/__init__.py new file mode 100644 index 0000000000..03528ea736 --- /dev/null +++ b/recipes/multimodal/server/backends/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Backend implementations for the Unified NeMo Inference Server. + +Available backends: +- magpie_tts: MagpieTTS text-to-speech (audio output from text input) + +Backends are lazily loaded to avoid importing heavy dependencies upfront. +""" + +from .base import BackendConfig, GenerationRequest, GenerationResult, InferenceBackend, Modality + +__all__ = [ + "InferenceBackend", + "GenerationRequest", + "GenerationResult", + "BackendConfig", + "Modality", + "get_backend", + "list_backends", +] + +# Registry of available backends: name -> (module_name, class_name) +BACKEND_REGISTRY = { + "magpie_tts": ("magpie_tts_backend", "MagpieTTSBackend"), +} + + +def list_backends() -> list: + """Return list of available backend names.""" + return list(BACKEND_REGISTRY.keys()) + + +def get_backend(backend_name: str) -> type: + """Get backend class by name with lazy loading. + + Args: + backend_name: One of the registered backend names + + Returns: + Backend class (not instance) + + Raises: + ValueError: If backend name is unknown + ImportError: If backend dependencies are not available + """ + if backend_name not in BACKEND_REGISTRY: + available = ", ".join(BACKEND_REGISTRY.keys()) + raise ValueError(f"Unknown backend: '{backend_name}'. Available backends: {available}") + + module_name, class_name = BACKEND_REGISTRY[backend_name] + + import importlib + + try: + module = importlib.import_module(f".{module_name}", package=__name__) + return getattr(module, class_name) + except ImportError as e: + raise ImportError( + f"Failed to import backend '{backend_name}'. " + f"Make sure required dependencies are installed. Error: {e}" + ) from e diff --git a/recipes/multimodal/server/backends/base.py b/recipes/multimodal/server/backends/base.py new file mode 100644 index 0000000000..0f64804ff5 --- /dev/null +++ b/recipes/multimodal/server/backends/base.py @@ -0,0 +1,246 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Abstract base class for inference backends. + +All model backends must implement this interface to be usable with the +unified inference server. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set + + +class Modality(str, Enum): + """Supported input/output modalities.""" + + TEXT = "text" + AUDIO_IN = "audio_in" + AUDIO_OUT = "audio_out" + + +@dataclass +class BackendConfig: + """Base configuration for all backends. + + Subclasses should add their own fields. The from_dict() classmethod + handles extracting known fields and putting the rest in extra_config. + """ + + model_path: str = "" + device: str = "cuda" + dtype: str = "bfloat16" + + # Generation defaults + max_new_tokens: int = 512 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: Optional[int] = None + + # Additional model-specific configs passed through + extra_config: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "BackendConfig": + """Create config from dictionary, extracting known fields.""" + known_fields = {f.name for f in cls.__dataclass_fields__.values()} + known = {k: v for k, v in d.items() if k in known_fields and k != "extra_config"} + extra = {k: v for k, v in d.items() if k not in known_fields} + return cls(**known, extra_config=extra) + + +@dataclass +class GenerationRequest: + """A single generation request. + + Supports text and/or audio inputs depending on the backend's capabilities. + """ + + # Text inputs + text: Optional[str] = None + system_prompt: Optional[str] = None + user_prompt: Optional[str] = None + + # Audio input (raw bytes or file path) + audio_bytes: Optional[bytes] = None + audio_path: Optional[str] = None + sample_rate: int = 16000 + + # Multi-turn audio inputs (list of audio bytes or paths) + audio_bytes_list: Optional[List[bytes]] = None + audio_paths: Optional[List[str]] = None + + # Generation parameters (override backend defaults) + max_new_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + seed: Optional[int] = None + + # Additional parameters + extra_params: Dict[str, Any] = field(default_factory=dict) + + # Request tracking + request_id: Optional[str] = None + + +@dataclass +class GenerationResult: + """Result from a generation request. + + Contains text output and optionally audio output, plus metadata. + """ + + # Text output + text: str = "" + + # Audio output (raw bytes, can be encoded to base64 for JSON) + audio_bytes: Optional[bytes] = None + audio_sample_rate: int = 16000 + audio_format: str = "wav" + + # Metadata + request_id: Optional[str] = None + num_tokens_generated: int = 0 + generation_time_ms: float = 0.0 + + # Debug info (optional, backend-specific) + debug_info: Optional[Dict[str, Any]] = None + + # Error handling + error: Optional[str] = None + + def is_success(self) -> bool: + return self.error is None + + +class InferenceBackend(ABC): + """Abstract base class for inference backends. + + Implementations must provide: + - get_config_class(): Return the config dataclass for this backend + - load_model(): Initialize the model from config + - generate(): Run inference on a batch of requests + - supported_modalities: What input/output types are supported + + Optionally: + - get_extra_routes(): Return additional FastAPI routes this backend needs + (e.g., session management endpoints for S2S backends) + """ + + def __init__(self, config: BackendConfig): + self.config = config + self._model = None + self._is_loaded = False + + @classmethod + @abstractmethod + def get_config_class(cls) -> type: + """Return the config dataclass for this backend. + + The returned class must be a subclass of BackendConfig with a + from_dict() classmethod. + """ + ... + + @property + @abstractmethod + def name(self) -> str: + """Return the backend name (e.g., 'salm', 'magpie_tts').""" + ... + + @property + @abstractmethod + def supported_modalities(self) -> Set[Modality]: + """Return the set of supported modalities.""" + ... + + @abstractmethod + def load_model(self) -> None: + """Load and initialize the model. + + Should set self._model and self._is_loaded = True on success. + Called once during server startup. + """ + ... + + @abstractmethod + def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: + """Run inference on a batch of requests. + + Args: + requests: List of generation requests to process + + Returns: + List of generation results, one per request (same order) + """ + ... + + @classmethod + def get_extra_routes(cls, backend_instance: "InferenceBackend") -> list: + """Return additional FastAPI routes this backend needs. + + Override to register custom endpoints (e.g., session management). + Each item should be a dict with 'path', 'endpoint', 'methods'. + + Returns: + List of route dicts, empty by default. + """ + return [] + + @property + def is_loaded(self) -> bool: + """Check if the model is loaded and ready.""" + return self._is_loaded + + def health_check(self) -> Dict[str, Any]: + """Return health status information.""" + return { + "backend": self.name, + "model_loaded": self._is_loaded, + "model_path": self.config.model_path, + "device": self.config.device, + "modalities": [m.value for m in self.supported_modalities], + } + + def get_generation_params(self, request: GenerationRequest) -> Dict[str, Any]: + """Get effective generation parameters, merging request with config defaults.""" + return { + "max_new_tokens": request.max_new_tokens or self.config.max_new_tokens, + "temperature": request.temperature or self.config.temperature, + "top_p": request.top_p or self.config.top_p, + "top_k": request.top_k or self.config.top_k, + } + + def validate_request(self, request: GenerationRequest) -> Optional[str]: + """Validate a request against supported modalities. + + Returns: + Error message if invalid, None if valid + """ + modalities = self.supported_modalities + + has_text_input = request.text is not None + has_audio_input = request.audio_bytes is not None or request.audio_path is not None + + if has_audio_input and Modality.AUDIO_IN not in modalities: + return f"Backend '{self.name}' does not support audio input" + + if not has_text_input and not has_audio_input: + return "Request must have either text or audio input" + + return None diff --git a/recipes/multimodal/server/backends/magpie_tts_backend.py b/recipes/multimodal/server/backends/magpie_tts_backend.py new file mode 100644 index 0000000000..4d5b910d64 --- /dev/null +++ b/recipes/multimodal/server/backends/magpie_tts_backend.py @@ -0,0 +1,338 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +"""MagpieTTS backend using MagpieInferenceRunner with RTF metrics.""" + +import io +import json +import os +import shutil +import tempfile +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set + +import soundfile as sf + +from .base import BackendConfig, GenerationRequest, GenerationResult, InferenceBackend, Modality + + +@dataclass +class MagpieTTSConfig(BackendConfig): + codec_model_path: Optional[str] = None + top_k: int = 80 + temperature: float = 0.6 + use_cfg: bool = True + cfg_scale: float = 2.5 + max_decoder_steps: int = 440 + use_local_transformer: bool = False + output_sample_rate: int = 22050 + # Checkpoint loading options (alternative to model_path .nemo file) + hparams_file: Optional[str] = None + checkpoint_file: Optional[str] = None + legacy_codebooks: bool = False + legacy_text_conditioning: bool = False + hparams_from_wandb: bool = False + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "MagpieTTSConfig": + known = { + "model_path", + "device", + "dtype", + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "codec_model_path", + "use_cfg", + "cfg_scale", + "max_decoder_steps", + "use_local_transformer", + "output_sample_rate", + "hparams_file", + "checkpoint_file", + "legacy_codebooks", + "legacy_text_conditioning", + "hparams_from_wandb", + } + return cls( + **{k: v for k, v in d.items() if k in known}, extra_config={k: v for k, v in d.items() if k not in known} + ) + + +class MagpieTTSBackend(InferenceBackend): + """MagpieTTS backend. Input: JSON with 'text' and 'context_audio_filepath'.""" + + @classmethod + def get_config_class(cls) -> type: + return MagpieTTSConfig + + @property + def name(self) -> str: + return "magpie_tts" + + @property + def supported_modalities(self) -> Set[Modality]: + return {Modality.TEXT, Modality.AUDIO_OUT} + + def __init__(self, config: BackendConfig): + self.tts_config = ( + config + if isinstance(config, MagpieTTSConfig) + else MagpieTTSConfig.from_dict( + { + **{ + k: getattr(config, k) + for k in ["model_path", "device", "dtype", "max_new_tokens", "temperature", "top_p", "top_k"] + if hasattr(config, k) + }, + **config.extra_config, + } + ) + ) + super().__init__(self.tts_config) + self._model = self._runner = self._temp_dir = self._checkpoint_name = None + + def load_model(self) -> None: + # Patch NeMo's load_fsspec() to route HuggingFace resolve URLs through + # huggingface_hub.hf_hub_download() (uses file locks and local caching), + # avoiding 429s when many ranks start concurrently. + try: + import os + import re + + import nemo.collections.tts.modules.audio_codec_modules as _acm + + _orig_load_fsspec = getattr(_acm, "load_fsspec", None) + if callable(_orig_load_fsspec) and not getattr(_acm, "_hf_load_fsspec_patched", False): + try: + from huggingface_hub import hf_hub_download + + def _hf_resolve_to_local(url: str) -> str | None: + if not isinstance(url, str): + return None + url_no_q = url.split("?", 1)[0] + m = re.match(r"^https?://huggingface\.co/([^/]+)/([^/]+)/resolve/([^/]+)/(.+)$", url_no_q) + if not m: + return None + repo_id = f"{m.group(1)}/{m.group(2)}" + revision = m.group(3) + filename = m.group(4) + token = os.environ.get("HF_TOKEN") or None + return hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) + + def _load_fsspec_patched(path: str, map_location: str = None, **kwargs): + if isinstance(path, str) and path.startswith("http"): + local = _hf_resolve_to_local(path) + if local: + return _orig_load_fsspec(local, map_location=map_location, **kwargs) + return _orig_load_fsspec(path, map_location=map_location, **kwargs) + + _acm.load_fsspec = _load_fsspec_patched + _acm._hf_load_fsspec_patched = True + except Exception: + pass + except Exception: + pass + + from nemo.collections.tts.modules.magpietts_inference.inference import InferenceConfig, MagpieInferenceRunner + from nemo.collections.tts.modules.magpietts_inference.utils import ModelLoadConfig, load_magpie_model + + if not self.tts_config.codec_model_path: + raise ValueError("codec_model_path required") + + # Support both checkpoint mode (hparams + ckpt) and nemo mode + has_ckpt_mode = self.tts_config.hparams_file and self.tts_config.checkpoint_file + if has_ckpt_mode: + cfg = ModelLoadConfig( + hparams_file=self.tts_config.hparams_file, + checkpoint_file=self.tts_config.checkpoint_file, + codecmodel_path=self.tts_config.codec_model_path, + legacy_codebooks=self.tts_config.legacy_codebooks, + legacy_text_conditioning=self.tts_config.legacy_text_conditioning, + hparams_from_wandb=self.tts_config.hparams_from_wandb, + ) + else: + cfg = ModelLoadConfig( + nemo_file=self.config.model_path, + codecmodel_path=self.tts_config.codec_model_path, + legacy_codebooks=self.tts_config.legacy_codebooks, + legacy_text_conditioning=self.tts_config.legacy_text_conditioning, + ) + self._model, self._checkpoint_name = load_magpie_model(cfg, device=self.config.device) + + self._runner = MagpieInferenceRunner( + self._model, + InferenceConfig( + temperature=self.tts_config.temperature, + topk=self.tts_config.top_k, + max_decoder_steps=self.tts_config.max_decoder_steps, + use_cfg=self.tts_config.use_cfg, + cfg_scale=self.tts_config.cfg_scale, + use_local_transformer=self.tts_config.use_local_transformer, + batch_size=16, + ), + ) + + self._temp_dir = tempfile.mkdtemp(prefix="magpie_tts_") + self.tts_config.output_sample_rate = self._model.sample_rate + self._is_loaded = True + print( + f"[MagpieTTSBackend] Loaded: {self._checkpoint_name}, sr={self._model.sample_rate}, cfg={self.tts_config.use_cfg}" + ) + + def _extract_json(self, text: str) -> dict: + """Extract JSON object from text, skipping non-JSON parts.""" + if not text: + return {"text": ""} + idx = text.find("{") + if idx >= 0: + try: + return json.loads(text[idx:]) + except json.JSONDecodeError: + pass + return {"text": text} + + def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: + if not self._is_loaded: + return [GenerationResult(error="Model not loaded", request_id=r.request_id) for r in requests] + if not requests: + return [] + + start_time = time.time() + batch_dir = os.path.join(self._temp_dir, f"batch_{int(time.time() * 1000)}") + output_dir = os.path.join(batch_dir, "output") + os.makedirs(output_dir, exist_ok=True) + + try: + # Reset KV caches to avoid cross-request shape mismatches + try: + if self._model is not None: + decoder = getattr(self._model, "decoder", None) + if decoder is not None and hasattr(decoder, "reset_cache"): + decoder.reset_cache(use_cache=False) + except Exception: + pass + + # Parse requests, extracting JSON from text + parsed = [self._extract_json(r.text) for r in requests] + + # Create audio_dir with symlinks to all context audio files + audio_dir = os.path.join(batch_dir, "audio") + os.makedirs(audio_dir, exist_ok=True) + + manifest_path = os.path.join(batch_dir, "manifest.json") + with open(manifest_path, "w") as f: + for i, p in enumerate(parsed): + ctx = p.get("context_audio_filepath", "") + if ctx and os.path.exists(ctx): + link_name = f"ctx_{i}_{os.path.basename(ctx)}" + link_path = os.path.join(audio_dir, link_name) + if not os.path.exists(link_path): + os.symlink(ctx, link_path) + else: + link_name = f"d{i}.wav" + link_path = os.path.join(audio_dir, link_name) + if not os.path.exists(link_path): + sr = int(getattr(self.tts_config, "output_sample_rate", 22050) or 22050) + dur_s = 0.1 + n = max(1, int(sr * dur_s)) + sf.write(link_path, [0.0] * n, sr) + f.write( + json.dumps( + { + "text": p.get("text", ""), + "audio_filepath": link_name, + "context_audio_filepath": link_name, + "duration": p.get("duration", 5.0), + "context_audio_duration": p.get("context_audio_duration", 5.0), + } + ) + + "\n" + ) + + config_path = os.path.join(batch_dir, "config.json") + with open(config_path, "w") as f: + json.dump({"batch": {"manifest_path": manifest_path, "audio_dir": audio_dir}}, f) + + # Run inference + from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config + + dataset = self._runner.create_dataset(load_evalset_config(config_path)) + rtf_list, _ = self._runner.run_inference_on_dataset( + dataset, output_dir, save_cross_attention_maps=False, save_context_audio=False + ) + + gen_time = time.time() - start_time + batch_metrics = { + "total_time_sec": gen_time, + "num_samples": len(requests), + **self._runner.compute_mean_rtf_metrics(rtf_list), + } + + # Build results + results = [] + for i, req in enumerate(requests): + path = os.path.join(output_dir, f"predicted_audio_{i}.wav") + if os.path.exists(path): + audio, sr = sf.read(path) + buf = io.BytesIO() + sf.write(buf, audio, sr, format="WAV") + buf.seek(0) + dur = len(audio) / sr + results.append( + GenerationResult( + text=parsed[i].get("text", ""), + audio_bytes=buf.read(), + audio_sample_rate=self.tts_config.output_sample_rate, + audio_format="wav", + request_id=req.request_id, + generation_time_ms=gen_time * 1000 / len(requests), + debug_info={ + "checkpoint": self._checkpoint_name, + "audio_duration_sec": dur, + "rtf": gen_time / len(requests) / dur if dur else 0, + "config": { + "temp": self.tts_config.temperature, + "top_k": self.tts_config.top_k, + "cfg": self.tts_config.use_cfg, + "cfg_scale": self.tts_config.cfg_scale, + }, + "batch_metrics": batch_metrics, + }, + ) + ) + else: + results.append(GenerationResult(error=f"Audio not found: {path}", request_id=req.request_id)) + return results + except Exception as e: + import traceback + + traceback.print_exc() + return [GenerationResult(error=str(e), request_id=r.request_id) for r in requests] + finally: + shutil.rmtree(batch_dir, ignore_errors=True) + + def validate_request(self, request: GenerationRequest) -> Optional[str]: + return "Text required" if not request.text else None + + def health_check(self) -> Dict[str, Any]: + h = super().health_check() + if self._is_loaded: + h.update( + { + "checkpoint": self._checkpoint_name, + "codec": self.tts_config.codec_model_path, + "cfg": self.tts_config.use_cfg, + "cfg_scale": self.tts_config.cfg_scale, + "sample_rate": self.tts_config.output_sample_rate, + } + ) + return h + + def __del__(self): + if getattr(self, "_temp_dir", None) and os.path.exists(self._temp_dir): + shutil.rmtree(self._temp_dir, ignore_errors=True) diff --git a/recipes/multimodal/server/unified_server.py b/recipes/multimodal/server/unified_server.py new file mode 100644 index 0000000000..b421faeec2 --- /dev/null +++ b/recipes/multimodal/server/unified_server.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified NeMo Inference Server with OpenAI-compatible API. + +Backend-agnostic: all backend-specific logic lives in backend modules. +The server only knows about GenerationRequest/GenerationResult and the +InferenceBackend interface. + +Exposes /v1/chat/completions endpoint for OpenAI compatibility. +Backends may register additional routes via get_extra_routes(). +""" + +import asyncio +import base64 +import hashlib +import json +import os +import re +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional + +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse + +from .backends import GenerationRequest, GenerationResult, get_backend + +# Debug flag +DEBUG = os.getenv("DEBUG", "").lower() in ("true", "1", "yes", "on") + + +@dataclass +class PendingRequest: + """Container for a pending batched request.""" + + request: GenerationRequest + future: asyncio.Future + timestamp: float + + +class RequestBatcher: + """Manages request batching with configurable delay.""" + + def __init__(self, backend, batch_size: int, batch_timeout: float): + self.backend = backend + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.pending_requests: List[PendingRequest] = [] + self.lock = asyncio.Lock() + self.timeout_task: Optional[asyncio.Task] = None + self.processing = False + + # Stats + self.total_requests = 0 + self.total_batches = 0 + + async def add_request(self, request: GenerationRequest) -> GenerationResult: + """Add a request and wait for result.""" + future = asyncio.Future() + pending = PendingRequest(request=request, future=future, timestamp=time.time()) + + async with self.lock: + self.pending_requests.append(pending) + + if len(self.pending_requests) >= self.batch_size: + if DEBUG: + print(f"[Batcher] Batch full ({self.batch_size}), processing immediately") + asyncio.create_task(self._process_batch()) + elif self.batch_timeout == 0: + asyncio.create_task(self._process_batch()) + elif self.timeout_task is None or self.timeout_task.done(): + self.timeout_task = asyncio.create_task(self._timeout_handler()) + + return await future + + async def _timeout_handler(self): + """Handle batch timeout.""" + await asyncio.sleep(self.batch_timeout) + async with self.lock: + if self.pending_requests and not self.processing: + if DEBUG: + print(f"[Batcher] Timeout, processing {len(self.pending_requests)} requests") + asyncio.create_task(self._process_batch()) + + async def _process_batch(self): + """Process pending requests as a batch.""" + async with self.lock: + if not self.pending_requests or self.processing: + return + + self.processing = True + batch = self.pending_requests[: self.batch_size] + self.pending_requests = self.pending_requests[self.batch_size :] + + try: + requests = [p.request for p in batch] + + if DEBUG: + print(f"[Batcher] Processing batch of {len(requests)} requests") + + loop = asyncio.get_event_loop() + results = await loop.run_in_executor(None, self.backend.generate, requests) + + for pending, result in zip(batch, results): + if not pending.future.done(): + pending.future.set_result(result) + + self.total_requests += len(batch) + self.total_batches += 1 + + except Exception as e: + for pending in batch: + if not pending.future.done(): + pending.future.set_exception(e) + finally: + async with self.lock: + self.processing = False + if self.pending_requests: + if self.batch_timeout == 0 or len(self.pending_requests) >= self.batch_size: + asyncio.create_task(self._process_batch()) + elif self.timeout_task is None or self.timeout_task.done(): + self.timeout_task = asyncio.create_task(self._timeout_handler()) + + +# Global state +backend_instance = None +request_batcher = None +server_config = {} + + +def extract_audio_from_messages(messages: List[Dict[str, Any]]) -> List[bytes]: + """Extract all audio bytes from OpenAI-format messages. + + Looks for audio_url in message content with format: + {"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}} + """ + audio_list = [] + for message in messages: + content = message.get("content") + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "audio_url": + audio_url = item.get("audio_url", {}) + url = audio_url.get("url", "") + match = re.match(r"data:audio/\w+;base64,(.+)", url) + if match: + audio_list.append(base64.b64decode(match.group(1))) + return audio_list + + +def extract_text_from_messages(messages: List[Dict[str, Any]]) -> str: + """Extract text content from OpenAI-format messages.""" + texts = [] + for message in messages: + content = message.get("content") + if isinstance(content, str): + if content: + texts.append(content) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text = item.get("text", "") + if text: + texts.append(text) + elif isinstance(item, str): + texts.append(item) + return " ".join(texts) + + +def extract_system_prompt(messages: List[Dict[str, Any]]) -> Optional[str]: + """Extract system prompt from messages.""" + for message in messages: + if message.get("role") == "system": + content = message.get("content") + if isinstance(content, str): + return content + elif isinstance(content, list): + texts = [ + item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "text" + ] + return " ".join(texts) if texts else None + return None + + +def create_app( + backend_type: str, + config_dict: Dict[str, Any], + batch_size: int = 8, + batch_timeout: float = 0.1, +) -> FastAPI: + """Create and configure the FastAPI app. + + Args: + backend_type: Name of the backend to use (e.g., 'magpie_tts'). + config_dict: Full configuration dict for the backend's config class. + batch_size: Maximum batch size for request batching. + batch_timeout: Seconds to wait before processing an incomplete batch. + """ + global backend_instance, request_batcher, server_config + + app = FastAPI( + title="Unified NeMo Inference Server", + description=f"OpenAI-compatible API for NeMo model inference ({backend_type} backend)", + version="1.0.0", + ) + + server_config = { + "backend_type": backend_type, + "model_path": config_dict.get("model_path", ""), + "batch_size": batch_size, + "batch_timeout": batch_timeout, + } + + @app.on_event("startup") + async def startup(): + global backend_instance, request_batcher + + # Look up backend class and its config class + BackendClass = get_backend(backend_type) + ConfigClass = BackendClass.get_config_class() + + # Validate and create config + config = ConfigClass.from_dict(config_dict) + + # Instantiate and load backend + print(f"[Server] Initializing {backend_type} backend...") + backend_instance = BackendClass(config) + backend_instance.load_model() + + # Create batcher + request_batcher = RequestBatcher(backend_instance, batch_size, batch_timeout) + + # Register any extra routes from the backend + extra_routes = BackendClass.get_extra_routes(backend_instance) + for route in extra_routes: + app.add_api_route( + route["path"], + route["endpoint"], + methods=route.get("methods", ["GET"]), + ) + print(f"[Server] Registered extra route: {route['path']}") + + print("[Server] Ready!") + print(f" Backend: {backend_type}") + print(f" Model: {config.model_path}") + print(f" Batch size: {batch_size}") + print(f" Batch timeout: {batch_timeout}s") + + @app.get("/") + async def root(): + """Root endpoint with server info.""" + return { + "service": "Unified NeMo Inference Server", + "version": "1.0.0", + "backend": server_config.get("backend_type"), + "model": server_config.get("model_path"), + "endpoints": ["/v1/chat/completions", "/health", "/v1/models"], + } + + @app.get("/health") + async def health(): + """Health check endpoint.""" + if backend_instance is None: + return JSONResponse(status_code=503, content={"status": "not_ready", "error": "Backend not initialized"}) + + health_info = backend_instance.health_check() + health_info["status"] = "healthy" if backend_instance.is_loaded else "not_ready" + health_info["timestamp"] = datetime.now().isoformat() + + return health_info + + @app.get("/v1/models") + async def list_models(): + """OpenAI-compatible models endpoint.""" + model_id = server_config.get("model_path", "unknown") + return { + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "nvidia", + } + ], + } + + @app.post("/v1/chat/completions") + async def chat_completions(request: Dict[str, Any]): + """OpenAI-compatible chat completions endpoint with audio support.""" + if backend_instance is None or not backend_instance.is_loaded: + raise HTTPException(status_code=503, detail="Model not loaded") + + try: + messages = request.get("messages", []) + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") + + # Extract components from messages + audio_bytes_list = extract_audio_from_messages(messages) + text = extract_text_from_messages(messages) + system_prompt = extract_system_prompt(messages) + + # Get generation parameters + max_tokens = request.get("max_tokens", 512) + temperature = request.get("temperature", 1.0) + top_p = request.get("top_p", 1.0) + seed = request.get("seed") + + # Create generation request + gen_request = GenerationRequest( + text=text if text else None, + system_prompt=system_prompt, + audio_bytes=audio_bytes_list[0] if len(audio_bytes_list) == 1 else None, + audio_bytes_list=audio_bytes_list if len(audio_bytes_list) > 1 else None, + max_new_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + seed=seed, + request_id=hashlib.md5(f"{time.time()}".encode()).hexdigest()[:8], + extra_params=request.get("extra_body", {}), + ) + + # Validate request + error = backend_instance.validate_request(gen_request) + if error: + raise HTTPException(status_code=400, detail=error) + + # Process through batcher + result = await request_batcher.add_request(gen_request) + + if not result.is_success(): + raise HTTPException(status_code=500, detail=result.error) + + # Build OpenAI-compatible response + response_id = f"chatcmpl-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}" + message_content = result.text or "" + + # Save outputs if AUDIO_SAVE_DIR is set + save_dir = os.environ.get("AUDIO_SAVE_DIR", "") + if save_dir: + try: + os.makedirs(save_dir, exist_ok=True) + except PermissionError: + save_dir = "" + saved_audio_path = None + + if save_dir: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + base_filename = f"response_{timestamp}_{response_id}" + + try: + saved_json_path = os.path.join(save_dir, f"{base_filename}.json") + json_output = { + "response_id": response_id, + "timestamp": timestamp, + "text": message_content, + "debug_info": result.debug_info, + "generation_time_ms": result.generation_time_ms, + "num_tokens_generated": result.num_tokens_generated, + } + with open(saved_json_path, "w") as f: + json.dump(json_output, f, indent=2) + except Exception as e: + print(f"[Server] Warning: Failed to save JSON: {e}") + + if result.audio_bytes: + try: + saved_audio_path = os.path.join(save_dir, f"{base_filename}.wav") + with open(saved_audio_path, "wb") as f: + f.write(result.audio_bytes) + except Exception as e: + print(f"[Server] Warning: Failed to save audio: {e}") + + # Build audio output if available + audio_output = None + if result.audio_bytes: + audio_output = { + "data": base64.b64encode(result.audio_bytes).decode("utf-8"), + "format": result.audio_format or "wav", + "sample_rate": result.audio_sample_rate, + "expires_at": int(time.time()) + 3600, + "transcript": result.text or "", + } + + # Embed debug_info in content as JSON + final_content = message_content + if result.debug_info: + final_content = f"{message_content}\n{json.dumps(result.debug_info)}" + + response = { + "id": response_id, + "object": "chat.completion", + "created": int(time.time()), + "model": server_config.get("model_path"), + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": final_content, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": -1, + "completion_tokens": result.num_tokens_generated or -1, + "total_tokens": -1, + }, + } + + if audio_output: + response["choices"][0]["message"]["audio"] = audio_output + + if result.debug_info: + response["debug_info"] = result.debug_info + + if saved_audio_path: + response["saved_audio_path"] = saved_audio_path + + return response + + except HTTPException: + raise + except Exception as e: + import traceback + + traceback.print_exc() + raise HTTPException(status_code=500, detail=str(e)) + + return app