From 9cd05d06e068f01c957938490c3e663918577208 Mon Sep 17 00:00:00 2001 From: AndyZhou952 Date: Fri, 21 Nov 2025 17:14:48 +0800 Subject: [PATCH 01/11] npu initial --- vllm_omni/entrypoints/omni_llm.py | 261 ++-- vllm_omni/entrypoints/omni_stage.py | 6 +- vllm_omni/entrypoints/stage_utils.py | 187 +-- vllm_omni/entrypoints/utils.py | 25 +- .../model_executor/models/qwen2_5_omni.py | 195 +-- .../models/qwen2_5_omni_token2wav.py | 345 +++-- vllm_omni/utils/platform_utils.py | 35 + vllm_omni/worker/npu_ar_model_runner.py | 641 +++++++++ vllm_omni/worker/npu_ar_worker.py | 14 + .../worker/npu_diffusion_model_runner.py | 587 +++++++++ vllm_omni/worker/npu_diffusion_worker.py | 13 + vllm_omni/worker/npu_model_runner.py | 1154 +++++++++++++++++ 12 files changed, 3034 insertions(+), 429 deletions(-) create mode 100644 vllm_omni/utils/platform_utils.py create mode 100644 vllm_omni/worker/npu_ar_model_runner.py create mode 100644 vllm_omni/worker/npu_ar_worker.py create mode 100644 vllm_omni/worker/npu_diffusion_model_runner.py create mode 100644 vllm_omni/worker/npu_diffusion_worker.py create mode 100644 vllm_omni/worker/npu_model_runner.py diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 655d0307fbf..69083cb26de 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -1,10 +1,11 @@ +from typing import Any, Optional, Sequence, Union +import logging import multiprocessing as mp +from concurrent.futures import ThreadPoolExecutor, as_completed import os +import queue import sys import time -from collections.abc import Sequence -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Optional, Union import cloudpickle from pydantic import ValidationError @@ -25,17 +26,19 @@ from vllm_omni.engine.arg_utils import OmniEngineArgs from vllm_omni.engine.output_processor import MultimodalOutputProcessor from vllm_omni.engine.processor import OmniProcessor +from vllm_omni.entrypoints.omni_stage import OmniStage +from vllm_omni.entrypoints.utils import load_stage_configs_from_model, select_worker_class from vllm_omni.entrypoints.log_utils import ( - OrchestratorMetrics, + remove_old_logs, configure_orchestrator_logger, init_stats_paths, - remove_old_logs, + OrchestratorMetrics, +) +from vllm_omni.entrypoints.stage_utils import ( + maybe_load_from_ipc as _load, + encode_for_ipc as _encode, + serialize_obj as _ser, ) -from vllm_omni.entrypoints.omni_stage import OmniStage -from vllm_omni.entrypoints.stage_utils import encode_for_ipc as _encode -from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load -from vllm_omni.entrypoints.stage_utils import serialize_obj as _set -from vllm_omni.entrypoints.utils import load_stage_configs_from_model, load_stage_configs_from_yaml from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -49,25 +52,23 @@ class OmniLLM: - max_inflight=1 per stage (serial within a stage), but pipeline across stages """ - def __init__( - self, - model: str, - stage_configs_path: Optional[str] = None, - log_stats: bool = False, - log_file: Optional[str] = None, - init_sleep_seconds: int = 20, - shm_threshold_bytes: int = 65536, - batch_timeout: int = 10, - init_timeout: int = 300, - **kwargs, - ): + def __init__(self, model: str, + stage_configs=None, + log_stats: bool = False, + log_file: Optional[str] = None, + init_sleep_seconds: int = 20, + shm_threshold_bytes: int = 65536, + batch_timeout: int = 10, + init_timeout: int = 300, + **kwargs): self.batch_timeout = batch_timeout self._enable_stats: bool = bool(log_stats) # Do NOT call super().__init__ to avoid creating OmniStageLLM instances in parent. - if stage_configs_path is None: + if stage_configs is None: self.stage_configs = load_stage_configs_from_model(model) else: - self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) + self.stage_configs = stage_configs + # Optional file handler for orchestrator self._log_file = log_file @@ -78,15 +79,8 @@ def __init__( self._stats_file, self._overall_stats_file = init_stats_paths(self._enable_stats, self._log_file) self._initialize_stages(model, init_sleep_seconds, shm_threshold_bytes, init_timeout) - def _initialize_stages( - self, - model: str, - init_sleep_seconds: int, - shm_threshold_bytes: int, - init_timeout: int, - ) -> None: + def _initialize_stages(self, model: str, init_sleep_seconds: int, shm_threshold_bytes: int, init_timeout: int) -> None: self.stage_list: list[OmniStage] = [] - # Build OmniStage instances in parallel, preserve original order def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: idx, cfg = idx_cfg @@ -110,6 +104,7 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: # Wait for all stages to report readiness before seeding self._stages_ready: set[int] = set() self._wait_for_stages_ready(timeout=init_timeout) + def _start_stage_processes(self, model: str) -> None: for stage_id, stage in enumerate(self.stage_list): @@ -136,10 +131,7 @@ def close(self) -> None: try: q.put_nowait(None) except Exception as e: - logger.warning( - "[Orchestrator] Failed to send shutdown signal to stage input queue: %s", - e, - ) + logger.warning("[Orchestrator] Failed to send shutdown signal to stage input queue: %s", e) for stage in self.stage_list: try: stage.stop_stage_worker() @@ -155,26 +147,17 @@ def __del__(self) -> None: # best-effort def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - sampling_params_list: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, - ) -> list[OmniRequestOutput]: - try: - return self._run_generation(prompts, sampling_params_list) - except Exception as e: - logger.exception("[Orchestrator] Failed to run generation: %s", e) - raise e - finally: - self.close() - - def _run_generation( - self, - prompts: Union[PromptType, Sequence[PromptType]], - sampling_params_list: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, + sampling_params_list: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, ) -> list[OmniRequestOutput]: logger.debug("[Orchestrator] generate() called") if sampling_params_list is None: raise ValueError("sampling_params_list is required for pipelined generation") if len(sampling_params_list) != len(self.stage_list): - raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + raise ValueError( + f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}" + ) # Normalize prompts to a list for per-request iteration if not isinstance(prompts, (list, tuple)): @@ -193,6 +176,7 @@ def _run_generation( # Track per-request start time for end-to-end timing _req_start_ts: dict[int, float] = {} _wall_start_ts: float = time.time() + _last_finish_ts: float = _wall_start_ts # Determine the final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) final_stage_id_for_e2e = -1 @@ -203,20 +187,10 @@ def _run_generation( if final_stage_id_for_e2e < 0: final_stage_id_for_e2e = len(self.stage_list) - 1 except Exception as e: - logger.debug( - "[Orchestrator] Failed to determine final stage for E2E; falling back to last: %s", - e, - exc_info=True, - ) + logger.debug("[Orchestrator] Failed to determine final stage for E2E; falling back to last: %s", e, exc_info=True) final_stage_id_for_e2e = len(self.stage_list) - 1 # Metrics/aggregation helper - metrics = OrchestratorMetrics( - num_stages, - self._enable_stats, - self._stats_file, - self._overall_stats_file, - _wall_start_ts, - ) + metrics = OrchestratorMetrics(num_stages, self._enable_stats, self._stats_file, self._overall_stats_file, _wall_start_ts) # Seed stage-0 queue with all requests logger.debug("[Orchestrator] Seeding %d requests into stage-0", len(request_prompts)) @@ -240,11 +214,7 @@ def _run_generation( completed_requests = 0 total_requests = len(request_prompts) - logger.debug( - "[Orchestrator] Entering scheduling loop: total_requests=%d, stages=%d", - total_requests, - num_stages, - ) + logger.debug("[Orchestrator] Entering scheduling loop: total_requests=%d, stages=%d", total_requests, num_stages) while completed_requests < total_requests: made_progress = False for stage_id, stage in enumerate(self.stage_list): @@ -255,17 +225,11 @@ def _run_generation( made_progress = True req_id = result.get("request_id") if "error" in result: - logger.error( - "Stage %s error on request %s: %s", - stage_id, - req_id, - result["error"], - ) + logger.error("Stage %s error on request %s: %s", stage_id, req_id, result["error"]) continue - + if result.get("type") == "stage_ready": - # Only happens when stage is initialized slower than expected, - # so we wait for a short time and try again + #Only happens when stage is initialized slower than expected, so we wait for a short time and try again time.sleep(0.05) continue @@ -277,17 +241,8 @@ def _run_generation( if _m is not None: metrics.on_stage_metrics(stage_id, req_id, _m) except Exception as e: - logger.exception( - "[Orchestrator] Failed to process metrics for stage %s, req %s: %s", - stage_id, - req_id, - e, - ) - logger.debug( - "[Orchestrator] Stage-%s completed request %s; forwarding or finalizing", - stage_id, - req_id, - ) + logger.exception("[Orchestrator] Failed to process metrics for stage %s, req %s: %s", stage_id, req_id, e) + logger.debug("[Orchestrator] Stage-%s completed request %s; forwarding or finalizing", stage_id, req_id) stage.set_engine_outputs(engine_outputs) if getattr(stage, "final_output", False): @@ -298,30 +253,15 @@ def _run_generation( request_output=engine_outputs, ) ) - logger.debug( - "[Orchestrator] Request %s finalized at stage-%s", - req_id, - stage_id, - ) + logger.debug("[Orchestrator] Request %s finalized at stage-%s", req_id, stage_id) - # End-to-end timing and time-per-token for final output - # (only once per request at the designated final stage) + # End-to-end timing and time-per-token for final output (only once per request at the designated final stage) try: rid_int = int(req_id) if isinstance(req_id, (int, str)) and str(req_id).isdigit() else req_id if stage_id == final_stage_id_for_e2e and rid_int not in metrics.e2e_done: - metrics.on_finalize_request( - stage_id, - req_id, - engine_outputs, - _req_start_ts.get(req_id, _wall_start_ts), - ) + metrics.on_finalize_request(stage_id, req_id, engine_outputs, _req_start_ts.get(req_id, _wall_start_ts)) except Exception as e: - logger.exception( - "[Orchestrator] Finalize request handling error for req %s at stage %s: %s", - req_id, - stage_id, - e, - ) + logger.exception("[Orchestrator] Finalize request handling error for req %s at stage %s: %s", req_id, stage_id, e) next_stage_id = stage_id + 1 if next_stage_id < num_stages: @@ -332,7 +272,7 @@ def _run_generation( # Measure transfer size and time (encode + enqueue) size_bytes = 0 try: - size_bytes = len(_set(next_inputs)) + size_bytes = len(_ser(next_inputs)) except Exception: size_bytes = 0 t0 = time.time() @@ -342,51 +282,27 @@ def _run_generation( obj_key="engine_inputs", shm_key="engine_inputs_shm", ) - ipc_payload.update( - { - "request_id": req_id, - "sampling_params": sp_next, - "sent_ts": time.time(), - } - ) + ipc_payload.update({ + "request_id": req_id, + "sampling_params": sp_next, + "sent_ts": time.time(), + }) self.stage_list[next_stage_id].submit(ipc_payload) t1 = time.time() tx_ms = (t1 - t0) * 1000.0 - metrics.on_forward( - stage_id, - next_stage_id, - req_id, - int(size_bytes), - float(tx_ms), - bool("engine_inputs_shm" in ipc_payload), - ) + metrics.on_forward(stage_id, next_stage_id, req_id, int(size_bytes), float(tx_ms), bool("engine_inputs_shm" in ipc_payload)) except Exception as e: - logger.warning( - "[Orchestrator] IPC encode failed for req %s: %s; falling back to inline payload", - req_id, - e, - ) - self.stage_list[next_stage_id].submit( - { - "request_id": req_id, - "engine_inputs": next_inputs, - "sampling_params": sp_next, - } - ) - logger.debug( - "[Orchestrator] Forwarded request %s to stage-%s", - req_id, - next_stage_id, - ) + logger.warning("[Orchestrator] IPC encode failed for req %s: %s; falling back to inline payload", req_id, e) + self.stage_list[next_stage_id].submit({ + "request_id": req_id, + "engine_inputs": next_inputs, + "sampling_params": sp_next, + }) + logger.debug("[Orchestrator] Forwarded request %s to stage-%s", req_id, next_stage_id) remaining_by_stage[next_stage_id] += 1 else: completed_requests += 1 - logger.debug( - "[Orchestrator] Request %s fully completed (%d/%d)", - req_id, - completed_requests, - total_requests, - ) + logger.debug("[Orchestrator] Request %s fully completed (%d/%d)", req_id, completed_requests, total_requests) if not made_progress: time.sleep(0.005) @@ -425,9 +341,7 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: not_ready = sorted(set(range(num_stages)) - set(self._stages_ready)) logger.warning( "[Orchestrator] Initialization timeout: only %s/%s stages are ready; not ready: %s", - len(self._stages_ready), - num_stages, - not_ready, + len(self._stages_ready), num_stages, not_ready, ) # Provide actionable suggestions before shutdown try: @@ -438,7 +352,9 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: "Increase initialization wait time (init_sleep_seconds or call-site timeout).", ] if getattr(self, "_log_file", None): - suggestions.append(f"Inspect per-stage log files for details: {self._log_file}.stage.log") + suggestions.append( + f"Inspect per-stage log files for details: {self._log_file}.stage.log" + ) logger.error( "[Orchestrator] Stage initialization failed, shutting down. Suggestions:\n- %s", "\n- ".join(suggestions), @@ -468,9 +384,13 @@ class OmniStageLLM(LLM): def __init__( self, model: str, - compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + compilation_config: Optional[ + Union[int, dict[str, Any], CompilationConfig] + ] = None, hf_overrides: Optional[HfOverrides] = None, - structured_outputs_config: Optional[Union[dict[str, Any], StructuredOutputsConfig]] = None, + structured_outputs_config: Optional[ + Union[dict[str, Any], StructuredOutputsConfig] + ] = None, **kwargs, ): """LLM constructor.""" @@ -479,12 +399,18 @@ def __init__( if "worker_cls" in kwargs: worker_cls = kwargs["worker_cls"] + # Select appropriate worker class based on device type + if isinstance(worker_cls, str): + worker_cls = select_worker_class(worker_cls) + kwargs["worker_cls"] = worker_cls # if the worker_cls is not qualified string name, # we serialize it using cloudpickle to avoid pickling issues - if isinstance(worker_cls, type): + elif isinstance(worker_cls, type): kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) - if "kv_transfer_config" in kwargs and isinstance(kwargs["kv_transfer_config"], dict): + if "kv_transfer_config" in kwargs and isinstance( + kwargs["kv_transfer_config"], dict + ): from vllm.config.kv_transfer import KVTransferConfig raw_config_dict = kwargs["kv_transfer_config"] @@ -492,7 +418,8 @@ def __init__( kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict) except ValidationError as e: logger.error( - "Failed to convert 'kv_transfer_config' dict to KVTransferConfig object. Dict: %s. Error: %s", + "Failed to convert 'kv_transfer_config' dict to " + "KVTransferConfig object. Dict: %s. Error: %s", raw_config_dict, e, ) @@ -505,10 +432,16 @@ def __init__( if compilation_config is not None: if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig(level=compilation_config) + compilation_config_instance = CompilationConfig( + level=compilation_config + ) elif isinstance(compilation_config, dict): compilation_config_instance = CompilationConfig( - **{k: v for k, v in compilation_config.items() if is_init_field(CompilationConfig, k)} + **{ + k: v + for k, v in compilation_config.items() + if is_init_field(CompilationConfig, k) + } ) else: compilation_config_instance = compilation_config @@ -518,7 +451,11 @@ def __init__( if structured_outputs_config is not None: if isinstance(structured_outputs_config, dict): structured_outputs_instance = StructuredOutputsConfig( - **{k: v for k, v in structured_outputs_config.items() if is_init_field(StructuredOutputsConfig, k)} + **{ + k: v + for k, v in structured_outputs_config.items() + if is_init_field(StructuredOutputsConfig, k) + } ) else: structured_outputs_instance = structured_outputs_config @@ -534,7 +471,9 @@ def __init__( ) # Create the Engine (autoselects V0 vs V1) - self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + self.llm_engine = LLMEngine.from_engine_args( + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS + ) self.llm_engine.output_processor = MultimodalOutputProcessor( tokenizer=self.llm_engine.tokenizer, log_stats=self.llm_engine.log_stats, @@ -556,4 +495,6 @@ def __init__( # Load the Input/Output processor plugin if any io_processor_plugin = self.llm_engine.model_config.io_processor_plugin - self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin) + self.io_processor = get_io_processor( + self.llm_engine.vllm_config, io_processor_plugin + ) \ No newline at end of file diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 994dfa423ca..b1e8fed7522 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -31,7 +31,7 @@ _to_dict, maybe_dump_to_shm, maybe_load_from_ipc_with_metrics, - set_stage_gpu_devices, + set_stage_devices, ) from vllm_omni.inputs.data import OmniTokensPrompt @@ -274,7 +274,9 @@ def filter(self, record: _logging.LogRecord) -> bool: # Device mapping try: - set_stage_gpu_devices(stage_id, runtime_cfg.get("devices")) + from vllm_omni.utils import detect_device_type + device_type = detect_device_type() + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: _logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e) diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 36ad2b0ab1c..fa7fedd6946 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -1,122 +1,144 @@ from __future__ import annotations -import json +from typing import Any, Dict, Optional, Tuple, Union + import logging +import json import os import pickle from multiprocessing import shared_memory as _shm -from typing import Any import cloudpickle -from omegaconf import OmegaConf logger = logging.getLogger(__name__) -def set_stage_gpu_devices(stage_id: int, devices: str | int | None) -> None: - """Configure per-stage CUDA visibility and current device. - - Behavior - - Comma-separated string (e.g. "2,5,7"): set CUDA_VISIBLE_DEVICES exactly - to this list; logical index 0 is used as current device. - - Integer or digit-string: treat as logical index (0-based) into the current - CUDA_VISIBLE_DEVICES mapping; map to the physical device, and then set - CUDA_VISIBLE_DEVICES to this single device. - - None/"cpu": keep default visibility. - - Otherwise: set CUDA_VISIBLE_DEVICES to the provided single device string. +def set_stage_devices( + stage_id: int, + devices: Optional[Union[str, int]], + device_type: Optional[str] = None, +) -> None: + """Configure per-stage device visibility and current device (CUDA or NPU). + + This function sets environment variables that control which devices are visible + to the process, and sets the current device. It must be called BEFORE worker + initialization so that workers see the correct devices. + + Args: + stage_id: Stage identifier for logging + devices: Device specification: + - Comma-separated string (e.g. "2,5,7"): set device visibility env var + exactly to this list; logical index 0 is used as current device. + - Integer or digit-string: treat as logical index (0-based) into the + current device visibility mapping; map to physical device, then set + env var to this single device. + - None/"cpu": keep default visibility. + - Otherwise: set env var to the provided single device string. + device_type: Device type ("cuda" or "npu"). If None, auto-detects. + + Behavior: + - CUDA: Sets CUDA_VISIBLE_DEVICES and calls torch.cuda.set_device() + - NPU: Sets ASCEND_RT_VISIBLE_DEVICES and calls torch.npu.set_device() """ + from vllm_omni.utils import detect_device_type, get_device_control_env_var + + if device_type is None: + device_type = detect_device_type() + + env_var = get_device_control_env_var() + + # Select device-specific torch functions + if device_type == "npu": + try: + import torch.npu # type: ignore[import-untyped] + except ImportError: + logger.debug("[Stage-%s] torch.npu not available, skipping NPU device setup", stage_id) + return + + is_available_fn = torch.npu.is_available + set_device_fn = torch.npu.set_device + device_count_fn = torch.npu.device_count + get_device_properties_fn = torch.npu.get_device_properties + mem_get_info_fn = torch.npu.mem_get_info + get_device_name_fn = torch.npu.get_device_name + device_type_label = "NPU" + elif device_type == "cuda": + import torch # noqa: WPS433 + + is_available_fn = torch.cuda.is_available + set_device_fn = torch.cuda.set_device + device_count_fn = torch.cuda.device_count + get_device_properties_fn = torch.cuda.get_device_properties + mem_get_info_fn = torch.cuda.mem_get_info + get_device_name_fn = torch.cuda.get_device_name + device_type_label = "CUDA" + else: + logger.debug("[Stage-%s] Unsupported device type: %s", stage_id, device_type) + return + try: - selected_physical: int | None = None - logical_idx: int | None = None + selected_physical: Optional[int] = None + logical_idx: Optional[int] = None if isinstance(devices, str) and "," in devices: - os.environ["CUDA_VISIBLE_DEVICES"] = devices + os.environ[env_var] = devices toks = [t.strip() for t in devices.split(",") if t.strip() != ""] if toks: try: selected_physical = int(toks[0]) logger.debug( - "[Stage-%s] Set CUDA_VISIBLE_DEVICES to %s; logical 0 -> physical %s", - stage_id, - devices, - selected_physical, + "[Stage-%s] Set %s to %s; logical 0 -> physical %s", + stage_id, env_var, devices, selected_physical, ) except Exception as e: - logger.debug("[Stage-%s] Failed to parse first CUDA device: %s", stage_id, e) + logger.debug("[Stage-%s] Failed to parse first %s device: %s", stage_id, device_type_label, e) selected_physical = None elif isinstance(devices, (int, str)) and (isinstance(devices, int) or str(devices).isdigit()): logical_idx = max(0, int(devices)) - vis = os.environ.get("CUDA_VISIBLE_DEVICES") + vis = os.environ.get(env_var) if vis: try: mapping = [int(x) for x in vis.split(",") if x.strip() != ""] if 0 <= logical_idx < len(mapping): selected_physical = mapping[logical_idx] except Exception as e: - logger.debug( - "[Stage-%s] Failed to map logical index via CUDA_VISIBLE_DEVICES: %s", - stage_id, - e, - ) + logger.debug("[Stage-%s] Failed to map logical index via %s: %s", stage_id, env_var, e) selected_physical = None if selected_physical is None: selected_physical = int(logical_idx) - os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical) + os.environ[env_var] = str(selected_physical) logger.debug( - "[Stage-%s] Logical index %d -> physical %s; set CUDA_VISIBLE_DEVICES to single device", - stage_id, - logical_idx + 1, - selected_physical, + "[Stage-%s] Logical index %d -> physical %s; set %s to single device", + stage_id, logical_idx + 1, selected_physical, env_var, ) elif devices in (None, "cpu"): - logger.debug( - "[Stage-%s] Using default device visibility (devices=%s)", - stage_id, - devices, - ) + logger.debug("[Stage-%s] Using default device visibility (devices=%s)", stage_id, devices) else: selected_physical = int(str(devices)) - os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical) - logger.debug( - "[Stage-%s] Set CUDA_VISIBLE_DEVICES to single device %s (fallback)", - stage_id, - selected_physical, - ) + os.environ[env_var] = str(selected_physical) + logger.debug("[Stage-%s] Set %s to single device %s (fallback)", stage_id, env_var, selected_physical) try: import torch # noqa: WPS433 - - if torch.cuda.is_available(): + if is_available_fn(): try: - torch.cuda.set_device(0) + set_device_fn(0) except Exception as e: - logger.debug( - "[Stage-%s] torch.cuda.set_device(0) failed: %s", - stage_id, - e, - exc_info=True, - ) - num = torch.cuda.device_count() + logger.debug("[Stage-%s] %s set_device(0) failed: %s", stage_id, device_type_label, e, exc_info=True) + num = device_count_fn() info = [] for i in range(num): - total = torch.cuda.get_device_properties(i).total_memory - free, _ = torch.cuda.mem_get_info(i) - info.append( - { - "idx": i, - "name": torch.cuda.get_device_name(i), - "total": int(total), - "free": int(free), - } - ) - logger.debug("[Stage-%s] CUDA devices visible=%s info=%s", stage_id, num, info) + total = get_device_properties_fn(i).total_memory + free, _ = mem_get_info_fn(i) + info.append({ + "idx": i, + "name": get_device_name_fn(i), + "total": int(total), + "free": int(free), + }) + logger.debug("[Stage-%s] %s devices visible=%s info=%s", stage_id, device_type_label, num, info) except Exception as e: - logger.debug( - "[Stage-%s] Failed to query CUDA devices: %s", - stage_id, - e, - exc_info=True, - ) + logger.debug("[Stage-%s] Failed to query %s devices: %s", stage_id, device_type_label, e, exc_info=True) except Exception as e: logger.warning("Failed to interpret devices for stage %s: %s", stage_id, e) @@ -126,7 +148,7 @@ def serialize_obj(obj: Any) -> bytes: return cloudpickle.dumps(obj) -def shm_write_bytes(payload: bytes) -> dict[str, Any]: +def shm_write_bytes(payload: bytes) -> Dict[str, Any]: """Write bytes into SharedMemory and return meta dict {name,size}. Caller should close the segment; the receiver should unlink. @@ -143,7 +165,7 @@ def shm_write_bytes(payload: bytes) -> dict[str, Any]: return meta -def shm_read_bytes(meta: dict[str, Any]) -> bytes: +def shm_read_bytes(meta: Dict[str, Any]) -> bytes: """Read bytes from SharedMemory by meta {name,size} and cleanup.""" shm = _shm.SharedMemory(name=meta["name"]) # type: ignore[index] mv = memoryview(shm.buf) @@ -170,7 +192,7 @@ def _ensure_parent_dir(path: str) -> None: pass -def append_jsonl(path: str, record: dict[str, Any]) -> None: +def append_jsonl(path: str, record: Dict[str, Any]) -> None: """Append a JSON record as one line to a JSONL file (best-effort). This is safe to call from multiple processes when each process writes @@ -187,7 +209,7 @@ def append_jsonl(path: str, record: dict[str, Any]) -> None: logger.exception("Failed to append JSONL to %s", path) -def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]: +def maybe_dump_to_shm(obj: Any, threshold: int) -> Tuple[bool, Any]: """Dump object to SHM if serialized size exceeds threshold. Returns (True, meta) when dumped; otherwise (False, original_obj). @@ -198,7 +220,7 @@ def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]: return False, obj -def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) -> Any: +def maybe_load_from_ipc(container: Dict[str, Any], obj_key: str, shm_key: str) -> Any: """Load object from container that may carry SHM or inline object. Deprecated: prefer `maybe_load_from_ipc_with_metrics` to also obtain @@ -209,9 +231,7 @@ def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) - return container[obj_key] -def maybe_load_from_ipc_with_metrics( - container: dict[str, Any], obj_key: str, shm_key: str -) -> tuple[Any, dict[str, float]]: +def maybe_load_from_ipc_with_metrics(container: Dict[str, Any], obj_key: str, shm_key: str) -> tuple[Any, Dict[str, float]]: """Load object and return (object, metrics) with RX bytes and decode time. Metrics keys: @@ -219,7 +239,6 @@ def maybe_load_from_ipc_with_metrics( - rx_decode_time_ms: float """ import time as _time # local import to avoid overhead at module import - t0 = _time.time() if shm_key in container: meta = container[shm_key] # type: ignore[index] @@ -243,13 +262,13 @@ def maybe_load_from_ipc_with_metrics( } -def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict[str, Any]: +def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> Dict[str, Any]: """Return a dict payload for IPC: inline (obj_key) or SHM (shm_key). When serialized size exceeds threshold, returns {shm_key: {name,size}}; otherwise returns {obj_key: obj}. """ - payload: dict[str, Any] = {} + payload: Dict[str, Any] = {} use_shm, data = maybe_dump_to_shm(obj, threshold) if use_shm: payload[shm_key] = data @@ -259,7 +278,7 @@ def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict # Convert OmegaConf/objects to plain dicts -def _to_dict(x: Any) -> dict[str, Any]: +def _to_dict(x: Any) -> Dict[str, Any]: try: if isinstance(x, dict): return dict(x) @@ -268,4 +287,4 @@ def _to_dict(x: Any) -> dict[str, Any]: try: return dict(x) except Exception: - return {} + return {} \ No newline at end of file diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 7662304c068..a806a9df6c4 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -1,13 +1,11 @@ -from __future__ import annotations - -import logging import os from pathlib import Path +from typing import Optional from omegaconf import OmegaConf -from vllm.transformers_utils.config import get_config -logger = logging.getLogger(__name__) +from vllm.transformers_utils.config import get_config +from vllm_omni.utils import detect_device_type # Get the project root directory (2 levels up from this file) PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -29,3 +27,20 @@ def load_stage_configs_from_yaml(config_path: str): """Load stage configs from yaml file.""" config_data = OmegaConf.load(config_path) return config_data.stage_args + + +def select_worker_class(worker_cls: Optional[str], device_type: Optional[str] = None) -> Optional[str]: + """Select appropriate worker class based on device type.""" + if worker_cls is None: + return None + + if device_type is None: + device_type = detect_device_type() + + if device_type == "npu": + if "gpu_ar_worker" in worker_cls: + worker_cls = worker_cls.replace("gpu_ar_worker", "npu_ar_worker") + elif "gpu_diffusion_worker" in worker_cls: + worker_cls = worker_cls.replace("gpu_diffusion_worker", "npu_diffusion_worker") + + return worker_cls diff --git a/vllm_omni/model_executor/models/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni.py index ae5ff25ca58..4b29d266f6a 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni.py @@ -1,8 +1,7 @@ import glob import os -from collections.abc import Iterable from functools import cached_property -from typing import NamedTuple, Optional, Union +from typing import Dict, Iterable, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch @@ -12,6 +11,7 @@ Qwen2_5OmniTalkerConfig, Qwen2_5OmniThinkerConfig, ) + from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP @@ -29,8 +29,9 @@ from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler - -from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights TALKER_CODEC_EOS_TOKEN_ID = 8294 @@ -56,6 +57,7 @@ class OmniOutput(NamedTuple): class Qwen2_5OmniForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.have_multimodal_outputs = True @@ -115,15 +117,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): architectures=["Qwen2_5OmniToken2WavModel"], ) # voice resources (loaded on demand) - self._token2wav_conds: dict[str, torch.Tensor] = {} - self._token2wav_ref_mels: dict[str, torch.Tensor] = {} + self._token2wav_conds: Dict[str, torch.Tensor] = {} + self._token2wav_ref_mels: Dict[str, torch.Tensor] = {} self.model = self.token2wav else: raise ValueError("Invalid model stage") # Set up intermediate tensors self.make_empty_intermediate_tensors = ( - (self.thinker.make_empty_intermediate_tensors) if self.model_stage == "thinker" else lambda: None + (self.thinker.make_empty_intermediate_tensors) + if self.model_stage == "thinker" + else lambda: None ) # -------------------- Device utilities -------------------- @@ -172,7 +176,11 @@ def get_input_embeddings( multimodal_embeddings=None, ) -> torch.Tensor: if self.model_stage == "code2wav": - return torch.zeros_like(input_ids).reshape(-1, 1).repeat(1, self.vllm_config.model_config.get_hidden_size()) + return ( + torch.zeros_like(input_ids) + .reshape(-1, 1) + .repeat(1, self.vllm_config.model_config.get_hidden_size()) + ) return self.model.get_input_embeddings(input_ids, multimodal_embeddings) def get_multimodal_embeddings(self, **kwargs): @@ -221,7 +229,9 @@ def forward( # if input_ids is None, set it to a zero tensor, in the length of the # same as the embedding seq length if input_ids is None: - input_ids = torch.zeros(inputs_embeds.shape[1], dtype=torch.long, device=thinker_dev).unsqueeze( + input_ids = torch.zeros( + inputs_embeds.shape[1], dtype=torch.long, device=thinker_dev + ).unsqueeze( 0 ) # (1, 0) added_batch_dim = True @@ -249,7 +259,11 @@ def forward( # Text-only path return OmniOutput( - text_hidden_states=(text_hidden_states.squeeze(0) if added_batch_dim else text_hidden_states), + text_hidden_states=( + text_hidden_states.squeeze(0) + if added_batch_dim + else text_hidden_states + ), multimodal_outputs=None, ) @@ -259,17 +273,13 @@ def forward( # Rules: # - Prefill segments are wrapped with special tokens: [BOS][PAD...][EOS] # - Decode segments consist of a single non-special token. - # - If additional_information is provided - # (can be a list split by request or a concatenated tensor plus a list of shapes), + # - If additional_information is provided (can be a list split by request or a concatenated tensor plus a list of shapes), # then for each request, reconstruct the thinker→talker input embeddings for the Prefill segments; - # - For Decode segments, if per-request auxiliary decode embeddings are provided (optional), add them; - # otherwise, keep the original embedding. + # - For Decode segments, if per-request auxiliary decode embeddings are provided (optional), add them; otherwise, keep the original embedding. if input_ids is None and additional_information is None: input_ids = torch.zeros( - inputs_embeds.shape[0], - dtype=torch.long, - device=inputs_embeds.device, + inputs_embeds.shape[0], dtype=torch.long, device=inputs_embeds.device ) additional_information = {} self.thinker_reply_part = torch.zeros_like(inputs_embeds) @@ -283,17 +293,13 @@ def forward( # ------- Request-scoped additional information (no cross-request concat) ------- request_ids: Optional[list[str]] = kwargs.get("request_ids") # ordered - request_token_spans: Optional[list[tuple[int, int]]] = kwargs.get("request_token_spans") + request_token_spans: Optional[list[tuple[int,int]]] = kwargs.get("request_token_spans") addi_by_req: Optional[dict] = kwargs.get("additional_information_by_req_id") runtime_addi = kwargs.get("runtime_additional_information") # Normalize runtime_addi into a mapping by request_id for convenience runtime_addi_by_req: dict[str, dict] = {} - if ( - isinstance(request_ids, list) - and isinstance(runtime_addi, list) - and len(runtime_addi) == len(request_ids) - ): + if isinstance(request_ids, list) and isinstance(runtime_addi, list) and len(runtime_addi) == len(request_ids): for i, rid in enumerate(request_ids): if isinstance(rid, str) and isinstance(runtime_addi[i], dict): runtime_addi_by_req[rid] = runtime_addi[i] @@ -326,19 +332,9 @@ def forward( otoks = info.get("thinker_output_token_ids") # list[int] if not isinstance(pe, torch.Tensor): - pe = torch.zeros( - 0, - self.talker.config.hidden_size, - dtype=inputs_embeds.dtype, - device=self._module_device(self.model), - ) + pe = torch.zeros(0, self.talker.config.hidden_size, dtype=inputs_embeds.dtype, device=self._module_device(self.model)) if not isinstance(tr, torch.Tensor): - tr = torch.zeros( - 0, - self.talker.config.hidden_size, - dtype=inputs_embeds.dtype, - device=self._module_device(self.model), - ) + tr = torch.zeros(0, self.talker.config.hidden_size, dtype=inputs_embeds.dtype, device=self._module_device(self.model)) if not isinstance(ptoks, (list, torch.Tensor)): ptoks = [] if not isinstance(otoks, (list, torch.Tensor)): @@ -358,9 +354,7 @@ def forward( # Prepare per-request reply queue for subsequent decode: drop first row if tr.ndim == 2 and tr.shape[0] > 0: - update_by_req_id.setdefault(rid, {})["thinker_reply_part_per_request"] = ( - tr[1:].detach().to("cpu").contiguous() - ) + update_by_req_id.setdefault(rid, {})["thinker_reply_part_per_request"] = tr[1:].detach().to("cpu").contiguous() # ------- Decode: span_len == 1 ------- if not is_profile and isinstance(request_ids, list) and isinstance(request_token_spans, list): @@ -384,11 +378,7 @@ def forward( dv = info.get("decode_output_prompt_embeds") if isinstance(info, dict) else None if isinstance(dv, torch.Tensor) and dv.numel() > 0: step_vec = dv[0:1] if dv.ndim == 2 else dv.view(1, -1) - elif ( - hasattr(self, "thinker_reply_part") - and isinstance(self.thinker_reply_part, torch.Tensor) - and self.thinker_reply_part.numel() > 0 - ): + elif hasattr(self, "thinker_reply_part") and isinstance(self.thinker_reply_part, torch.Tensor) and self.thinker_reply_part.numel() > 0: # C) fallback shared pool step_vec = self.thinker_reply_part[0:1] self.thinker_reply_part = self.thinker_reply_part[1:] @@ -413,8 +403,7 @@ def forward( multimodal_outputs = {"additional_information_update_by_req_id": update_by_req_id} if sampling_metadata is not None: - # the padding token id is set to text model's pad token id, - # which do not match with the talker model's word embedding size + # the padding token id is set to text model's pad token id, which do not match with the talker model's word embedding size sampling_metadata.prompt_token_ids[sampling_metadata.prompt_token_ids == 152064] = 8448 return OmniOutput( @@ -432,12 +421,16 @@ def forward( device=inputs_embeds.device, ) ) - + code = code[:-1] if code[-1] == TALKER_CODEC_EOS_TOKEN_ID else code code = code[1:] if code[0] == TALKER_CODEC_BOS_TOKEN_ID else code - - audio_tensor = self.generate_audio(code, voice_type) - return OmniOutput(text_hidden_states=None, multimodal_outputs={"audio": audio_tensor}) + + audio_tensor = self.generate_audio( + code, voice_type + ) + return OmniOutput( + text_hidden_states=None, multimodal_outputs={"audio": audio_tensor} + ) return OmniOutput( text_hidden_states=torch.cat( @@ -448,7 +441,9 @@ def forward( ).to(self._module_device(self.model)), self.talker.thinker_to_talker_proj( self.talker.get_input_embeddings( - torch.tensor([TALKER_CODEC_BOS_TOKEN_ID, TALKER_CODEC_EOS_TOKEN_ID]) + torch.tensor( + [TALKER_CODEC_BOS_TOKEN_ID, TALKER_CODEC_EOS_TOKEN_ID] + ) .to(torch.bfloat16) .to(self._module_device(self.model)) ) @@ -575,6 +570,7 @@ def _thinker_to_talker_prefill( thinker_prompt_embeds, prompt_token_ids, ): + talker_hf_config = self.talker_config if hasattr(talker_hf_config, "talker_config"): talker_hf_config = talker_hf_config.talker_config @@ -597,7 +593,9 @@ def _thinker_to_talker_prefill( input_tokens_len = len(prompt_token_ids_processed) # the code below is from model runner in Qwen, may need to further discuss later if input_tokens_len > 2: - prompt_token_ids_processed = [self.talker_config.tts_codec_mask_token_id] * (input_tokens_len - 2) + [ + prompt_token_ids_processed = [ + self.talker_config.tts_codec_mask_token_id + ] * (input_tokens_len - 2) + [ self.talker_config.tts_codec_pad_token_id, self.talker_config.tts_codec_start_token_id, ] @@ -608,7 +606,9 @@ def _thinker_to_talker_prefill( ][-input_tokens_len:] if isinstance(prompt_token_ids_processed, list): prompt_token_ids_processed = ( - torch.Tensor(prompt_token_ids_processed).to(torch.int64).to(self._module_device(self.talker)) + torch.Tensor(prompt_token_ids_processed) + .to(torch.int64) + .to(self._module_device(self.talker)) ) return prompt_token_ids_processed, prompt_embeds @@ -617,12 +617,14 @@ def _thinker_to_talker_decode_one_step( output_prompt_embeds, output_token_ids, ): - processed_output_token_embeds = output_prompt_embeds + self.talker.get_input_embeddings( - output_token_ids + processed_output_token_embeds = ( + output_prompt_embeds + self.talker.get_input_embeddings(output_token_ids) ) # for decode return output_token_ids, processed_output_token_embeds - def compute_logits(self, hidden_states: Union[torch.Tensor, OmniOutput]) -> Optional[torch.Tensor]: + def compute_logits( + self, hidden_states: Union[torch.Tensor, OmniOutput] + ) -> Optional[torch.Tensor]: # Handle OmniOutput type if isinstance(hidden_states, OmniOutput): hidden_states = hidden_states.text_hidden_states @@ -638,7 +640,7 @@ def sample( # Use thinker model for sampling return self.model.sample(logits, sampling_metadata) - def generate_speech(self, text_tokens: torch.Tensor, voice_type: str = "default") -> torch.Tensor: + def generate_speech(self, text_tokens: torch.Tensor, voice_type: str = "default"): """ Generate speech from text tokens using the talker and token2wav models. This method is kept for backward compatibility and direct speech generation. @@ -651,7 +653,9 @@ def generate_speech(self, text_tokens: torch.Tensor, voice_type: str = "default" Audio tensor """ # Generate codec tokens using talker model - talker_output = self.talker(input_ids=None, positions=None, inputs_embeds=text_tokens) + talker_output = self.talker( + input_ids=None, positions=None, inputs_embeds=text_tokens + ) # Convert talker output to codec tokens codec_tokens = self._convert_to_codec_tokens(talker_output) @@ -678,7 +682,9 @@ def _convert_to_codec_tokens( # Suppress only codec_bos, consistent with HF generate's # suppress_tokens behavior bos_id = None - if hasattr(self, "talker_config") and hasattr(self.talker_config, "tts_codec_start_token_id"): + if hasattr(self, "talker_config") and hasattr( + self.talker_config, "tts_codec_start_token_id" + ): bos_id = int(getattr(self.talker_config, "tts_codec_start_token_id")) if bos_id is not None: logits[..., bos_id] = -1e9 @@ -692,13 +698,17 @@ def _init_token2wav_model(self, hf_model_folder): __init__.""" if self.token2wav is None or self.token2wav_config is None: return - device = "cuda" if torch.cuda.is_available() else "cpu" + device = self._module_device(self.token2wav) # optional speaker resources conds = getattr(self.token2wav_config, "conds", None) ref_mels = getattr(self.token2wav_config, "ref_mels", None) if isinstance(conds, dict) and isinstance(ref_mels, dict): - self._token2wav_conds = {k: torch.as_tensor(v, device=device) for k, v in conds.items()} - self._token2wav_ref_mels = {k: torch.as_tensor(v, device=device) for k, v in ref_mels.items()} + self._token2wav_conds = { + k: torch.as_tensor(v, device=device) for k, v in conds.items() + } + self._token2wav_ref_mels = { + k: torch.as_tensor(v, device=device) for k, v in ref_mels.items() + } # legacy: load from directory if provided model_path = hf_model_folder if isinstance(model_path, str) and os.path.isdir(model_path): @@ -710,14 +720,24 @@ def _init_token2wav_model(self, hf_model_folder): self._token2wav_ref_mels[key] = value["ref_mel"].to(device) else: # legacy npy inputs - for f in sorted(glob.glob(os.path.join(model_path, "inputs", "*spk_emb.npy"))): + for f in sorted( + glob.glob(os.path.join(model_path, "inputs", "*spk_emb.npy")) + ): key = os.path.basename(f).split("_")[0].lower() - self._token2wav_conds[key] = torch.as_tensor(np.load(f), device=device) - for f in sorted(glob.glob(os.path.join(model_path, "inputs", "*ref_mel.npy"))): + self._token2wav_conds[key] = torch.as_tensor( + np.load(f), device=device + ) + for f in sorted( + glob.glob(os.path.join(model_path, "inputs", "*ref_mel.npy")) + ): key = os.path.basename(f).split("_")[0].lower() - self._token2wav_ref_mels[key] = torch.as_tensor(np.load(f), device=device) + self._token2wav_ref_mels[key] = torch.as_tensor( + np.load(f), device=device + ) - def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default") -> Optional[torch.Tensor]: + def _codec_to_audio( + self, codec_tokens: torch.Tensor, voice_type: str = "default" + ) -> Optional[torch.Tensor]: if self.token2wav is None: self._init_token2wav_model() if self.token2wav is None: @@ -751,7 +771,9 @@ def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default if codec.ndim == 1: codec = codec.unsqueeze(0) else: - codec = torch.as_tensor(codec_tokens, dtype=torch.long, device=token2wav_dev).unsqueeze(0) + codec = torch.as_tensor( + codec_tokens, dtype=torch.long, device=token2wav_dev + ).unsqueeze(0) # Streaming with chunked process and boundary alignment # (rely on token2wav.process_chunk) @@ -766,7 +788,9 @@ def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default steps = 10 # Prepare initial noise for the whole sequence - y_all = torch.randn((1, total_mel, mel_dim), dtype=ref_mel.dtype, device=token2wav_dev) + y_all = torch.randn( + (1, total_mel, mel_dim), dtype=ref_mel.dtype, device=token2wav_dev + ) logger.info( "Currently, we do not use the chunked process, we only use the " @@ -778,7 +802,9 @@ def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default for i in range(codec.shape[1]): chunk_code_length = i * 2 - 24 finished = i == (codec.shape[1] - 1) - if (chunk_code_length > 0 and chunk_code_length % chunk_size == 0) or finished: + if ( + chunk_code_length > 0 and chunk_code_length % chunk_size == 0 + ) or finished: chunk_ends.append(i) # Number of chunks in mel domain @@ -807,7 +833,7 @@ def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default waveform = np.concatenate(wav_chunks) return torch.as_tensor(waveform, device=token2wav_dev) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: """Load weights for all components of the omni model.""" loaded_weights = set() thinker_weights = [] @@ -836,14 +862,20 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if talker_weights and self.talker is not None: # Map talker weights to appropriate components if self.thinker is None: - thinker_embedding_weights = [w for n, w in thinker_weights if n == "thinker.model.embed_tokens.weight"] + thinker_embedding_weights = [ + w + for n, w in thinker_weights + if n == "thinker.model.embed_tokens.weight" + ] if thinker_embedding_weights: self.thinker_embedding = nn.Embedding( thinker_embedding_weights[0].shape[0], thinker_embedding_weights[0].shape[1], ) self.thinker_embedding.weight = nn.Parameter( - thinker_embedding_weights[0].to(self._module_device(self.talker)) + thinker_embedding_weights[0].to( + self._module_device(self.talker) + ) ) talker_loaded = self.talker.load_weights(talker_weights) talker_loaded = add_prefix_to_loaded_weights(talker_loaded, "talker") @@ -853,18 +885,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Load token2wav weights (if any) if token2wav_weights and self.token2wav is not None: # download weights from huggingface for spk_dict.pt - model_path = self.vllm_config.model_config.model - download_dir = self.vllm_config.load_config.download_dir - if os.path.exists(model_path): - hf_model_folder = model_path - else: - hf_model_folder = download_weights_from_hf_specific( - model_path, - download_dir, - allow_patterns=["*.pt"], - ) + hf_model_folder = download_weights_from_hf_specific( + self.vllm_config.model_config.model, + self.vllm_config.load_config.download_dir, + allow_patterns=["*.pt"], + ) self._init_token2wav_model(hf_model_folder) - t2w_loaded = self.token2wav.load_weights(token2wav_weights, os.path.join(hf_model_folder, "spk_dict.pt")) + t2w_loaded = self.token2wav.load_weights( + token2wav_weights, os.path.join(hf_model_folder, "spk_dict.pt") + ) t2w_loaded = add_prefix_to_loaded_weights(t2w_loaded, "token2wav") loaded_weights.update(t2w_loaded) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py index f1109a9a7b1..8c541cc2a09 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py @@ -3,8 +3,7 @@ ############################ import math -from collections.abc import Iterable -from typing import Optional, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -17,17 +16,24 @@ Qwen2_5OmniDiTConfig, Qwen2_5OmniToken2WavConfig, ) -from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniPreTrainedModel +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniPreTrainedModel, +) # Bring in HF base classes, configs and utilities used below from transformers.utils.logging import get_logger as _hf_get_logger + from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.models.utils import AutoWeightsLoader as _Vllm_AutoWeightsLoader +from vllm.model_executor.models.utils import ( + AutoWeightsLoader as _Vllm_AutoWeightsLoader, +) from vllm.model_executor.models.utils import WeightsMapper as _Vllm_WeightsMapper -from vllm.model_executor.models.utils import init_vllm_registered_model as _vllm_init_vllm_registered_model -from vllm.model_executor.models.utils import maybe_prefix as _vllm_maybe_prefix +from vllm.model_executor.models.utils import ( + init_vllm_registered_model as _Vllm_init_vllm_registered_model, +) +from vllm.model_executor.models.utils import maybe_prefix as _Vllm_maybe_prefix from vllm.sequence import IntermediateTensors from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -61,7 +67,8 @@ def forward(self, x): batch_size, seq_len = x.shape[0], x.shape[1] t = torch.arange(seq_len, device=x.device) device_type = x.device.type - device_type = device_type if device_type != "mps" else "cpu" + if device_type in ("mps", "npu"): + device_type = "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = torch.stack((freqs, freqs), dim=-1) @@ -212,7 +219,9 @@ def _length_to_mask(self, length, max_len=None, dtype=None, device=None): def _compute_statistics(self, x, m, dim=2): mean = (m * x).sum(dim) - std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + std = torch.sqrt( + (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps) + ) return mean, std def forward(self, hidden_states): @@ -274,7 +283,9 @@ def __init__( kernel_size=1, dilation=1, ) - self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.res2net_block = Res2NetBlock( + out_channels, out_channels, res2net_scale, kernel_size, dilation + ) self.tdnn2 = TimeDelayNetBlock( out_channels, out_channels, @@ -302,10 +313,13 @@ class ECAPA_TimeDelayNet(torch.nn.Module): def __init__(self, config: Qwen2_5OmniDiTConfig): super().__init__() - if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( - config.enc_dilations - ): - raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len( + config.enc_channels + ) != len(config.enc_dilations): + raise ValueError( + "enc_channels, enc_kernel_sizes and enc_dilations should have " + "same length" + ) self.channels = config.enc_channels self.blocks = nn.ModuleList() @@ -399,14 +413,26 @@ def forward( ): if apply_cfg: hidden_states = torch.cat([hidden_states, hidden_states], dim=0) - speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) - condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + speaker_embedding = torch.cat( + [speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0 + ) + condition_vector = torch.cat( + [condition_vector, torch.zeros_like(condition_vector)], dim=0 + ) code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) elif drop_audio_cond: # cfg for cond audio condition_vector = torch.zeros_like(condition_vector) speaker_embedding = torch.zeros_like(speaker_embedding) - condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) - hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + condition_vector = ( + self.spk_encoder(condition_vector) + .unsqueeze(1) + .repeat(1, hidden_states.size(1), 1) + ) + hidden_states = self.proj( + torch.cat( + (hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1 + ) + ) return hidden_states @@ -440,9 +466,13 @@ def __init__(self, dim): def forward(self, hidden_states, emb=None): emb = self.linear(self.silu(emb)) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk( + emb, 6, dim=1 + ) - hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + hidden_states = ( + self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + ) return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp @@ -461,7 +491,9 @@ def forward(self, hidden_states, emb): emb = self.linear(self.silu(emb)) scale, shift = torch.chunk(emb, 2, dim=1) - hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + hidden_states = ( + self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + ) return hidden_states @@ -541,7 +573,9 @@ def __init__(self, config: Qwen2_5OmniDiTConfig): self.to_k = nn.Linear(config.hidden_size, self.inner_dim) self.to_v = nn.Linear(config.hidden_size, self.inner_dim) - self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + self.to_out = nn.ModuleList( + [nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)] + ) def forward( self, @@ -567,7 +601,9 @@ def forward( # Due to training process, only first head is applied with RoPE, # will be fixed at next release cos, sin = position_embeddings - query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + query[:, :1], key[:, :1] = apply_rotary_pos_emb( + query[:, :1], key[:, :1], cos, sin + ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_weights, _ = attention_interface( @@ -581,7 +617,9 @@ def forward( # mask. e.g. inference got a batch with different target durations, # mask out the padding - attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.reshape( + batch_size, -1, self.heads * head_dim + ) attention_weights = attention_weights.to(query.dtype) # linear proj @@ -611,7 +649,9 @@ class DiTTimestepEmbedding(nn.Module): def __init__(self, dim, freq_embed_dim=256): super().__init__() self.time_embed = SinusPositionEmbedding(freq_embed_dim) - self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + self.time_mlp = nn.ModuleList( + [nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)] + ) def forward(self, timestep): # noqa: F821 time_hidden = self.time_embed(timestep) @@ -622,21 +662,29 @@ def forward(self, timestep): # noqa: F821 class DiTDecoderLayer(nn.Module): - def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + def __init__( + self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0 + ): super().__init__() self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) self.attn = DiTAttention(config) self.look_ahead_block = look_ahead_block self.look_backward_block = look_backward_block - self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) - self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + self.ff_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=False, eps=1e-6 + ) + self.ff = DiTMLP( + dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout + ) def forward( self, hidden_states, timestep, position_embeddings=None, block_diff=None ): # x: noised input, t: time embedding # pre-norm & modulation for attention input - norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm( + hidden_states, emb=timestep + ) # attention attn_output = self.attn( @@ -649,7 +697,9 @@ def forward( # process attention output for input x hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output - norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm = ( + self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) ff_output = self.ff(norm) hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output @@ -692,14 +742,14 @@ def forward(self, hidden_states): beta = self.beta.unsqueeze(0).unsqueeze(-1) alpha = torch.exp(alpha) beta = torch.exp(beta) - hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( - torch.sin(hidden_states * alpha), 2 - ) + hidden_states = hidden_states + ( + 1.0 / (beta + self.no_div_by_zero) + ) * torch.pow(torch.sin(hidden_states * alpha), 2) return hidden_states -def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): """Generates a 1D Kaiser-windowed sinc filter. Args: @@ -724,7 +774,9 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> else: beta = 0.0 - kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + kaiser_window = torch.kaiser_window( + kernel_size, beta=beta, periodic=False, dtype=torch.float32 + ) # Compute time indices if is_even: @@ -734,7 +786,9 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> # Compute sinc filter if cutoff == 0: - return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + return torch.zeros( + (1, 1, kernel_size), dtype=torch.float32 + ) # Ensures correct shape sinc_filter = torch.sinc(2 * cutoff * time_indices) normalized_filter = 2 * cutoff * kaiser_window * sinc_filter @@ -749,19 +803,27 @@ class UpSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None): super().__init__() self.ratio = ratio - self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) self.stride = ratio self.pad = self.kernel_size // ratio - 1 self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 - self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) - filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) self.register_buffer("filter", filter, persistent=False) def forward(self, hidden_states): channels = hidden_states.shape[1] hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to(self.filter.dtype) + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to( + self.filter.dtype + ) hidden_states = self.ratio * F.conv_transpose1d( hidden_states, self.filter.expand(channels, -1, -1), @@ -794,7 +856,9 @@ def __init__(self, ratio=2, kernel_size=None): def forward(self, hidden_states): channels = hidden_states.shape[1] hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate").to(self.filter.dtype) + hidden_states = F.pad( + hidden_states, (self.pad_left, self.pad_right), mode="replicate" + ).to(self.filter.dtype) out = F.conv1d( hidden_states, self.filter.expand(channels, -1, -1), @@ -895,10 +959,15 @@ def __init__( ] ) - self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + self.num_layers = len(self.convs1) + len( + self.convs2 + ) # total number of conv layers self.activations = nn.ModuleList( - [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + [ + TorchActivation1d(activation=SnakeBeta(channels)) + for _ in range(self.num_layers) + ] ) def _get_padding(self, kernel_size, dilation=1): @@ -931,7 +1000,9 @@ def __init__(self, config: Qwen2_5OmniBigVGANConfig): self.num_residual_blocks = len(config.resblock_kernel_sizes) self.num_upsample_layers = len(config.upsample_rates) - self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + self.conv_pre = nn.Conv1d( + config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3 + ) # Removing extra ModuleList breaks official state dict ups = [ @@ -946,7 +1017,9 @@ def __init__(self, config: Qwen2_5OmniBigVGANConfig): ) ] ) - for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + for layer_idx, (stride, kernel_size) in enumerate( + zip(config.upsample_rates, config.upsample_kernel_sizes) + ) ] self.ups = nn.ModuleList(ups) @@ -958,12 +1031,16 @@ def __init__(self, config: Qwen2_5OmniBigVGANConfig): dilation, ) for layer_idx in range(self.num_upsample_layers) - for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + for kernel_size, dilation in zip( + config.resblock_kernel_sizes, config.resblock_dilation_sizes + ) ] ) self.activation_post = TorchActivation1d( - activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + activation=SnakeBeta( + config.upsample_initial_channel // (2**self.num_upsample_layers) + ) ) self.conv_post = nn.Conv1d( config.upsample_initial_channel // (2**self.num_upsample_layers), @@ -1003,7 +1080,9 @@ def forward(self, mel_spectrogram): for layer_index in range(self.num_upsample_layers): hidden_representation = self.ups[layer_index][0](hidden_representation) residual_output = sum( - self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + self.resblocks[layer_index * self.num_residual_blocks + block_index]( + hidden_representation + ) for block_index in range(self.num_residual_blocks) ) residual_output = residual_output / self.num_residual_blocks @@ -1031,7 +1110,11 @@ def _rk4_step( value_start, function_value_start=None, ): - k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k1 = ( + function_value_start + if function_value_start is not None + else function(time_start, value_start) + ) k2 = function( time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third, @@ -1057,7 +1140,9 @@ def _compute_step(self, function, time_start, time_step, time_end, value_start): function_value_start, ) - def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + def _linear_interpolation( + self, time_start, time_end, value_start, value_end, time_point + ): if time_point == time_start: return value_start if time_point == time_end: @@ -1078,10 +1163,15 @@ def integrate(self, time_points): current_value = self.initial_value for time_start, time_end in zip(time_points[:-1], time_points[1:]): time_step = time_end - time_start - delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + delta_value, _ = self._compute_step( + self.function, time_start, time_step, time_end, current_value + ) next_value = current_value + delta_value - while current_index < len(time_points) and time_end >= time_points[current_index]: + while ( + current_index < len(time_points) + and time_end >= time_points[current_index] + ): solution[current_index] = self._linear_interpolation( time_start, time_end, @@ -1112,7 +1202,9 @@ def __init__(self, config: Qwen2_5OmniDiTConfig): self.repeats = config.repeats self.time_embed = DiTTimestepEmbedding(config.hidden_size) - self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.text_embed = DiTCodecEmbedding( + config.num_embeds, config.emb_dim, config.repeats + ) self.input_embed = DiTInputEmbedding(config) self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) @@ -1132,12 +1224,16 @@ def __init__(self, config: Qwen2_5OmniDiTConfig): ) ) - self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final( + config.hidden_size + ) # final modulation self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) def _create_block_diff(self, hidden_states): batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] - block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + block_indices = ( + torch.arange(seq_len, device=hidden_states.device) // self.block_size + ) # [seq_length] block_i = block_indices.unsqueeze(1) # [seq_length, 1] block_j = block_indices.unsqueeze(0) # [1, seq_length] @@ -1162,8 +1258,12 @@ def forward( # Compute embeddings time_embedding = self.time_embed(time_step) - text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) - text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + text_embedding = self.text_embed( + quantized_code, drop_code=False if apply_cfg else drop_code + ) + text_embedding_unconditioned = ( + self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + ) hidden_states = self.input_embed( hidden_states, @@ -1203,11 +1303,17 @@ def sample( guidance_scale=0.5, sway_coefficient=-1.0, ): - noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype) + noise_initialization = torch.randn( + [1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype + ) maximum_duration = quantized_code.shape[1] * self.repeats - initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + initial_state = noise_initialization[:, :maximum_duration].to( + quantized_code.device + ) batch_size = reference_mel_spectrogram.shape[0] - conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + conditioning_vector = conditioning_vector.unsqueeze(1).repeat( + 1, maximum_duration, 1 + ) if batch_size != 1: raise ValueError("Only batch size = 1 is currently supported") @@ -1234,7 +1340,10 @@ def ode_function(time_step, hidden_states): apply_cfg=True, ) guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) - return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + return ( + guided_prediction + + (guided_prediction - null_prediction) * guidance_scale + ) initial_time = 0 time_embedding = torch.linspace( @@ -1246,9 +1355,13 @@ def ode_function(time_step, hidden_states): ) if sway_coefficient is not None: - time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + time_embedding += sway_coefficient * ( + torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding + ) - ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + ode_solver = RungeKutta4ODESolver( + function=ode_function, initial_value=initial_state + ) solution_trajectory = ode_solver.integrate(time_embedding) generated_waveform = solution_trajectory[-1] @@ -1258,14 +1371,14 @@ def ode_function(time_step, hidden_states): @torch.no_grad() def fast_block_sample( self, - conditioning_vector: torch.Tensor, - reference_mel_spectrogram: torch.Tensor, - quantized_code: torch.Tensor, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, y0: torch.Tensor, - num_steps: int = 10, - guidance_scale: float = 0.5, + num_steps=10, + guidance_scale=0.5, sway_coefficient: Optional[float] = -1.0, - ) -> torch.Tensor: + ): """ Block-wise ODE sampling starting from provided initial state y0. @@ -1279,7 +1392,9 @@ def fast_block_sample( """ initial_state = y0.to(quantized_code.device) batch_size = reference_mel_spectrogram.shape[0] - conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, initial_state.shape[1], 1) + conditioning_vector = conditioning_vector.unsqueeze(1).repeat( + 1, initial_state.shape[1], 1 + ) if batch_size != 1: raise ValueError("Only batch size = 1 is currently supported") @@ -1306,7 +1421,10 @@ def ode_function(time_step, hidden_states): apply_cfg=True, ) guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) - return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + return ( + guided_prediction + + (guided_prediction - null_prediction) * guidance_scale + ) initial_time = 0 time_embedding = torch.linspace( @@ -1318,9 +1436,13 @@ def ode_function(time_step, hidden_states): ) if sway_coefficient is not None: - time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + time_embedding += sway_coefficient * ( + torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding + ) - ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + ode_solver = RungeKutta4ODESolver( + function=ode_function, initial_value=initial_state + ) solution_trajectory = ode_solver.integrate(time_embedding) generated_waveform = solution_trajectory[-1] @@ -1356,7 +1478,8 @@ def __init__(self, config: Qwen2_5OmniToken2WavConfig): attn_impl = "sdpa" elif config._attn_implementation == "eager": logger.warning_once( - "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + "Qwen2_5OmniToken2WavModel does not support eager attention " + "implementation, fall back to sdpa" ) attn_impl = "sdpa" self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( @@ -1379,7 +1502,9 @@ def __init__(self, config: Qwen2_5OmniToken2WavConfig): # codec embedding size for masking EOS out-of-range try: - self.codec_embed_size = self.code2wav_dit_model.text_embed.codec_embed.weight.size(0) + self.codec_embed_size = ( + self.code2wav_dit_model.text_embed.codec_embed.weight.size(0) + ) except Exception: self.codec_embed_size = -1 @@ -1461,7 +1586,7 @@ def process_little_chunk( steps: int, prev_generated: torch.Tensor, finished: bool = False, - ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Streaming per small chunk: returns (mel_or_None, audio_slice).""" start_index = max(i * self.chunk_size - self.past_cache_size, 0) end_index = min( @@ -1469,8 +1594,16 @@ def process_little_chunk( codec_all.shape[1] * self.factor, ) - y0 = y_all[:, start_index:end_index].reshape(1, -1, self.code2wav_dit_model.mel_dim).contiguous() - codec = codec_all[:, start_index // self.factor : end_index // self.factor].reshape(1, -1).contiguous() + y0 = ( + y_all[:, start_index:end_index] + .reshape(1, -1, self.code2wav_dit_model.mel_dim) + .contiguous() + ) + codec = ( + codec_all[:, start_index // self.factor : end_index // self.factor] + .reshape(1, -1) + .contiguous() + ) # generate mel for current window (B, mel_dim, T) generated = self.process_chunk_dit_batch( @@ -1500,9 +1633,9 @@ def process_chunk( y_all: torch.Tensor, i: int, steps: int, - prev_generated: Union[torch.Tensor, list[torch.Tensor]], + prev_generated: Union[torch.Tensor, List[torch.Tensor]], finished: bool = False, - ) -> tuple[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor]: + ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor]: """High-level chunk API aligning to qwen2_code2wav_dit signature.""" if not isinstance(prev_generated, torch.Tensor): prev_generated = prev_generated[0] if len(prev_generated) > 0 else None @@ -1527,7 +1660,7 @@ def _process_chunk_for_50hz( finished: bool, prev_generated: Optional[torch.Tensor], generated: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Align mel and audio boundaries for 50Hz-like streaming. @@ -1544,13 +1677,21 @@ def _process_chunk_for_50hz( mel = generated[:, :, : self.chunk_size] elif finished: mel_trim = generated[:, :, self.past_cache_size :] - mel = torch.cat([prev_generated[:, :, -self.future_size * 2 :], mel_trim], dim=2) + mel = torch.cat( + [prev_generated[:, :, -self.future_size * 2 :], mel_trim], dim=2 + ) else: if start_index == 0: - mel_trim = generated[:, :, i * self.chunk_size : -self.future_cache_size] + mel_trim = generated[ + :, :, i * self.chunk_size : -self.future_cache_size + ] else: - mel_trim = generated[:, :, self.past_cache_size : -self.future_cache_size] - mel = torch.cat([prev_generated[:, :, -self.future_size * 2 :], mel_trim], dim=2) + mel_trim = generated[ + :, :, self.past_cache_size : -self.future_cache_size + ] + mel = torch.cat( + [prev_generated[:, :, -self.future_size * 2 :], mel_trim], dim=2 + ) audio = self.code2wav_bigvgan_model(mel) if i == 0: @@ -1558,7 +1699,11 @@ def _process_chunk_for_50hz( elif finished: audio_output = audio[self.future_size * self.vocoder_hop :] else: - audio_output = audio[self.future_size * self.vocoder_hop : -self.future_size * self.vocoder_hop] + audio_output = audio[ + self.future_size + * self.vocoder_hop : -self.future_size + * self.vocoder_hop + ] return mel, audio_output @@ -1582,9 +1727,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = vllm_config.model_config.hf_config # Initialize underlying HF Token2Wav model via registry - self.token2wav = _vllm_init_vllm_registered_model( + self.token2wav = _Vllm_init_vllm_registered_model( vllm_config=vllm_config, - prefix=_vllm_maybe_prefix(prefix, "token2wav_model"), + prefix=_Vllm_maybe_prefix(prefix, "token2wav_model"), hf_config=self.config, architectures=["Qwen2_5OmniToken2WavDiTModel"], ) @@ -1636,7 +1781,9 @@ def sample( ) -> Optional[SamplerOutput]: return None - def load_weights_without_buffers(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def load_weights_without_buffers( + self, weights: Iterable[Tuple[str, torch.Tensor]] + ) -> Set[str]: loader = _Vllm_AutoWeightsLoader(self) loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) # Log load summary @@ -1674,7 +1821,9 @@ def find_all_registers(self): return registers # remove buffers from the weights and reload them after loading weights - def remove_buffers_from_weights(self, weights: Iterable[tuple[str, torch.Tensor]], buffers: dict): + def remove_buffers_from_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], buffers: dict + ): weights_to_load = [] for key, value in weights: if key in buffers: @@ -1694,7 +1843,9 @@ def reload_buffers_to_model(self, buffers: dict): loaded_buffers.add(name) return loaded_buffers - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]], spk_dict_path: str) -> set[str]: + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], spk_dict_path: str + ) -> Set[str]: buffers = self.find_all_registers() weights_to_load = self.remove_buffers_from_weights(weights, buffers) loaded = self.load_weights_without_buffers(weights_to_load) @@ -1722,7 +1873,9 @@ def process_chunk_dit_batch( ) @torch.inference_mode() - def process_chunk_bigvgan_batch(self, mel_batch: torch.Tensor) -> Optional[torch.Tensor]: + def process_chunk_bigvgan_batch( + self, mel_batch: torch.Tensor + ) -> Optional[torch.Tensor]: # BigVGAN is not part of this wrapper; return None for parity. return None @@ -1737,7 +1890,7 @@ def process_little_chunk( steps: int, prev_generated: torch.Tensor, finished: bool = False, - ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: mel = self.token2wav( code=codec_all, conditioning=conditioning, @@ -1755,9 +1908,9 @@ def process_chunk( y_all: torch.Tensor, i: int, steps: int, - prev_generated: Union[torch.Tensor, list[torch.Tensor]], + prev_generated: Union[torch.Tensor, List[torch.Tensor]], finished: bool = False, - ) -> tuple[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor]: + ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor]: _mel, out = self.process_little_chunk( conditioning=conditioning, reference_mel=reference_mel, @@ -1765,7 +1918,9 @@ def process_chunk( y_all=y_all, i=i, steps=steps, - prev_generated=(prev_generated if isinstance(prev_generated, torch.Tensor) else None), + prev_generated=( + prev_generated if isinstance(prev_generated, torch.Tensor) else None + ), finished=finished, ) return _mel if _mel is not None else prev_generated, out diff --git a/vllm_omni/utils/platform_utils.py b/vllm_omni/utils/platform_utils.py new file mode 100644 index 00000000000..1379ad61248 --- /dev/null +++ b/vllm_omni/utils/platform_utils.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Callable + +import torch + +from vllm.platforms import current_platform + + +def detect_device_type() -> str: + device_type = getattr(current_platform, "device_type", None) + if isinstance(device_type, str) and device_type: + return device_type.lower() + if torch.cuda.is_available(): + return "cuda" + if hasattr(torch, "npu") and torch.npu.is_available(): # type: ignore[attr-defined] + return "npu" + return "cpu" + + +def is_npu() -> bool: + return detect_device_type() == "npu" + + +def get_device_control_env_var() -> str: + """Return the environment variable name for device visibility control.""" + if hasattr(current_platform, "device_control_env_var"): + env_var = getattr(current_platform, "device_control_env_var", None) + if isinstance(env_var, str) and env_var: + return env_var + + device_type = detect_device_type() + if device_type == "npu": + return "ASCEND_RT_VISIBLE_DEVICES" + return "CUDA_VISIBLE_DEVICES" # fallback \ No newline at end of file diff --git a/vllm_omni/worker/npu_ar_model_runner.py b/vllm_omni/worker/npu_ar_model_runner.py new file mode 100644 index 00000000000..032433a1545 --- /dev/null +++ b/vllm_omni/worker/npu_ar_model_runner.py @@ -0,0 +1,641 @@ +"""AR NPU Model Runner for vLLM-omni.""" + +from __future__ import annotations + +from typing import Any, Optional, Union + +import numpy as np +import torch + +from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import AsyncModelRunnerOutput +from vllm.v1.structured_output.utils import apply_grammar_bitmask +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.gpu_model_runner import ( + EMPTY_MODEL_RUNNER_OUTPUT, + IntermediateTensors, + get_pp_group, + get_tp_group, + has_kv_transfer_group, +) +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput +from vllm.v1.worker.ubatch_utils import UBatchSlices +from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.npu_model_runner import OmniNPUModelRunner + +logger = init_logger(__name__) + + +class NPUARModelRunner(OmniNPUModelRunner): + """Autoregressive NPU model runner that returns hidden states per request.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + num_scheduled_tokens_np: np.ndarray, + intermediate_tensors: Optional[IntermediateTensors] = None, + ubatch_slices: Optional[UBatchSlices] = None, + num_tokens_after_padding: Optional[torch.Tensor] = None, + ) -> tuple[ + int, + int, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + Optional[IntermediateTensors], + dict[str, Any], + Optional[dict[str, dict]], + ]: + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if ubatch_slices: + assert num_tokens_after_padding is not None + num_input_tokens = int(num_tokens_after_padding[0].item() * 2) + if hasattr(self, "pad_out_ubatch_slice"): + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif ubatch_slices is None: + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) + num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + per_req_additional_information: Optional[dict[str, dict]] = None + if ( + self.supports_mm_inputs + and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder + ): + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds or None, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) + + if hasattr(self, "_forward_additional_information"): + self._forward_additional_information = None + per_req_additional_information = {} + + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + addi_cpu = getattr(req_state, "additional_information_cpu", None) + num_computed_tokens = int( + self.input_batch.num_computed_tokens_cpu[req_index] + ) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + if pe_cpu is not None: + src = pe_cpu[ + num_computed_tokens : num_computed_tokens + overlay_len + ].to(dtype=self.dtype, device=self.device, non_blocking=True) + start_offset = int(self.query_start_loc_cpu[req_index]) + self.inputs_embeds[start_offset : start_offset + overlay_len].copy_( + src + ) + if addi_cpu is not None and isinstance(addi_cpu, dict): + req_info: dict[str, object] = {} + for k, v in addi_cpu.items(): + if isinstance(v, torch.Tensor): + if overlay_len > 0: + try: + seg = v[ + num_computed_tokens : num_computed_tokens + overlay_len + ].detach().to("cpu").contiguous() + except Exception: + seg = v.detach().to("cpu").contiguous() + req_info[k] = seg + else: + req_info[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + req_info[k] = v + else: + req_info[k] = v + per_req_additional_information[req_id] = req_info + + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_kwargs = { + **self._init_model_kwargs(num_scheduled_tokens), + **self._extract_mm_kwargs(scheduler_output), + } + elif self.enable_prompt_embeds and get_pp_group().is_first_rank: + if hasattr(self, "is_token_ids"): + token_ids_idx = ( + self.is_token_ids[:num_scheduled_tokens] + .nonzero(as_tuple=False) + .squeeze(1) + ) + if token_ids_idx.numel() > 0: + token_ids = self.input_ids[token_ids_idx] + tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) + self.inputs_embeds[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_kwargs = self._init_model_kwargs(num_input_tokens) + input_ids = None + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs(num_input_tokens) + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True + ) + + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + model_kwargs.update(encoder_inputs) + + return ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_after_padding, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + per_req_additional_information, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[OmniModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + with record_function_or_nullcontext("Preprocess"): + with self.synchronize_input_prep(): + super()._update_states(scheduler_output) + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + return EMPTY_MODEL_RUNNER_OUTPUT + if hasattr(self, "kv_connector_no_forward"): + return self.kv_connector_no_forward(scheduler_output) + return EMPTY_MODEL_RUNNER_OUTPUT + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + + ( + attn_metadata, + positions, + num_scheduled_tokens_np, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids_prep, + inputs_embeds_prep, + intermediate_tensors_prep, + max_query_len, + ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + + if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_before() + + if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.take_update_info_from_eplb_process() + + with_prefill = getattr(self, "with_prefill", True) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + input_ids = input_ids_prep + inputs_embeds = inputs_embeds_prep + intermediate_tensors = intermediate_tensors_prep + + per_req_additional_information: Optional[dict[str, dict]] = None + model_kwargs = self._init_model_kwargs(num_input_tokens) + + if ( + self.supports_mm_inputs + and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder + ): + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + per_req_additional_information = {} + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + addi_cpu = getattr(req_state, "additional_information_cpu", None) + num_computed_tokens = int( + self.input_batch.num_computed_tokens_cpu[req_index] + ) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + if pe_cpu is not None and inputs_embeds is not None: + src = pe_cpu[ + num_computed_tokens : num_computed_tokens + overlay_len + ].to(dtype=self.dtype, device=self.device, non_blocking=True) + start_offset = int(self.query_start_loc_cpu[req_index]) + inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + if addi_cpu is not None and isinstance(addi_cpu, dict): + req_info: dict[str, object] = {} + for k, v in addi_cpu.items(): + if isinstance(v, torch.Tensor): + if overlay_len > 0: + try: + seg = v[ + num_computed_tokens : num_computed_tokens + overlay_len + ].detach().to("cpu").contiguous() + except Exception: + seg = v.detach().to("cpu").contiguous() + req_info[k] = seg + else: + req_info[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + req_info[k] = v + else: + req_info[k] = v + per_req_additional_information[req_id] = req_info + + model_kwargs.update(self._extract_mm_kwargs(scheduler_output)) + + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + model_kwargs.update(encoder_inputs) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, uniform_decode=uniform_decode + ) + if hasattr(self, "aclgraph_dispatcher"): + aclgraph_runtime_mode, batch_descriptor = ( + self.aclgraph_dispatcher.dispatch(batch_descriptor) + ) + elif hasattr(self, "cudagraph_dispatcher"): + aclgraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch(batch_descriptor) + ) + else: + from vllm.config import CUDAGraphMode + aclgraph_runtime_mode = CUDAGraphMode.NONE + + moe_comm_type = None + if hasattr(self, "_select_moe_comm_method"): + moe_comm_type = self._select_moe_comm_method(num_input_tokens, with_prefill) + + with ( + set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + reserved_mc2_mask=getattr(self, "reserved_mc2_mask", None), + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + prefetch_stream=getattr(self, "prefetch_stream", None), + model_instance=self.model, + weight_prefetch_method=getattr(self, "weight_prefetch_method", None), + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + + model_kwargs_extra = {} + if per_req_additional_information: + model_kwargs_extra["additional_information_by_req_id"] = per_req_additional_information + try: + per_req_runtime_info = [] + for req_id in self.input_batch.req_ids: + req_state = self.requests.get(req_id) + info = ( + getattr(req_state, "additional_information_cpu", None) + if req_state is not None + else None + ) + per_req_runtime_info.append(info if isinstance(info, dict) else {}) + model_kwargs_extra["runtime_additional_information"] = ( + per_req_runtime_info + ) + model_kwargs_extra["request_ids"] = self.input_batch.req_ids + req_token_spans = [] + for req_index in range(len(self.input_batch.req_ids)): + start_offset = int(self.query_start_loc_cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + req_token_spans.append((start_offset, start_offset + sched_tokens)) + model_kwargs_extra["request_token_spans"] = req_token_spans + except Exception: + pass + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + **model_kwargs_extra, + ) + + with record_function_or_nullcontext("Postprocess"): + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None + + hidden_states, multimodal_outputs = self.extract_multimodal_outputs( + hidden_states + ) + try: + if ( + isinstance(multimodal_outputs, dict) + and ( + "additional_information_update" in multimodal_outputs + or "additional_information_update_by_req_id" in multimodal_outputs + ) + ): + updates_list = multimodal_outputs.get( + "additional_information_update" + ) + if isinstance(updates_list, list): + for idx, upd in enumerate(updates_list): + if not isinstance(upd, dict) or idx >= len(self.input_batch.req_ids): + continue + req_id = self.input_batch.req_ids[idx] + self._merge_additional_information_update(req_id, upd) + updates_map = multimodal_outputs.get( + "additional_information_update_by_req_id" + ) + if isinstance(updates_map, dict): + for req_id, upd in updates_map.items(): + if not isinstance(upd, dict): + continue + if req_id not in self.requests: + continue + self._merge_additional_information_update(req_id, upd) + except Exception as e: + logger.error(f"Error merging for requests:{self.input_batch.req_ids} additional information update: {e}, with the multimodal_outputs as {multimodal_outputs}") + broadcast_pp_output = getattr(self, "broadcast_pp_output", False) + if not broadcast_pp_output: + + if not get_pp_group().is_last_rank: + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np, + finished_sending=None, finished_recving=None, + kv_connector_output=kv_connector_output + ) + return output + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + else: + assert not self.is_pooling_model + + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + logits = None + sample_hidden_states = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + + model_output_broadcast_data = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + if sample_hidden_states is None: + sample_hidden_states = hidden_states[logits_indices] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + apply_grammar_bitmask( + scheduler_output, self.input_batch, logits, self.device + ) + + with record_function_or_nullcontext("Sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + def propose_draft_token_ids(sampled_token_ids): + if not hasattr(self, "propose_draft_token_ids") or spec_decode_metadata is None: + return + with record_function_or_nullcontext("Draft"): + if isinstance(sampled_token_ids, torch.Tensor): + if sampled_token_ids.dim() == 1: + sampled_token_ids_list = [[t.item()] for t in sampled_token_ids] + else: + sampled_token_ids_list = sampled_token_ids.tolist() + elif isinstance(sampled_token_ids, list) and len(sampled_token_ids) > 0: + if not isinstance(sampled_token_ids[0], list): + sampled_token_ids_list = [[t] for t in sampled_token_ids] + else: + sampled_token_ids_list = sampled_token_ids + else: + sampled_token_ids_list = sampled_token_ids + + self._draft_token_ids = self.propose_draft_token_ids( + sampled_token_ids_list, + self.input_batch.sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + num_scheduled_tokens, + hidden_states, + attn_metadata, + aux_hidden_states, + ) + + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) + effective_drafter_max_model_len = self.max_model_len + if effective_drafter_max_model_len is None: + effective_drafter_max_model_len = self.model_config.max_model_len + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): + effective_drafter_max_model_len = ( + self.speculative_config.draft_model_config.max_model_len + ) + input_fits_in_drafter = False + if spec_decode_metadata is not None and attn_metadata: + try: + first_layer_metadata = next(iter(attn_metadata.values())) if attn_metadata else None + if first_layer_metadata and hasattr(first_layer_metadata, "seq_lens"): + max_seq_len = first_layer_metadata.seq_lens.max().item() + input_fits_in_drafter = ( + max_seq_len + self.speculative_config.num_speculative_tokens + <= effective_drafter_max_model_len + ) + except Exception: + pass + + if use_padded_batch_for_eagle and input_fits_in_drafter: + sampled_tokens = sampler_output.sampled_token_ids + if isinstance(sampled_tokens, torch.Tensor): + if sampled_tokens.dim() == 1: + sampled_tokens_list = [[t.item()] for t in sampled_tokens] + else: + sampled_tokens_list = sampled_tokens.tolist() + else: + sampled_tokens_list = sampled_tokens + propose_draft_token_ids(sampled_tokens_list) + + with record_function_or_nullcontext("Bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + ) + + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): + propose_draft_token_ids(valid_sampled_token_ids) + + with record_function_or_nullcontext("EPLB"): + if hasattr(self, "eplb_step"): + self.eplb_step() + + hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() + pooler_output: list[Optional[torch.Tensor]] = [] + prev_logits_index = 0 + for logits_index in logits_indices: + pooler_output.append( + hidden_states_cpu[prev_logits_index : logits_index + 1] + ) + prev_logits_index = logits_index + 1 + + output = OmniModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=( + pooler_output + if self.vllm_config.model_config.engine_output_type != "text" + else None + ), + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) + + if not self.use_async_scheduling: + return output + + return AsyncNPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + + def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: + req_state = self.requests.get(req_id) + if req_state is None: + return + existing = getattr(req_state, "additional_information_cpu", {}) + if not isinstance(existing, dict): + existing = {} + merged = dict(existing) + for k, v in upd.items(): + if isinstance(v, torch.Tensor): + merged[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + new_list = [] + for item in v: + if isinstance(item, torch.Tensor): + new_list.append(item.detach().to("cpu").contiguous()) + else: + new_list.append(item) + merged[k] = new_list + else: + merged[k] = v + setattr(req_state, "additional_information_cpu", merged) + + + + diff --git a/vllm_omni/worker/npu_ar_worker.py b/vllm_omni/worker/npu_ar_worker.py new file mode 100644 index 00000000000..99b62dfecd9 --- /dev/null +++ b/vllm_omni/worker/npu_ar_worker.py @@ -0,0 +1,14 @@ +from vllm_ascend.worker.worker_v1 import NPUWorker +from vllm_omni.worker.npu_ar_model_runner import NPUARModelRunner + + +class NPUARWorker(NPUWorker): + """NPU AR worker for thinker/talker stages in Qwen2.5-Omni.""" + + def init_device(self): + device = self._init_device() + + self.model_runner: NPUARModelRunner = NPUARModelRunner( + self.vllm_config, device + ) + diff --git a/vllm_omni/worker/npu_diffusion_model_runner.py b/vllm_omni/worker/npu_diffusion_model_runner.py new file mode 100644 index 00000000000..d9bac70ca90 --- /dev/null +++ b/vllm_omni/worker/npu_diffusion_model_runner.py @@ -0,0 +1,587 @@ +"""Diffusion NPU Model Runner for vLLM-omni.""" + +from __future__ import annotations + +import gc +import logging +from typing import Any, List, Optional, Union + +import numpy as np +import torch + +from vllm.config import CUDAGraphMode +from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.gpu_model_runner import ( + EMPTY_MODEL_RUNNER_OUTPUT, + IntermediateTensors, + PerLayerAttnMetadata, + get_pp_group, +) +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.utils import lmhead_tp_enable +from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput +from vllm.v1.worker.ubatch_utils import UBatchSlices +from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.npu_model_runner import OmniNPUModelRunner + +logger = logging.getLogger(__name__) + + +class NPUDiffusionModelRunner(OmniNPUModelRunner): + """Diffusion model runner for vLLM-omni on NPU (non-autoregressive).""" + + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ubatch_slices: Optional[UBatchSlices] = None, + num_tokens_after_padding: Optional[torch.Tensor] = None, + ) -> tuple[ + int, + int, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + Optional[IntermediateTensors], + dict[str, Any], + ]: + + num_input_tokens = scheduler_output.total_num_scheduled_tokens + num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad + + if ( + self.supports_mm_inputs + and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder + ): + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids[:num_input_tokens], + multimodal_embeddings=mm_embeds or None, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds_scheduled) + + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_kwargs = { + **self._init_model_kwargs(num_input_tokens), + **self._extract_mm_kwargs(scheduler_output), + } + elif self.enable_prompt_embeds and get_pp_group().is_first_rank: + if hasattr(self, "is_token_ids"): + token_ids_idx = ( + self.is_token_ids[:num_input_tokens] + .nonzero(as_tuple=False) + .squeeze(1) + ) + if token_ids_idx.numel() > 0: + token_ids = self.input_ids[token_ids_idx] + tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) + self.inputs_embeds[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_kwargs = self._init_model_kwargs(num_input_tokens) + input_ids = None + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs(num_input_tokens) + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True + ) + + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + model_kwargs.update(encoder_inputs) + + return ( + num_input_tokens, + num_input_tokens, + num_tokens_after_padding, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[OmniModelRunnerOutput, IntermediateTensors]: + with record_function_or_nullcontext("Preprocess"): + with self.synchronize_input_prep(): + super()._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_after_padding, + ) = self._prepare_inputs(scheduler_output) + + if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_before() + + if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.take_update_info_from_eplb_process() + + ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) = self._preprocess( + scheduler_output, + intermediate_tensors, + ubatch_slices, + num_tokens_after_padding, + ) + + aclgraph_runtime_mode = CUDAGraphMode.NONE + with ( + set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=True, # Diffusion models process all tokens at once + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=None, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + + outputs = self._run_diffusion( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + multimodal_kwargs=model_kwargs, + logits_indices=logits_indices, + ) + + if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_end() + + _, multimodal_outputs = self.extract_multimodal_outputs(outputs) + pooler_output: List[Optional[torch.Tensor]] = [] + if isinstance(multimodal_outputs, torch.Tensor): + assert multimodal_outputs.shape[0] == self.input_batch.num_reqs + for i in range(self.input_batch.num_reqs): + pooler_output.append( + multimodal_outputs[i].detach().to("cpu").contiguous() + ) + elif isinstance(multimodal_outputs, list): + for out in multimodal_outputs: + pooler_output.append( + out.detach().to("cpu").contiguous() if out is not None else None + ) + elif isinstance(multimodal_outputs, dict): + for out in multimodal_outputs.values(): + pooler_output.append( + out.detach().to("cpu").contiguous() if out is not None else None + ) + else: + raise RuntimeError("Unsupported diffusion output type") + + output = OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + num_nans_in_logits={}, + ) + + if not self.use_async_scheduling: + return output + + return AsyncNPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=[], + invalid_req_indices=[], + async_output_copy_stream=self.async_output_copy_stream, + ) + + def _run_diffusion( + self, + *, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], + multimodal_kwargs: dict, + logits_indices: torch.Tensor, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + """Runs the diffusion process and returns per-request tensors. + + Tries model interfaces in the following order for maximal compatibility: + 1) model.sample(condition=..., **kwargs) + 2) model.forward(condition=..., **kwargs) + 3) model.diffuse(condition=..., **kwargs) + """ + kwargs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs(multimodal_kwargs, device=self.device), + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + if hasattr(self.model, "forward"): + return self.model.forward(**kwargs) + # TODO: add the diffuse method for other models + + raise RuntimeError( + "The loaded model does not expose diffusion interfaces 'sample', " + "'forward', or 'diffuse'. Please implement one of them or adapt the runner." + ) + + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + uniform_decode: bool = False, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run a dummy forward pass to warm up/profile run or capture the ACL graph for the model.""" + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + num_tokens_after_padding = None + + if num_tokens_after_padding is None: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens_after_padding = num_tokens + num_pad + else: + num_tokens_across_dp = num_tokens_after_padding + num_tokens_after_padding = int(num_tokens_after_padding[0].item()) + + original_in_profile_run = self.in_profile_run + self.in_profile_run = is_profile + + if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_before() + + need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) + dummy_indices = None + dummy_compute_logits = None + + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens # Diffusion always uses with_prefill=True + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32, device=self.device) + + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices]) + + attn_metadata: Optional[PerLayerAttnMetadata] = None + + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + attn_metadata = {} + + seq_lens = max_query_len + if hasattr(self, "seq_lens_np"): + self.seq_lens_np[:num_reqs] = seq_lens + self.seq_lens_np[num_reqs:] = 0 + if hasattr(self, "seq_lens_cpu"): + self.seq_lens_cpu[:num_reqs] = seq_lens + self.seq_lens_cpu[num_reqs:] = 0 + if hasattr(self, "seq_lens"): + self.seq_lens[:num_reqs].copy_( + self.seq_lens_cpu[:num_reqs], non_blocking=True) + self.seq_lens[num_reqs:].fill_(0) + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + if hasattr(self, "query_start_loc_np"): + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cum_num_tokens + if hasattr(self, "query_start_loc_cpu"): + self.query_start_loc_cpu[0] = 0 + self.query_start_loc_cpu[1 : num_reqs + 1] = cum_num_tokens + if hasattr(self, "query_start_loc"): + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups + ): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=self.max_model_len, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id + ].get_device_tensor(num_reqs), + slot_mapping=self.input_batch.block_table[ + kv_cache_group_id + ].slot_mapping[:num_tokens], + causal=True, + ) + for attn_group in self.attn_groups[kv_cache_group_id]: + + assert type(attn_metadata) is dict + attn_metadata_i = attn_group.get_metadata_builder().build_for_cudagraph_capture( + common_attn_metadata + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + try: + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): + model_kwargs = self._init_model_kwargs(num_tokens) + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_kwargs = { + **model_kwargs, + **self._dummy_mm_kwargs(num_reqs), + } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_kwargs = self._init_model_kwargs(num_tokens) + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + ) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False + ) + + _cg_mode, batch_descriptor = ( + self.aclgraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + if hasattr(self, "aclgraph_dispatcher") and not is_profile + else (CUDAGraphMode.NONE, None) + ) + if cudagraph_runtime_mode is not None: + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( + f"ACL graph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) + else: + cudagraph_runtime_mode = _cg_mode + + with self.maybe_randomize_inputs(input_ids), set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_after_padding, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=True, # Diffusion models process all tokens at once + in_profile_run=self.in_profile_run, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill=True, + is_torchair_compile=False, + input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata, + num_tokens=num_tokens, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + model_kwargs=model_kwargs, + ) + + if need_dummy_logits: + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run(num_tokens) + if need_dummy_logits: + self.drafter.model.compute_logits(hidden_states[dummy_indices]) + + if self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "model"): + self.model.clear_all_moe_loads() + if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.take_update_info_from_eplb_process() + finally: + self.in_profile_run = original_in_profile_run + + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + return hidden_states, None + + @torch.inference_mode() + def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: + logger.warning("Dummy sampler run is not implemented for diffusion model") + return None + + def profile_run(self) -> None: + if self.supports_mm_inputs: + if self.model_config.multimodal_config.skip_mm_profiling: + logger.info( + "Skipping memory profiling for multimodal encoder and " + "encoder cache." + ) + else: + mm_budget = self.mm_budget + assert mm_budget is not None + + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] + + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) + + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) + + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + encoder_output_shape = dummy_encoder_outputs[0].shape + if encoder_output_shape[0] < encoder_budget: + expanded_outputs = [] + for output in dummy_encoder_outputs: + expanded = output.new_zeros( + (encoder_budget, encoder_output_shape[-1]) + ) + num_tokens = output.shape[0] + expanded[:num_tokens].copy_(output) + expanded_outputs.append(expanded) + + dummy_encoder_outputs = expanded_outputs + + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + hidden_states, _ = self._dummy_run(self.max_num_tokens, is_profile=True) + if get_pp_group().is_last_rank: + pass + self._sync_device() + del hidden_states + self.encoder_cache.clear() + gc.collect() + + + diff --git a/vllm_omni/worker/npu_diffusion_worker.py b/vllm_omni/worker/npu_diffusion_worker.py new file mode 100644 index 00000000000..c496cd2ce9a --- /dev/null +++ b/vllm_omni/worker/npu_diffusion_worker.py @@ -0,0 +1,13 @@ +from vllm_ascend.worker.worker_v1 import NPUWorker +from vllm_omni.worker.npu_diffusion_model_runner import NPUDiffusionModelRunner + + +class NPUDiffusionWorker(NPUWorker): + """NPU diffusion worker for code2wav stage in Qwen2.5-Omni.""" + + def init_device(self): + device = self._init_device() + + self.model_runner: NPUDiffusionModelRunner = NPUDiffusionModelRunner( + self.vllm_config, device + ) diff --git a/vllm_omni/worker/npu_model_runner.py b/vllm_omni/worker/npu_model_runner.py new file mode 100644 index 00000000000..902ea284256 --- /dev/null +++ b/vllm_omni/worker/npu_model_runner.py @@ -0,0 +1,1154 @@ +"""NPU Model Runner base class for vLLM-omni. + +Provides multimodality extensions for NPU model runners, including payload +decoding and multimodal output extraction. +""" + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, List, Optional, Union, cast + +import math +import numpy as np +import torch + +import vllm.envs as envs +from vllm.config import CUDAGraphMode +from vllm.distributed.kv_transfer import has_kv_transfer_group +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed.kv_transfer import get_kv_transfer_group +from vllm.forward_context import BatchDescriptor, DPMetadata, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargsItem +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.sampling_params import SamplingType +from vllm.utils import cdiv, round_up +from vllm.v1.attention.backends.utils import CommonAttentionMetadata, split_attn_metadata +from vllm.v1.outputs import KVConnectorOutput, LogprobsLists, LogprobsTensors, SamplerOutput +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.worker.utils import is_residual_scattered_for_sp, MultiModalBudget +from vllm.v1.worker.gpu_model_runner import IntermediateTensors, PerLayerAttnMetadata +from vllm.v1.worker.ubatch_splitting import ubatch_split +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from vllm_ascend.worker.npu_input_batch import CachedRequestState +from vllm_ascend.utils import enable_sp, vllm_version_is, lmhead_tp_enable + +from vllm_omni.engine import AdditionalInformationPayload, PromptEmbedsPayload + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +class OmniNPUModelRunner(NPUModelRunner): + """Base class for NPU model runners with multimodality support. + + Extends NPUModelRunner with: + - Payload decoding (prompt_embeds, additional_information) + - Multimodal output extraction + - Additional information update merging + - Multimodal initialization (mm_budget, mm_registry, supports_mm_inputs) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.is_multimodal_raw_input_only_model = ( + self.model_config.is_multimodal_raw_input_only_model) + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + self.model_config) + + self._mm_budget = None + + @property + def mm_budget(self): + if self._mm_budget is None: + self._mm_budget = MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) if self.supports_mm_inputs else None + return self._mm_budget + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input NPU tensors for the model. + + The SamplingMetadata is updated and copied to the NPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + reqs_to_add: list[CachedRequestState] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if self.is_pooling_model: + assert pooling_params is not None + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" + + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + # Handle backward compatibility for mm_features/mm_kwargs + backward_kwargs = {} + if vllm_version_is("0.11.0"): + backward_kwargs["mm_features"] = getattr(new_req_data, "mm_features", None) + else: + backward_kwargs["mm_kwargs"] = getattr(new_req_data, "mm_kwargs", None) + backward_kwargs["mm_hashes"] = getattr(new_req_data, "mm_hashes", None) + backward_kwargs["mm_positions"] = getattr(new_req_data, "mm_positions", None) + + req_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + **backward_kwargs, + ) + self.requests[req_id] = req_state + + # If prompt embeddings are provided, decode and attach to request state + try: + if getattr(new_req_data, "prompt_embeds", None) is not None: + payload = new_req_data.prompt_embeds + if isinstance(payload, PromptEmbedsPayload): + dtype = getattr(np, payload.dtype) + arr = np.frombuffer(payload.data, dtype=dtype) + arr = arr.reshape(payload.shape) + pe_cpu = torch.from_numpy(arr) + elif isinstance(payload, torch.Tensor): + pe_cpu = payload.detach().to("cpu").contiguous() + else: + pe_cpu = None + # Store temporarily on CPU; later moved to device in builder + if pe_cpu is not None: + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # Also replace payload with Tensor for user visibility in + # scheduler_output + try: + new_req_data.prompt_embeds = pe_cpu # type: ignore[assignment] + except Exception: + pass + except Exception as e: + logger.error(f"Error decoding prompt embeds: {e}") + # Decode additional_information payloads (dictionary) + try: + if getattr(new_req_data, "additional_information", None) is not None: + payload_info = new_req_data.additional_information + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + elif isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype( + getattr(entry, "tensor_dtype", "float32") + ) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding additional information: {e}") + pass + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + if vllm_version_is("0.11.0"): + self._init_mrope_positions(self.requests[req_id]) + else: + self._init_mrope_positions_0102(self.requests[req_id]) + + reqs_to_add.append(self.requests[req_id]) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + else: + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + reqs_to_add.append(req_state) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens + if new_block_ids is not None: + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) + if spec_token_ids: + num_spec_tokens = len(spec_token_ids) + start_index = self.input_batch.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index + ] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] += num_spec_tokens + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for request in reqs_to_add: + self.input_batch.add_request(request) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + @torch.inference_mode() + def extract_multimodal_outputs( + self, hidden_states: Union[torch.Tensor, List[torch.Tensor]] + ) -> tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor], dict]]: + """Extract multimodal outputs from hidden states.""" + if ( + hasattr(self.model, "have_multimodal_outputs") + and self.model.have_multimodal_outputs + ): + text_hidden_states = hidden_states.text_hidden_states + multimodal_outputs = hidden_states.multimodal_outputs + + elif isinstance(hidden_states, torch.Tensor): + text_hidden_states = hidden_states + multimodal_outputs = {} + elif isinstance(hidden_states, List): + text_hidden_states = hidden_states[0] + multimodal_outputs = {} + else: + raise ValueError(f"Invalid hidden states type: {type(hidden_states)}") + return text_hidden_states, multimodal_outputs + + def _sample( + self, logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata] + ) -> SamplerOutput: + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + if hasattr(self, "_update_states_after_model_execute"): + self._update_states_after_model_execute(output_token_ids) + + return sampler_output + + def _bookkeeping_sync( + self, scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, num_scheduled_tokens: int + ) -> tuple[ + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], + ]: + num_nans_in_logits = {} + # placeholder for now, TODO: _get_nans_in_logits() + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + invalid_req_indices = [] + + if not self.use_async_scheduling: + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + discard_sampled_tokens_req_indices.append(i) + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + else: + valid_sampled_token_ids = [] + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + discard_sampled_tokens_req_indices.append(i) + invalid_req_indices = discard_sampled_tokens_req_indices + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + + req_ids = self.input_batch.req_ids + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + + req_id = req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + return ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + + def get_dp_padding( + self, num_tokens: int + ) -> tuple[int, Optional[torch.Tensor]]: + """Determines the total number of tokens that each rank will run.""" + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + + if dp_size == 1 or self.vllm_config.model_config.enforce_eager: + return 0, None + + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + """Calculate input token count with padding if needed.""" + if (self.use_aclgraph and num_scheduled_tokens + <= self.aclgraph_batch_sizes[-1]): + # Add padding to the batch size for ACLGraph. + # Note: pad_for_cudagraph works for both CUDA graphs and ACLGraph + return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) + + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if (self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + + def _generate_dummy_run_hidden_states(self, with_prefill, + is_torchair_compile, input_ids, + positions, attn_metadata, num_tokens, + intermediate_tensors, inputs_embeds, + model_kwargs=None): + """Override to support model_kwargs for multimodal/pooling models.""" + if model_kwargs is None: + model_kwargs = {} + + hidden_states = self.model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs) + + from vllm_ascend.compilation.acl_graph import update_attn_params, update_mla_attn_params + + forward_context = get_forward_context() + assert forward_context is not None + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ + not forward_context.capturing: + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + positions.shape[0], + self.speculative_config) + else: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) + + from vllm_ascend.spec_decode.interface import SpecDcodeType + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, _ = hidden_states + else: + hidden_states = hidden_states + return hidden_states + + def _init_model_kwargs(self, num_tokens: int): + """Initialize model kwargs.""" + model_kwargs = dict[str, Any]() + + if not self.is_pooling_model: + return model_kwargs + + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens_cpu[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( + device=self.device) + return model_kwargs + + def sync_and_slice_intermediate_tensors( + self, num_tokens: int, intermediate_tensors: IntermediateTensors, + sync_self: bool + ) -> IntermediateTensors: + """Sync and slice intermediate tensors for pipeline parallelism.""" + assert self.intermediate_tensors is not None + + tp = self.vllm_config.parallel_config.tensor_parallel_size + is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens) + + if sync_self: + assert intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + is_scattered = k == "residual" and is_rs + copy_len = num_tokens // tp if is_scattered else num_tokens + self.intermediate_tensors[k][:copy_len].copy_( + v[:copy_len], non_blocking=True) + + return IntermediateTensors({ + k: v[:num_tokens // tp] if k == "residual" and is_rs else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + def _extract_mm_kwargs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict: + """Extract multimodal kwargs.""" + if not scheduler_output or not self.is_multimodal_raw_input_only_model: + return {} + + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + # Handle version compatibility + if vllm_version_is("0.11.0"): + mm_features = getattr(req, "mm_features", None) + else: + mm_features = getattr(req, "mm_kwargs", None) + + if mm_features: + if isinstance(mm_features, list): + for feature in mm_features: + if hasattr(feature, "data") and feature.data is not None: + mm_kwargs.append(feature.data) + elif isinstance(mm_features, dict): + return mm_features + + if not mm_kwargs: + return {} + + model = cast(SupportsMultiModal, self.model) + mm_kwargs_combined: dict = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ): + mm_kwargs_combined.update(mm_kwargs_group) + + return mm_kwargs_combined + + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models.""" + if hasattr(self, "_batch_mm_kwargs_from_scheduler"): + mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + else: + return {} + + if not mm_kwargs: + return {} + + model = cast(SupportsMultiModal, self.model) + encoder_features = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ): + encoder_features.update(mm_kwargs_group) + + return encoder_features + + @contextmanager + def synchronize_input_prep(self): + """Synchronize input preparation for async scheduling.""" + if getattr(self, 'use_async_scheduling', False): + if not hasattr(self, "prepare_inputs_event") or self.prepare_inputs_event is None: + self.prepare_inputs_event = torch.npu.Event() + self.prepare_inputs_event.record(torch.npu.current_stream()) + + if not hasattr(self, "prepare_inputs_event") or self.prepare_inputs_event is None: + yield + return + + self.prepare_inputs_event.synchronize() + try: + yield + finally: + self.prepare_inputs_event.record() + + @contextmanager + def maybe_get_kv_connector_output( + self, scheduler_output: "SchedulerOutput" + ): + """KV connector context manager.""" + if not has_kv_transfer_group(): + yield None + return + + output = KVConnectorOutput() + + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + kv_connector.start_load_kv(get_forward_context()) + try: + yield output + finally: + kv_connector.wait_for_save() + output.finished_sending, output.finished_recving = ( + kv_connector.get_finished(scheduler_output.finished_req_ids)) + output.kv_connector_stats = kv_connector.get_kv_connector_stats() + kv_connector.clear_connector_metadata() + + def pad_out_ubatch_slice(self, ubatch_slices, num_total_tokens: int): + """Pad ubatch slice for DBO (Dynamic Batch Overlap).""" + from vllm.v1.worker.ubatch_utils import UBatchSlice + if len(ubatch_slices) < 2: + return + padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, + num_total_tokens) + ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, + padded_second_ubatch_slice) + + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: + """Step for EPLB (Expert Parallelism Load Balancing).""" + if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_end() + + def _get_mm_dummy_batch( + self, + modality: str, + max_items_per_batch: int, + ) -> dict: + """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_model_len, + mm_counts={modality: 1}, + cache=self.mm_budget.cache, + ) + dummy_mm_data = dummy_decoder_data.multi_modal_data + + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch + + model = cast(SupportsMultiModal, self.model) + return next(mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=getattr(self, "pin_memory", False), + merge_by_field_config=model.merge_by_field_config, + )) + + def _dummy_mm_kwargs(self, num_seqs: int) -> dict: + """Return dummy multimodal kwargs for dummy runs.""" + if not self.is_multimodal_raw_input_only_model: + return {} + + mm_budget = self.mm_budget + assert mm_budget is not None + + dummy_modality = mm_budget.get_modality_with_max_tokens() + return self._get_mm_dummy_batch(dummy_modality, num_seqs) + + @contextmanager + def maybe_randomize_inputs(self, input_ids: Optional[torch.Tensor]): + """Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.""" + dp_size = self.vllm_config.parallel_config.data_parallel_size + randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 + + if not randomize_inputs: + yield + return + + if input_ids is None: + yield + return + + import functools + + @functools.cache + def rand_input_ids() -> torch.Tensor: + return torch.randint_like( + self.input_ids, + low=0, + high=self.model_config.get_vocab_size(), + dtype=input_ids.dtype, + ) + + logger.debug_once("Randomizing dummy data for DP Rank") + input_ids.copy_( + rand_input_ids()[:input_ids.size(0)], + non_blocking=True + ) + yield + input_ids.fill_(0) + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + uniform_decode: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run a dummy forward pass to warm up/profile run or capture the ACL graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: Used to control the behavior. + - if not set will determine the aclgraph mode based on using + the self.aclgraph_dispatcher. + - CUDAGraphMode.NONE: No aclgraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise aclgraph. + - CUDAGraphMode.FULL: Full aclgraph, attention metadata is needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + allow_microbatching: If True, allow ubatch splitting if DBO is enabled. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. + remove_lora: If False, dummy LoRAs are not destroyed after the run. + """ + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + + if hasattr(self, "use_aclgraph") and self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + with_prefill = create_mixed_batch or (not uniform_decode and num_tokens > 1) + + if hasattr(self, "is_kv_producer") and self.is_kv_producer and \ + hasattr(self, "is_kv_consumer") and not self.is_kv_consumer: + with_prefill = True + + if hasattr(self, "_sync_metadata_across_dp"): + (num_tokens, num_tokens_across_dp, with_prefill, _) = \ + self._sync_metadata_across_dp(num_tokens, with_prefill, False) + else: + num_tokens_across_dp = None + + if hasattr(self, "_select_moe_comm_method"): + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + else: + moe_comm_type = None + + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + + if create_mixed_batch: + assert not uniform_decode + num_decode_tokens = num_tokens // 2 + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + decode_token_per_req = getattr(self, "decode_token_per_req", 1) + num_reqs = (num_tokens + decode_token_per_req - 1) // decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + + ubatch_slices = None + num_tokens_after_padding = None + + if self.parallel_config.enable_dbo and allow_microbatching: + ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + total_num_scheduled_tokens, + total_num_scheduled_tokens, + uniform_decode=uniform_decode, + vllm_config=self.vllm_config, + ) + if ubatch_num_tokens_after_padding is not None: + num_tokens_after_padding = ubatch_num_tokens_after_padding * 2 + + if num_tokens_after_padding is None: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens_after_padding = num_tokens + num_pad + else: + if isinstance(num_tokens_after_padding, torch.Tensor): + num_tokens_after_padding = int(num_tokens_after_padding[0].item()) + elif isinstance(num_tokens_after_padding, (list, np.ndarray)): + num_tokens_after_padding = int(num_tokens_after_padding[0]) + + attn_metadata: Optional[PerLayerAttnMetadata] = None + + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + + if create_mixed_batch: + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + seq_lens = max_query_len + + self.seq_lens_np[:num_reqs] = seq_lens + self.seq_lens_np[num_reqs:] = 0 + if isinstance(seq_lens, list): + self.seq_lens_cpu[:num_reqs] = torch.tensor(seq_lens, dtype=torch.int32) + else: + self.seq_lens_cpu[:num_reqs] = seq_lens + self.seq_lens_cpu[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_( + self.seq_lens_cpu[:num_reqs], non_blocking=True) + self.seq_lens[num_reqs:].fill_(0) + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc_cpu[0] = 0 + self.query_start_loc_cpu[1 : num_reqs + 1] = torch.from_numpy(cum_num_tokens) + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups + ): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=self.max_model_len, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id + ].get_device_tensor(num_reqs), + slot_mapping=self.input_batch.block_table[ + kv_cache_group_id + ].slot_mapping[:num_tokens], + causal=True, + ) + for attn_group in self.attn_groups[kv_cache_group_id]: + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata + ) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list + ): + assert common_attn_metadata.max_query_len == 1 + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build_for_cudagraph_capture(common_attn_metadata) + for layer_name in attn_group.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + assert type(attn_metadata) is dict + attn_metadata_i = attn_group.get_metadata_builder().build_for_cudagraph_capture( + common_attn_metadata + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): + model_kwargs = self._init_model_kwargs(num_tokens) + + # Prepare inputs (NPU uses direct tensor access, not .gpu buffers) + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_kwargs = { + **model_kwargs, + **self._dummy_mm_kwargs(num_reqs), + } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_kwargs = self._init_model_kwargs(num_tokens) + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + # Prepare intermediate_tensors if PP + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + ) + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False + ) + + # Dispatch graph mode (check for aclgraph_dispatcher) + if hasattr(self, "aclgraph_dispatcher") and not is_profile: + _cg_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + else: + _cg_mode, batch_descriptor = (CUDAGraphMode.NONE, None) + + # Map GPU parameter name to NPU internal name for clarity + # cudagraph_runtime_mode (GPU signature) → aclgraph_runtime_mode (NPU internal) + if cudagraph_runtime_mode is not None: + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( + f"ACL graph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) + aclgraph_runtime_mode = cudagraph_runtime_mode + else: + aclgraph_runtime_mode = _cg_mode + + # Adjust for ubatch if needed + if ubatch_slices is not None: + num_tokens_after_padding = ubatch_slices[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_after_padding + + original_in_profile_run = self.in_profile_run + self.in_profile_run = is_profile + + if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_before() + + need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) + dummy_indices = None + dummy_compute_logits = None + + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32, device=self.device) + + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices]) + + try: + with self.maybe_randomize_inputs(input_ids), set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_after_padding, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + reserved_mc2_mask=getattr(self, "reserved_mc2_mask", None), + moe_comm_type=moe_comm_type, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + prefetch_stream=getattr(self, "prefetch_stream", None), + model_instance=self.model, + weight_prefetch_method=getattr(self, "weight_prefetch_method", None), + ): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill=with_prefill, + is_torchair_compile=False, + input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata, + num_tokens=num_tokens, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + model_kwargs=model_kwargs, + ) + + if need_dummy_logits: + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + skip_attn=True, + num_reqs=num_reqs, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + if need_dummy_logits: + self.drafter.model.compute_logits(hidden_states[dummy_indices]) + + if self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "model"): + self.model.clear_all_moe_loads() + if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.take_update_info_from_eplb_process() + finally: + self.in_profile_run = original_in_profile_run + + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + return hidden_states, hidden_states[logit_indices] From 54aa97203a1c951c8738026c7bfe864188542a26 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Sat, 22 Nov 2025 15:35:47 +0000 Subject: [PATCH 02/11] fix lint Signed-off-by: gcanlin --- vllm_omni/entrypoints/utils.py | 13 ++++++++++--- vllm_omni/utils/__init__.py | 11 +++++++++++ vllm_omni/worker/npu_model_runner.py | 19 ++++++++++--------- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index a806a9df6c4..e0b10c93c47 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -33,14 +33,21 @@ def select_worker_class(worker_cls: Optional[str], device_type: Optional[str] = """Select appropriate worker class based on device type.""" if worker_cls is None: return None - + if device_type is None: device_type = detect_device_type() - + if device_type == "npu": + # Replace module path: gpu_ar_worker -> npu_ar_worker if "gpu_ar_worker" in worker_cls: worker_cls = worker_cls.replace("gpu_ar_worker", "npu_ar_worker") elif "gpu_diffusion_worker" in worker_cls: worker_cls = worker_cls.replace("gpu_diffusion_worker", "npu_diffusion_worker") - + + # Replace class name: GPUARWorker -> NPUARWorker, GPUDiffusionWorker -> NPUDiffusionWorker + if "GPUARWorker" in worker_cls: + worker_cls = worker_cls.replace("GPUARWorker", "NPUARWorker") + elif "GPUDiffusionWorker" in worker_cls: + worker_cls = worker_cls.replace("GPUDiffusionWorker", "NPUDiffusionWorker") + return worker_cls diff --git a/vllm_omni/utils/__init__.py b/vllm_omni/utils/__init__.py index e69de29bb2d..50dbb478d90 100644 --- a/vllm_omni/utils/__init__.py +++ b/vllm_omni/utils/__init__.py @@ -0,0 +1,11 @@ +from vllm_omni.utils.platform_utils import ( + detect_device_type, + get_device_control_env_var, + is_npu, +) + +__all__ = [ + "detect_device_type", + "get_device_control_env_var", + "is_npu", +] diff --git a/vllm_omni/worker/npu_model_runner.py b/vllm_omni/worker/npu_model_runner.py index 902ea284256..faf958f3aaa 100644 --- a/vllm_omni/worker/npu_model_runner.py +++ b/vllm_omni/worker/npu_model_runner.py @@ -395,7 +395,7 @@ def _bookkeeping_sync( req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ + logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -524,7 +524,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill, """Override to support model_kwargs for multimodal/pooling models.""" if model_kwargs is None: model_kwargs = {} - + hidden_states = self.model(input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -809,6 +809,7 @@ def rand_input_ids() -> torch.Tensor: def _dummy_run( self, num_tokens: int, + with_prefill: bool = False, cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, @@ -817,7 +818,7 @@ def _dummy_run( is_profile: bool = False, create_mixed_batch: bool = False, remove_lora: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """Run a dummy forward pass to warm up/profile run or capture the ACL graph for the model. Args: @@ -943,9 +944,8 @@ def _dummy_run( else: self.seq_lens_cpu[:num_reqs] = seq_lens self.seq_lens_cpu[num_reqs:] = 0 - self.seq_lens[:num_reqs].copy_( - self.seq_lens_cpu[:num_reqs], non_blocking=True) - self.seq_lens[num_reqs:].fill_(0) + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + self.seq_lens[num_reqs:].fill_(0) cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) self.query_start_loc_np[0] = 0 @@ -1040,7 +1040,7 @@ def _dummy_run( ) ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False + num_tokens, self.intermediate_tensors, False ) # Dispatch graph mode (check for aclgraph_dispatcher) @@ -1149,6 +1149,7 @@ def dummy_compute_logits(hidden_states): if not skip_eplb: self.eplb_step(is_dummy=True, is_profile=is_profile) - logit_indices = np.cumsum(num_scheduled_tokens) - 1 + # Extract text hidden states from multimodal outputs + # Parent class will handle logit indexing in profile_run hidden_states, _ = self.extract_multimodal_outputs(hidden_states) - return hidden_states, hidden_states[logit_indices] + return hidden_states From 05fa5bf8b8bed5e17290f974434d640327386315 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Sun, 23 Nov 2025 12:07:43 +0000 Subject: [PATCH 03/11] fix device mapping Signed-off-by: gcanlin --- vllm_omni/entrypoints/omni_stage.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index b1e8fed7522..a2a3b36ff73 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -551,7 +551,9 @@ def filter(self, record: _logging.LogRecord) -> bool: # Device mapping try: - set_stage_gpu_devices(stage_id, runtime_cfg.get("devices")) + from vllm_omni.utils import detect_device_type + device_type = detect_device_type() + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: _logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e) From b90fad01f25ba18c33fdca904a30b7f950fcf21c Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 24 Nov 2025 01:30:45 +0000 Subject: [PATCH 04/11] remove use_aux_hidden_state_outputs Signed-off-by: gcanlin --- vllm_omni/worker/npu_ar_model_runner.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm_omni/worker/npu_ar_model_runner.py b/vllm_omni/worker/npu_ar_model_runner.py index 032433a1545..b51cb4928ad 100644 --- a/vllm_omni/worker/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu_ar_model_runner.py @@ -385,11 +385,8 @@ def execute_model( ) with record_function_or_nullcontext("Postprocess"): - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None + hidden_states = model_output + aux_hidden_states = None hidden_states, multimodal_outputs = self.extract_multimodal_outputs( hidden_states @@ -516,7 +513,7 @@ def propose_draft_token_ids(sampled_token_ids): and self.speculative_config.use_eagle() and not self.speculative_config.disable_padded_drafter_batch ) - effective_drafter_max_model_len = self.max_model_len + effective_drafter_max_model_len = self.self.model_config.max_model_len.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len if ( From 881fca6444abec2b54c1f556fe885f53b8091791 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 24 Nov 2025 03:59:29 +0000 Subject: [PATCH 05/11] fix typo Signed-off-by: gcanlin --- vllm_omni/worker/npu_ar_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/worker/npu_ar_model_runner.py b/vllm_omni/worker/npu_ar_model_runner.py index b51cb4928ad..1fce6992b70 100644 --- a/vllm_omni/worker/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu_ar_model_runner.py @@ -513,7 +513,7 @@ def propose_draft_token_ids(sampled_token_ids): and self.speculative_config.use_eagle() and not self.speculative_config.disable_padded_drafter_batch ) - effective_drafter_max_model_len = self.self.model_config.max_model_len.max_model_len + effective_drafter_max_model_len = self.model_config.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len if ( From 2d07514f3f6c739790a241809dfcd5c0710dd51c Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 24 Nov 2025 10:22:26 +0000 Subject: [PATCH 06/11] Make ops adapted on Ascend NPU Signed-off-by: gcanlin --- .../models/qwen2_5_omni_token2wav.py | 64 ++++++++++++++----- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py index 8c541cc2a09..8b17bd8b580 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py @@ -774,20 +774,21 @@ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): else: beta = 0.0 + # Compute on CPU to avoid NPU kernel issues (kaiser_window not supported on NPU) kaiser_window = torch.kaiser_window( - kernel_size, beta=beta, periodic=False, dtype=torch.float32 + kernel_size, beta=beta, periodic=False, dtype=torch.float32, device='cpu' ) # Compute time indices if is_even: - time_indices = torch.arange(-half_size, half_size) + 0.5 + time_indices = torch.arange(-half_size, half_size, device='cpu') + 0.5 else: - time_indices = torch.arange(kernel_size) - half_size + time_indices = torch.arange(kernel_size, device='cpu') - half_size # Compute sinc filter if cutoff == 0: return torch.zeros( - (1, 1, kernel_size), dtype=torch.float32 + (1, 1, kernel_size), dtype=torch.float32, device='cpu' ) # Ensures correct shape sinc_filter = torch.sinc(2 * cutoff * time_indices) @@ -799,6 +800,29 @@ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): return normalized_filter.view(1, 1, kernel_size) +def replication_pad_1d(hidden_states: torch.Tensor, pad_left: int, + pad_right: int) -> torch.Tensor: + """Manual replicate padding to avoid replication_pad1d kernel limits on NPU.""" + # NOTE: a immature implmentation for running in Ascend NPU. Need to discuss. + if pad_left == 0 and pad_right == 0: + return hidden_states + + segments = [] + if pad_left > 0: + left = hidden_states[..., :1].expand( + *hidden_states.shape[:-1], pad_left) + segments.append(left) + + segments.append(hidden_states) + + if pad_right > 0: + right = hidden_states[..., -1:].expand( + *hidden_states.shape[:-1], pad_right) + segments.append(right) + + return torch.cat(segments, dim=-1) + + class UpSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None): super().__init__() @@ -820,16 +844,21 @@ def __init__(self, ratio=2, kernel_size=None): def forward(self, hidden_states): channels = hidden_states.shape[1] - hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to( - self.filter.dtype + input_dtype = hidden_states.dtype + # F.pad in Ascend doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d( + hidden_states.to(self.filter.dtype), self.pad, self.pad + ) + filter_on_device = self.filter.to( + device=hidden_states.device, dtype=hidden_states.dtype ) hidden_states = self.ratio * F.conv_transpose1d( hidden_states, - self.filter.expand(channels, -1, -1), + filter_on_device.expand(channels, -1, -1), stride=self.stride, groups=channels, - ).to(hidden_states_dtype) + ).to(input_dtype) hidden_states = hidden_states[..., self.pad_left : -self.pad_right] return hidden_states @@ -855,16 +884,21 @@ def __init__(self, ratio=2, kernel_size=None): def forward(self, hidden_states): channels = hidden_states.shape[1] - hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad( - hidden_states, (self.pad_left, self.pad_right), mode="replicate" - ).to(self.filter.dtype) + input_dtype = hidden_states.dtype + # F.pad in Ascend doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d( + hidden_states.to(self.filter.dtype), self.pad_left, self.pad_right + ) + filter_on_device = self.filter.to( + device=hidden_states.device, dtype=hidden_states.dtype + ) out = F.conv1d( hidden_states, - self.filter.expand(channels, -1, -1), + filter_on_device.expand(channels, -1, -1), stride=self.stride, groups=channels, - ).to(hidden_states_dtype) + ).to(input_dtype) return out From b5869d8aa6382248ac1b3ac08c4346ba63b838a0 Mon Sep 17 00:00:00 2001 From: AndyZhou952 Date: Tue, 25 Nov 2025 11:51:27 +0800 Subject: [PATCH 07/11] sync main - support multimodal inputs w/ multiple requests --- .../offline_inference/qwen2_5_omni/README.md | 8 + .../offline_inference/qwen2_5_omni/end2end.py | 339 +++++++++------- .../qwen2_5_omni/processing_omni.py | 367 ------------------ .../qwen2_5_omni/run_multiple_prompts.sh | 11 +- .../qwen2_5_omni/run_single_prompt.sh | 10 +- .../offline_inference/qwen2_5_omni/utils.py | 312 --------------- examples/online_serving/README.md | 12 +- ...letion_client_for_multimodal_generation.py | 155 +++++++- .../run_curl_multimodal_generation.sh | 144 ++++++- vllm_omni/entrypoints/chat_utils.py | 242 ++++++++++++ vllm_omni/entrypoints/omni_stage.py | 21 +- vllm_omni/entrypoints/openai/serving_chat.py | 146 ++++++- .../models/qwen2_5_omni/__init__.py | 0 .../qwen2_5_omni/__init__.py:Zone.Identifier | Bin 0 -> 25 bytes .../models/{ => qwen2_5_omni}/qwen2_5_omni.py | 0 .../{ => qwen2_5_omni}/qwen2_5_omni_talker.py | 0 .../qwen2_5_omni_thinker.py | 0 .../qwen2_5_omni_token2wav.py | 0 .../models/{ => qwen2_5_omni}/qwen2_old.py | 0 vllm_omni/model_executor/models/registry.py | 13 +- vllm_omni/model_executor/models/utils.py | 22 ++ vllm_omni/model_executor/models/vision.py | 23 ++ .../stage_input_processors/qwen2_5_omni.py | 13 +- vllm_omni/worker/gpu_model_runner.py | 39 +- vllm_omni/worker/npu_model_runner.py | 42 +- 25 files changed, 1022 insertions(+), 897 deletions(-) delete mode 100644 examples/offline_inference/qwen2_5_omni/processing_omni.py delete mode 100644 examples/offline_inference/qwen2_5_omni/utils.py create mode 100644 vllm_omni/entrypoints/chat_utils.py create mode 100644 vllm_omni/model_executor/models/qwen2_5_omni/__init__.py create mode 100644 vllm_omni/model_executor/models/qwen2_5_omni/__init__.py:Zone.Identifier rename vllm_omni/model_executor/models/{ => qwen2_5_omni}/qwen2_5_omni.py (100%) rename vllm_omni/model_executor/models/{ => qwen2_5_omni}/qwen2_5_omni_talker.py (100%) rename vllm_omni/model_executor/models/{ => qwen2_5_omni}/qwen2_5_omni_thinker.py (100%) rename vllm_omni/model_executor/models/{ => qwen2_5_omni}/qwen2_5_omni_token2wav.py (100%) rename vllm_omni/model_executor/models/{ => qwen2_5_omni}/qwen2_old.py (100%) create mode 100644 vllm_omni/model_executor/models/vision.py diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md index 112ef5335e2..14203978972 100644 --- a/examples/offline_inference/qwen2_5_omni/README.md +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -32,3 +32,11 @@ Then run the command below. ```bash bash run_single_prompt.sh ``` + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index e3e496f46d1..be4e2c1a1a2 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -1,166 +1,145 @@ -import argparse +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This example shows how to use vLLM-omni for running offline inference +with the correct prompt format on Qwen2.5-Omni +""" + import os -import os as _os_env_toggle -import random +from typing import NamedTuple -import numpy as np import soundfile as sf -import torch -from utils import make_omni_prompt +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.multimodal.image import convert_image_mode from vllm.sampling_params import SamplingParams +from vllm.utils import FlexibleArgumentParser -from vllm_omni.entrypoints.omni_llm import OmniLLM - -_os_env_toggle.environ["VLLM_USE_V1"] = "1" +from vllm_omni import OmniLLM SEED = 42 -# Set all random seeds -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -torch.cuda.manual_seed(SEED) -torch.cuda.manual_seed_all(SEED) -# Make PyTorch deterministic -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False -# Set environment variables for deterministic behavior -os.environ["PYTHONHASHSEED"] = str(SEED) -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model", - required=True, - help="Path to merged model directory (will be created if downloading).", - ) - parser.add_argument("--thinker-model", type=str, default=None) - parser.add_argument("--talker-model", type=str, default=None) - parser.add_argument("--code2wav-model", type=str, default=None) - parser.add_argument( - "--hf-hub-id", - default="Qwen/Qwen2.5-Omni-7B", - help="Hugging Face repo id to download if needed.", - ) - parser.add_argument("--hf-revision", default=None, help="Optional HF revision (branch/tag/commit).") - parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.") - parser.add_argument("--voice-type", default="default", help="Voice type, e.g., m02, f030, default.") - parser.add_argument( - "--code2wav-dir", - default=None, - help="Path to code2wav folder (contains spk_dict.pt).", - ) - parser.add_argument("--dit-ckpt", default=None, help="Path to DiT checkpoint file (e.g., dit.pt).") - parser.add_argument("--bigvgan-ckpt", default=None, help="Path to BigVGAN checkpoint file.") - parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"]) - parser.add_argument("--max-model-len", type=int, default=32768) - parser.add_argument( - "--init-sleep-seconds", - type=int, - default=20, - help="Sleep seconds after starting each stage process to allow initialization (default: 20)", - ) +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. - parser.add_argument("--thinker-only", action="store_true") - parser.add_argument("--text-only", action="store_true") - parser.add_argument("--do-wave", action="store_true") - parser.add_argument( - "--prompt_type", - choices=[ - "text", - "audio", - "audio-long", - "audio-long-chunks", - "audio-long-expand-chunks", - "image", - "video", - "video-frames", - "audio-in-video", - "audio-in-video-v2", - "audio-multi-round", - "badcase-vl", - "badcase-text", - "badcase-image-early-stop", - "badcase-two-audios", - "badcase-two-videos", - "badcase-multi-round", - "badcase-voice-type", - "badcase-voice-type-v2", - "badcase-audio-tower-1", - "badcase-audio-only", - ], - default="text", - ) - parser.add_argument("--use-torchvision", action="store_true") - parser.add_argument("--tokenize", action="store_true") - parser.add_argument( - "--output-wav", - default="output.wav", - help="[Deprecated] Output wav directory (use --output-dir).", +default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech." +) + + +def get_text_query(question: str = None) -> QueryResult: + if question is None: + question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--output-dir", - default="outputs", - help="Output directory to save text and wav files together.", + return QueryResult( + inputs={ + "prompt": prompt, + }, + limit_mm_per_prompt={}, ) - parser.add_argument( - "--thinker-hidden-states-dir", - default="thinker_hidden_states", - help="Path to thinker hidden states directory.", + + +def get_mixed_modalities_query() -> QueryResult: + question = "What is recited in the audio? What is the content of this image? Why is this video funny?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|vision_bos|><|IMAGE|><|vision_eos|>" + "<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--batch-timeout", - type=int, - default=5, - help="Timeout for batching in seconds (default: 5)", + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + "image": convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"), + "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, + }, + }, + limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1}, ) - parser.add_argument( - "--init-timeout", - type=int, - default=300, - help="Timeout for initializing stages in seconds (default: 300)", + + +def get_use_audio_in_video_query() -> QueryResult: + question = "Describe the content of the video, then convert what the baby say into text." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--shm-threshold-bytes", - type=int, - default=65536, - help="Threshold for using shared memory in bytes (default: 65536)", + asset = VideoAsset(name="baby_reading", num_frames=16) + audio = asset.get_audio(sampling_rate=16000) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": asset.np_ndarrays, + "audio": audio, + }, + "mm_processor_kwargs": { + "use_audio_in_video": True, + }, + }, + limit_mm_per_prompt={"audio": 1, "video": 1}, ) - parser.add_argument( - "--enable-stats", - action="store_true", - default=False, - help="Enable writing detailed statistics (default: disabled)", + + +def get_multi_audios_query() -> QueryResult: + question = "Are these two audio clips the same?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--txt-prompts", - type=str, - default=None, - help="Path to a .txt file with one prompt per line (preferred).", + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": [ + AudioAsset("winning_call").audio_and_sample_rate, + AudioAsset("mary_had_lamb").audio_and_sample_rate, + ], + }, + }, + limit_mm_per_prompt={ + "audio": 2, + }, ) - args = parser.parse_args() - return args -def main(): - args = parse_args() - model_name = args.model - try: - # Preferred: load from txt file (one prompt per line) - if getattr(args, "txt_prompts", None) and args.prompt_type == "text": - with open(args.txt_prompts, encoding="utf-8") as f: - lines = [ln.strip() for ln in f.readlines()] - args.prompts = [ln for ln in lines if ln != ""] - print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}") - except Exception as e: - print(f"[Error] Failed to load prompts: {e}") - raise - - if args.prompts is None: - raise ValueError("No prompts provided. Use --prompts ... or --txt-prompts (with --prompt_type text)") +query_map = { + "mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "multi_audios": get_multi_audios_query, + "text": get_text_query, +} + + +def main(args): + model_name = "Qwen/Qwen2.5-Omni-7B" + query_result = query_map[args.query_type]() + omni_llm = OmniLLM( model=model_name, log_stats=args.enable_stats, @@ -205,8 +184,16 @@ def main(): code2wav_sampling_params, ] - prompt = [make_omni_prompt(args, prompt) for prompt in args.prompts] - omni_outputs = omni_llm.generate(prompt, sampling_params_list) + if args.txt_prompts is None: + prompts = [query_result.inputs for _ in range(args.num_prompts)] + else: + assert args.query_type == "text", "txt-prompts is only supported for text query type" + with open(args.txt_prompts, encoding="utf-8") as f: + lines = [ln.strip() for ln in f.readlines()] + prompts = [get_text_query(ln).inputs for ln in lines if ln != ""] + print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}") + + omni_outputs = omni_llm.generate(prompts, sampling_params_list) # Determine output directory: prefer --output-dir; fallback to --output-wav output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav @@ -217,7 +204,7 @@ def main(): request_id = int(output.request_id) text_output = output.outputs[0].text # Save aligned text file per request - prompt_text = args.prompts[request_id] + prompt_text = prompts[request_id]["prompt"] out_txt = os.path.join(output_dir, f"{request_id:05d}.txt") lines = [] lines.append("Prompt:\n") @@ -239,5 +226,67 @@ def main(): print(f"Request ID: {request_id}, Saved audio to {output_wav}") +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--enable-stats", + action="store_true", + default=False, + help="Enable writing detailed statistics (default: disabled)", + ) + parser.add_argument( + "--init-sleep-seconds", + type=int, + default=20, + help="Sleep seconds after starting each stage process to allow initialization (default: 20)", + ) + parser.add_argument( + "--batch-timeout", + type=int, + default=5, + help="Timeout for batching in seconds (default: 5)", + ) + parser.add_argument( + "--init-timeout", + type=int, + default=300, + help="Timeout for initializing stages in seconds (default: 300)", + ) + parser.add_argument( + "--shm-threshold-bytes", + type=int, + default=65536, + help="Threshold for using shared memory in bytes (default: 65536)", + ) + parser.add_argument( + "--output-wav", + default="output_audio", + help="[Deprecated] Output wav directory (use --output-dir).", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1, + help="Number of prompts to generate.", + ) + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one prompt per line (preferred).", + ) + + return parser.parse_args() + + if __name__ == "__main__": - main() + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen2_5_omni/processing_omni.py b/examples/offline_inference/qwen2_5_omni/processing_omni.py deleted file mode 100644 index a22220dd388..00000000000 --- a/examples/offline_inference/qwen2_5_omni/processing_omni.py +++ /dev/null @@ -1,367 +0,0 @@ -from __future__ import annotations - -import base64 -import logging -import math -import os -import time -import warnings -from functools import lru_cache -from io import BytesIO - -import requests -import torch -import torchvision -from packaging import version -from PIL import Image -from torchvision import io, transforms -from torchvision.transforms import InterpolationMode - -logger = logging.getLogger(__name__) - -IMAGE_FACTOR = 28 -MIN_PIXELS = 4 * 28 * 28 -MAX_PIXELS = 16384 * 28 * 28 -MAX_RATIO = 200 - -VIDEO_MIN_PIXELS = 128 * 28 * 28 -VIDEO_MAX_PIXELS = 768 * 28 * 28 -VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 -FRAME_FACTOR = 2 -FPS = 2.0 -FPS_MIN_FRAMES = 4 -FPS_MAX_FRAMES = 768 - -temporal_patch_size = 2 -spatial_patch_size = 14 -spatial_merge_size = 2 - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is - divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is - divisible by 'factor'.""" - return math.floor(number / factor) * factor - - -def smart_resize( - height: int, - width: int, - factor: int = IMAGE_FACTOR, - min_pixels: int = MIN_PIXELS, - max_pixels: int = MAX_PIXELS, -) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - - 3. The aspect ratio of the image is maintained as closely as possible. - """ - if max(height, width) / min(height, width) > MAX_RATIO: - raise ValueError( - f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" - ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - return h_bar, w_bar - - -def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: - if "image" in ele: - image = ele["image"] - else: - image = ele["image_url"] - image_obj = None - if isinstance(image, Image.Image): - image_obj = image - elif image.startswith("http://") or image.startswith("https://"): - image_obj = Image.open(requests.get(image, stream=True).raw) - elif image.startswith("file://"): - image_obj = Image.open(image[7:]) - elif image.startswith("data:image"): - if "base64," in image: - _, base64_data = image.split("base64,", 1) - data = base64.b64decode(base64_data) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image) - if image_obj is None: - raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") - image = image_obj.convert("RGB") - # resize - if "resized_height" in ele and "resized_width" in ele: - resized_height, resized_width = smart_resize( - ele["resized_height"], - ele["resized_width"], - factor=size_factor, - ) - else: - width, height = image.size - min_pixels = ele.get("min_pixels", MIN_PIXELS) - max_pixels = ele.get("max_pixels", MAX_PIXELS) - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - - return image - - -def smart_nframes( - ele: dict, - total_frames: int, - video_fps: int | float, -) -> int: - """calculate the number of frames for video used for model inputs. - - Args: - ele (dict): a dict contains the configuration of video. - support either `fps` or `nframes`: - - nframes: the number of frames to extract for model inputs. - - fps: the fps to extract frames for model inputs. - - min_frames: the minimum number of frames of the video, - only used when fps is provided. - - max_frames: the maximum number of frames of the video, - only used when fps is provided. - total_frames (int): the original total number of frames of the video. - video_fps (int | float): the original fps of the video. - - Raises: - ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. - - Returns: - int: the number of frames for video used for model inputs. - """ - assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" - if "nframes" in ele: - nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) - else: - fps = ele.get("fps", FPS) - min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) - max_frames = floor_by_factor( - ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), - FRAME_FACTOR, - ) - nframes = total_frames / video_fps * fps - nframes = min(max(nframes, min_frames), max_frames) - nframes = round_by_factor(nframes, FRAME_FACTOR) - if not (FRAME_FACTOR <= nframes and nframes <= total_frames): - raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") - return nframes - - -def _read_video_torchvision( - ele: dict, -) -> torch.Tensor: - """read video using torchvision.io.read_video - - Args: - ele (dict): a dict contains the configuration of video. - support keys: - - video: the path of video. support "file://", "http://", - "https://" and local path. - - video_start: the start time of video. - - video_end: the end time of video. - Returns: - torch.Tensor: the video tensor with shape (T, C, H, W). - """ - video_path = ele["video"] - if version.parse(torchvision.__version__) < version.parse("0.19.0"): - if "http://" in video_path or "https://" in video_path: - warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") - if "file://" in video_path: - video_path = video_path[7:] - st = time.time() - video, audio, info = io.read_video( - video_path, - start_pts=ele.get("video_start", 0.0), - end_pts=ele.get("video_end", None), - pts_unit="sec", - output_format="TCHW", - ) - total_frames, video_fps = video.size(0), info["video_fps"] - total_duration = round(total_frames / video_fps, 3) - logger.info( - f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, duration={total_duration}s, time={time.time() - st:.3f}s" - ) - nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) - idx = torch.linspace(0, total_frames - 1, nframes).round().long() - video = video[idx] - return video, total_duration, nframes - - -def is_decord_available() -> bool: - import importlib.util - - return importlib.util.find_spec("decord") is not None - - -def _read_video_decord( - ele: dict, -) -> torch.Tensor: - """read video using decord.VideoReader - - Args: - ele (dict): a dict contains the configuration of video. - support keys: - - video: the path of video. support "file://", "http://", - "https://" and local path. - - video_start: the start time of video. - - video_end: the end time of video. - Returns: - torch.Tensor: the video tensor with shape (T, C, H, W). - """ - import decord - - video_path = ele["video"] - st = time.time() - vr = decord.VideoReader(video_path) - # TODO: support start_pts and end_pts - if "video_start" in ele or "video_end" in ele: - raise NotImplementedError("not support start_pts and end_pts in decord for now.") - total_frames, video_fps = len(vr), vr.get_avg_fps() - total_duration = round(total_frames / video_fps, 3) - logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") - nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) - idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() - video = vr.get_batch(idx).asnumpy() - video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format - return video, total_duration, nframes - - -VIDEO_READER_BACKENDS = { - "decord": _read_video_decord, - "torchvision": _read_video_torchvision, -} - -FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) - - -@lru_cache(maxsize=1) -def get_video_reader_backend() -> str: - if FORCE_QWENVL_VIDEO_READER is not None: - video_reader_backend = FORCE_QWENVL_VIDEO_READER - elif is_decord_available(): - video_reader_backend = "decord" - else: - video_reader_backend = "torchvision" - # print(f"qwen-vl-utils using {video_reader_backend} to read video.", - # file=sys.stderr) - return video_reader_backend - - -def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: - if isinstance(ele["video"], str): - video_reader_backend = get_video_reader_backend() - video, total_dur, nframes = VIDEO_READER_BACKENDS[video_reader_backend](ele) - frame_timestamps = total_dur * torch.arange(1, nframes + 1) / nframes - grid_timestamps = frame_timestamps[::FRAME_FACTOR] - second_per_grid = grid_timestamps[1] - grid_timestamps[0] - nframes, _, height, width = video.shape - min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) - total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) - max_pixels = max( - min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), - int(min_pixels * 1.05), - ) - max_pixels = ele.get("max_pixels", max_pixels) - # min_pixels = (factor ** 2) * 52 - # max_pixels = (factor ** 2) * min(768, (16384 / nframes * temporal_patch_size)) - if "resized_height" in ele and "resized_width" in ele: - resized_height, resized_width = smart_resize( - ele["resized_height"], - ele["resized_width"], - factor=image_factor, - ) - else: - resized_height, resized_width = smart_resize( - height, - width, - factor=image_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - video = transforms.functional.resize( - video, - [resized_height, resized_width], - interpolation=InterpolationMode.BICUBIC, - antialias=True, - ).float() - return video, total_dur, nframes, second_per_grid - else: - assert isinstance(ele["video"], (list, tuple)) - process_info = ele.copy() - process_info.pop("type", None) - process_info.pop("video", None) - images = [ - fetch_image({"image": video_element, **process_info}, size_factor=image_factor) - for video_element in ele["video"] - ] - nframes = ceil_by_factor(len(images), FRAME_FACTOR) - if len(images) < nframes: - images.extend([images[-1]] * (nframes - len(images))) - return images, None, None, None - - -def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: - vision_infos = [] - if isinstance(conversations[0], dict): - conversations = [conversations] - for conversation in conversations: - for message in conversation: - if isinstance(message["content"], list): - for ele in message["content"]: - if ( - "image" in ele - or "image_url" in ele - or "video" in ele - or ele["type"] in ("image", "image_url", "video") - ): - vision_infos.append(ele) - return vision_infos - - -def process_vision_info( - conversations: list[dict] | list[list[dict]], -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: - vision_infos = extract_vision_info(conversations) - # Read images or videos - image_inputs = [] - video_inputs = [] - for vision_info in vision_infos: - if "image" in vision_info or "image_url" in vision_info: - image_inputs.append(fetch_image(vision_info)) - elif "video" in vision_info: - video_inputs.append(fetch_video(vision_info)) - else: - raise ValueError("image, image_url or video should in content.") - if len(image_inputs) == 0: - image_inputs = None - if len(video_inputs) == 0: - video_inputs = None - return image_inputs, video_inputs diff --git a/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh b/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh index 78d3dd54fb0..2ec8a1c57ec 100644 --- a/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh +++ b/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh @@ -1,8 +1,3 @@ -python end2end.py --model Qwen/Qwen2.5-Omni-7B \ - --voice-type "m02" \ - --dit-ckpt none \ - --bigvgan-ckpt none \ - --output-wav output_audio \ - --prompt_type text \ - --init-sleep-seconds 0 \ - --txt-prompts top100.txt +python end2end.py --output-wav output_audio \ + --query-type text \ + --txt-prompts top10.txt diff --git a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh index 739902b2561..5b3c19cdc27 100644 --- a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh +++ b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh @@ -1,8 +1,2 @@ -python end2end.py --model Qwen/Qwen2.5-Omni-7B \ - --voice-type "m02" \ - --dit-ckpt none \ - --bigvgan-ckpt none \ - --output-wav output_audio \ - --prompt_type text \ - --init-sleep-seconds 0 \ - --prompts "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." +python end2end.py --output-wav output_audio \ + --query-type use_audio_in_video diff --git a/examples/offline_inference/qwen2_5_omni/utils.py b/examples/offline_inference/qwen2_5_omni/utils.py deleted file mode 100644 index 11df5c55a0e..00000000000 --- a/examples/offline_inference/qwen2_5_omni/utils.py +++ /dev/null @@ -1,312 +0,0 @@ -import tempfile -from typing import Optional, Union -from urllib.request import urlopen - -import librosa -import requests -import resampy -import soundfile as sf -import torch -import torchvision.io -from processing_omni import fetch_image, fetch_video -from transformers import AutoConfig, AutoProcessor -from vllm.inputs import TextPrompt - -from vllm_omni.inputs.data import OmniTokensPrompt - -# Simple caches to avoid repeated heavy HF loads per prompt -_PROCESSOR_CACHE: dict[str, "AutoProcessor"] = {} -_CONFIG_CACHE: dict[str, "AutoConfig"] = {} - - -def get_system_prompt(): - return { - "role": "system", - "content": [ - { - "type": "text", - "text": ( - "You are Qwen, a virtual human developed by the Qwen Team, " - "Alibaba Group, capable of perceiving auditory and visual inputs, " - "as well as generating text and speech." - ), - } - ], - } - - -def resample_wav_to_16khz(input_filepath): - data, original_sample_rate = sf.read(input_filepath) - # Only use the first channel - if len(data.shape) > 1: - data = data[:, 0] - # resample to 16kHz - data_resampled = resampy.resample(data, sr_orig=original_sample_rate, sr_new=16000) - return data_resampled - - -def fetch_and_read_video(args, video_url: str, fps=2): - def read_video_with_torchvision(video_file_name: str): - video, audio, info = torchvision.io.read_video( - video_file_name, - start_pts=0.0, - end_pts=None, - pts_unit="sec", - output_format="TCHW", - ) - - total_frames, video_fps = video.size(0), info["video_fps"] - total_duration = round(total_frames / video_fps, 3) - nframes = int(total_frames / video_fps * fps) - - frame_timestamps = total_duration * torch.arange(1, nframes + 1) / nframes - grid_timestamps = frame_timestamps[::2] - second_per_grid = grid_timestamps[1] - grid_timestamps[0] - - idx = torch.linspace(0, video.size(0) - 1, nframes).round().long() - video = video[idx] - - if args.legacy_omni_video: - return [video, total_duration, nframes, second_per_grid.item()] - else: - return video - - def read_video_with_transformers(video_file_name: Union[str, list[str]]): - video, total_duration, nframes, second_per_grid = fetch_video({"video": video_file_name}) - if total_duration is None and nframes is None: - nframes = len(video) - total_duration = 0.5 * nframes - second_per_grid = 1.0 - if args.legacy_omni_video: - return [video, total_duration, nframes, second_per_grid] - else: - return video - - def read_video(video_file_name: str): - if args.use_torchvision: - return read_video_with_torchvision(video_file_name) - else: - return read_video_with_transformers(video_file_name) - - if isinstance(video_url, str) and video_url.startswith("http"): - with tempfile.NamedTemporaryFile(delete=True) as temp_video_file: - resp = requests.get(video_url) - assert resp.status_code == requests.codes.ok, ( - f"Failed to fetch video from {video_url}, status_code:{resp.status_code}, resp:{resp}" - ) - - temp_video_file.write(urlopen(video_url).read()) - temp_video_file_path = temp_video_file.name - video_file_name = temp_video_file_path - return read_video(video_file_name) - else: - video_file_name = video_url - return read_video(video_file_name) - - -def make_inputs_qwen2_omni( - args, - messages: list[dict[str, Union[str, list[dict[str, str]]]]], - use_audio_in_video: Optional[bool] = False, - tokenize: bool = False, -) -> Union[OmniTokensPrompt, TextPrompt]: - from transformers import AutoConfig, AutoProcessor - - # Cached processor/config to prevent per-prompt reloading and repeated warnings - if args.model not in _PROCESSOR_CACHE: - _PROCESSOR_CACHE[args.model] = AutoProcessor.from_pretrained(args.model) - processor = _PROCESSOR_CACHE[args.model] - - config = _CONFIG_CACHE.get(args.model) - if config is None: - try: - config = AutoConfig.from_pretrained(args.model) - except Exception: - config = None - _CONFIG_CACHE[args.model] = config # cache even if None to avoid retry storms - - # Decide legacy flag only once based on config (default True if unknown) - if getattr(args, "legacy_omni_video", None) is None: - if config is not None and hasattr(config, "architectures"): - args.legacy_omni_video = "Qwen2_5OmniModel" not in config.architectures - else: - args.legacy_omni_video = True - - audios, images, videos = [], [], [] - for message in messages: - if not isinstance(message["content"], list): - message["content"] = [ - { - "type": "text", - "text": message["content"], - } - ] - index, num_contents = 0, len(message["content"]) - while index < num_contents: - ele = message["content"][index] - if "type" not in ele: - if "text" in ele: - ele["type"] = "text" - elif "audio" in ele: - ele["type"] = "audio" - elif "audio_url" in ele: - ele["type"] = "audio_url" - elif "image" in ele: - ele["type"] = "image" - elif "image_url" in ele: - ele["type"] = "image_url" - elif "video" in ele: - ele["type"] = "video" - elif "video_url" in ele: - ele["type"] = "video_url" - else: - raise ValueError(f"Unknown ele: {ele}") - - if ele["type"] == "audio" or ele["type"] == "audio_url": - if "audio_url" in ele: - audio_key = "audio_url" - with tempfile.NamedTemporaryFile(delete=True) as temp_audio_file: - temp_audio_file.write(urlopen(ele[audio_key]).read()) - temp_audio_file_path = temp_audio_file.name - audios.append(resample_wav_to_16khz(temp_audio_file_path)) - ele["audio"] = temp_audio_file_path - elif "audio" in ele: - audio_key = "audio" - audios.append(resample_wav_to_16khz(ele[audio_key])) - else: - raise ValueError(f"Unknown ele {ele}") - elif use_audio_in_video and (ele["type"] == "video" or ele["type"] == "video_url"): - # use video as audio as well - if "video_url" in ele: - audio_key = "video_url" - with tempfile.NamedTemporaryFile(delete=True) as temp_video_file: - temp_video_file.write(urlopen(ele[audio_key]).read()) - temp_video_file_path = temp_video_file.name - ele[audio_key] = temp_video_file_path - audios.append(librosa.load(temp_video_file_path, sr=16000)[0]) - videos.append(fetch_and_read_video(args, temp_video_file_path)) - ele["video"] = temp_video_file_path - elif "video" in ele: - audio_key = "video" - audios.append(librosa.load(ele[audio_key], sr=16000)[0]) - videos.append(fetch_and_read_video(args, audio_key)) - else: - raise ValueError(f"Unknown ele {ele}") - # insert a audio after the video - message["content"].insert( - index + 1, - { - "type": "audio", - "audio": ele[audio_key], - }, - ) - # no need to load the added audio again - index += 1 - elif ele["type"] == "video" or ele["type"] == "video_url": - if "video_url" in ele: - video_key = "video_url" - with tempfile.NamedTemporaryFile(delete=True) as temp_video_file: - temp_video_file.write(urlopen(ele["video_url"]).read()) - temp_video_file_path = temp_video_file.name - videos.append(fetch_and_read_video(args, temp_video_file)) - ele["video"] = temp_video_file_path - else: - video_key = "video" - videos.append(fetch_and_read_video(args, ele[video_key])) - elif ele["type"] == "image" or ele["type"] == "image_url": - images.append(fetch_image(ele)) - - # move to the next content - index += 1 - - prompt = processor.apply_chat_template( - messages, - tokenize=tokenize, - add_generation_prompt=True, - add_vision_id=True, - ) - - audios = audios if len(audios) > 0 else None - images = images if len(images) > 0 else None - videos = videos if len(videos) > 0 else None - - multi_modal_data = {} - if audios: - multi_modal_data["audio"] = audios - if images: - multi_modal_data["image"] = images - if videos: - multi_modal_data["video"] = videos - - if isinstance(prompt, list) and isinstance(prompt[0], (list, str)): - prompt = prompt[0] - - if tokenize: - return OmniTokensPrompt( - prompt_token_ids=prompt, - multi_modal_data=multi_modal_data, - ) - else: - return TextPrompt( - prompt=prompt, - multi_modal_data=multi_modal_data, - ) - - -def make_text_prompt(args, prompt): - messages = [ - get_system_prompt(), - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - ], - }, - ] - - prompt = make_inputs_qwen2_omni(args, messages, tokenize=args.tokenize) - return prompt - - -def make_audio_in_video_v2_prompt(args): - messages = [ - { - "role": "system", - "content": [ - { - "type": "text", - "text": ( - "You are Qwen, a virtual human developed by the Qwen Team, " - "Alibaba Group, capable of perceiving auditory and visual " - "inputs, as well as generating text and speech." - ), - } - ], - }, - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": ("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw_small.mp4"), - }, - ], - }, - ] - prompt = make_inputs_qwen2_omni( - args, - messages, - use_audio_in_video=True, - tokenize=args.tokenize, - ) - return prompt - - -def make_omni_prompt(args, prompt=None) -> Union[OmniTokensPrompt, list[OmniTokensPrompt]]: - if args.prompt_type == "text": - prompt = make_text_prompt(args, prompt) - elif args.prompt_type == "audio-in-video-v2": - prompt = make_audio_in_video_v2_prompt(args) - else: - raise ValueError(f"Unsupported prompt type: {args.prompt_type}") - return prompt diff --git a/examples/online_serving/README.md b/examples/online_serving/README.md index b64989326dc..ac86f96c900 100644 --- a/examples/online_serving/README.md +++ b/examples/online_serving/README.md @@ -23,10 +23,18 @@ cd examples/online_serving Send request via python ```bash -python openai_chat_completion_client_for_multimodal_generation.py +python openai_chat_completion_client_for_multimodal_generation.py --query-type mixed_modalities ``` Send request via curl ```bash -bash run_curl_multimodal_generation.sh +bash run_curl_multimodal_generation.sh mixed_modalities +``` + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg ``` diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py index 13d37476d8f..47e109be5ac 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py @@ -1,6 +1,9 @@ import base64 +import requests from openai import OpenAI +from vllm.assets.audio import AudioAsset +from vllm.utils import FlexibleArgumentParser # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" @@ -15,6 +18,16 @@ SEED = 42 +def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode("utf-8") + + return result + + def get_system_prompt(): return { "role": "system", @@ -31,7 +44,106 @@ def get_system_prompt(): } -def run_text_to_audio(model: str) -> None: +def get_text_query(): + question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + prompt = { + "role": "user", + "content": [ + { + "type": "text", + "text": f"{question}", + } + ], + } + return prompt + + +def get_mixed_modalities_query(): + question = "What is recited in the audio? What is the content of this image? Why is this video funny?" + prompt = { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": AudioAsset("mary_had_lamb").url}, + }, + { + "type": "image_url", + "image_url": { + "url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" + }, + }, + { + "type": "video_url", + "video_url": { + "url": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + }, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + + return prompt + + +def get_use_audio_in_video_query(): + question = "Describe the content of the video, then convert what the baby say into text." + + prompt = { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", + "num_frames": 16, + }, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + + return prompt + + +def get_multi_audios_query(): + question = "Are these two audio clips the same?" + prompt = { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": AudioAsset("mary_had_lamb").url}, + }, + { + "type": "audio_url", + "audio_url": {"url": AudioAsset("winning_call").url}, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + return prompt + + +query_map = { + "mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "multi_audios": get_multi_audios_query, + "text": get_text_query, +} + + +def run_multimodal_generation(args) -> None: + model_name = "Qwen/Qwen2.5-Omni-7B" thinker_sampling_params = { "temperature": 0.0, # Deterministic - no randomness "top_p": 1.0, # Disable nucleus sampling @@ -67,23 +179,21 @@ def run_text_to_audio(model: str) -> None: code2wav_sampling_params, ] + prompt = query_map[args.query_type]() + extra_body = { + "sampling_params_list": sampling_params_list # Optional, it has a default setting in stage_configs of the corresponding model. + } + + if args.query_type == "use_audio_in_video": + extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True} + chat_completion = client.chat.completions.create( messages=[ get_system_prompt(), - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words.", - }, - ], - }, + prompt, ], - model=model, - extra_body={ - "sampling_params_list": sampling_params_list - }, # Optional, it has a default setting in stage_configs of the corresponding model. + model=model_name, + extra_body=extra_body, ) count = 0 @@ -99,5 +209,20 @@ def run_text_to_audio(model: str) -> None: print("Chat completion output from text:", choice.message.content) +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + + return parser.parse_args() + + if __name__ == "__main__": - run_text_to_audio("Qwen/Qwen2.5-Omni-7B") + args = parse_args() + run_multimodal_generation(args) diff --git a/examples/online_serving/run_curl_multimodal_generation.sh b/examples/online_serving/run_curl_multimodal_generation.sh index 6d2afd4a846..d0c85bc391b 100644 --- a/examples/online_serving/run_curl_multimodal_generation.sh +++ b/examples/online_serving/run_curl_multimodal_generation.sh @@ -1,6 +1,20 @@ #!/usr/bin/env bash set -euo pipefail +# Default query type +QUERY_TYPE="${1:-mixed_modalities}" + +# Validate query type +if [[ ! "$QUERY_TYPE" =~ ^(mixed_modalities|use_audio_in_video|multi_audios|text)$ ]]; then + echo "Error: Invalid query type '$QUERY_TYPE'" + echo "Usage: $0 [mixed_modalities|use_audio_in_video|multi_audios|text]" + echo " mixed_modalities: Audio + Image + Video + Text query" + echo " use_audio_in_video: Video + Text query (with audio extraction from video)" + echo " multi_audios: Two audio clips + Text query" + echo " text: Text query" + exit 1 +fi + SEED=42 thinker_sampling_params='{ @@ -35,18 +49,121 @@ code2wav_sampling_params='{ }' # Above is optional, it has a default setting in stage_configs of the corresponding model. +# Define URLs for assets +MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg" +WINNING_CALL_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/winning_call.ogg" +CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" +SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + +# Build user content and extra fields based on query type +case "$QUERY_TYPE" in + text) + user_content='[ + { + "type": "text", + "text": "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + mixed_modalities) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "What is recited in the audio? What is the content of this image? Why is this video funny?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + use_audio_in_video) + user_content='[ + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "Describe the content of the video, then convert what the baby say into text." + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs='{ + "use_audio_in_video": true + }' + ;; + multi_audios) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "audio_url", + "audio_url": { + "url": "'"$WINNING_CALL_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "Are these two audio clips the same?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; +esac + +echo "Running query type: $QUERY_TYPE" +echo "" + + output=$(curl -sS -X POST http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d @- < "BaseMultiModalContentParser": + return OmniAsyncMultiModalContentParser(self) + + +class OmniAsyncMultiModalContentParser(AsyncMultiModalContentParser): + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__(tracker=tracker) + self._mm_processor_kwargs: Optional[dict[str, Any]] = None + + def set_mm_processor_kwargs(self, mm_processor_kwargs: Optional[dict[str, Any]]) -> None: + """Set mm_processor_kwargs for use in parsing.""" + self._mm_processor_kwargs = mm_processor_kwargs + + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + video = self._connector.fetch_video_async(video_url=video_url) if video_url else None + + placeholder = self._tracker.add("video", video, uuid) + self._add_placeholder("video", placeholder) + + # Extract audio from video if use_audio_in_video is True + if video_url and self._mm_processor_kwargs and self._mm_processor_kwargs.get("use_audio_in_video", False): + audio_coro = self._extract_audio_from_video_async(video_url) + audio_placeholder = self._tracker.add("audio", audio_coro, uuid) + self._add_placeholder("audio", audio_placeholder) + + async def _extract_audio_from_video_async(self, video_url: str) -> tuple[np.ndarray, Union[int, float]]: + """ + Extract audio from video URL using librosa. + Returns tuple of (audio_array, sample_rate) compatible with audio format. + + All blocking I/O operations are run in a thread pool to avoid blocking the event loop. + """ + import asyncio + import os + import tempfile + from urllib.parse import urlparse + + # Parse URL to determine type + parsed_url = urlparse(video_url) + temp_video_file_path = None + + def _download_video_sync(url: str) -> bytes: + """Synchronous video download - runs in thread pool.""" + from urllib.request import urlopen + + return urlopen(url).read() + + def _write_temp_file_sync(data: bytes, suffix: str) -> str: + """Synchronous temp file write - runs in thread pool.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: + temp_file.write(data) + return temp_file.name + + def _load_audio_sync(file_path: str) -> tuple[np.ndarray, Union[int, float]]: + """Synchronous audio loading with librosa - runs in thread pool.""" + import librosa + + return librosa.load(file_path, sr=16000) + + def _cleanup_file_sync(file_path: str) -> None: + """Synchronous file deletion - runs in thread pool.""" + try: + if os.path.exists(file_path): + os.unlink(file_path) + except OSError: + pass + + try: + if parsed_url.scheme in ("http", "https"): + # Download video from HTTP/HTTPS URL asynchronously + video_data = await asyncio.to_thread(_download_video_sync, video_url) + # Write temp file asynchronously + temp_video_file_path = await asyncio.to_thread(_write_temp_file_sync, video_data, ".mp4") + elif parsed_url.scheme == "file": + # Use file path directly (handle Windows paths) + from urllib.request import url2pathname + + temp_video_file_path = url2pathname(parsed_url.path) + elif parsed_url.scheme == "data": + # Handle data URL (base64 encoded video) + import base64 + + header, data = video_url.split(",", 1) + video_data = base64.b64decode(data) + # Write temp file asynchronously + temp_video_file_path = await asyncio.to_thread(_write_temp_file_sync, video_data, ".mp4") + else: + # Assume it's a local file path + temp_video_file_path = video_url + + # Extract audio using librosa asynchronously (CPU-intensive, runs in thread pool) + audio_array, sample_rate = await asyncio.to_thread(_load_audio_sync, temp_video_file_path) + + return audio_array, sample_rate + finally: + # Clean up temporary file if we created one (asynchronously) + if temp_video_file_path and parsed_url.scheme in ("http", "https", "data"): + await asyncio.to_thread(_cleanup_file_sync, temp_video_file_path) + + +def parse_chat_messages_futures( + messages: list[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> tuple[ + list[ConversationMessage], + Awaitable[Optional[MultiModalDataDict]], + Optional[MultiModalUUIDDict], +]: + conversation: list[ConversationMessage] = [] + mm_tracker = OmniAsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + interleave_strings=( + content_format == "string" + and model_config.multimodal_config is not None + and model_config.multimodal_config.interleave_mm_strings + ), + mm_processor_kwargs=mm_processor_kwargs, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, + content_format: _ChatTemplateContentFormat, + interleave_strings: bool, + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> list[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ChatCompletionContentPartTextParam(type="text", text=content)] + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + wrap_dicts=(content_format == "openai"), + interleave_strings=interleave_strings, + mm_processor_kwargs=mm_processor_kwargs, + ) + + for result_msg in result: + if role == "assistant": + parsed_msg = _AssistantParser(message) + + # The 'tool_calls' is not None check ensures compatibility. + # It's needed only if downstream code doesn't strictly + # follow the OpenAI spec. + if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, + *, + wrap_dicts: bool, + interleave_strings: bool, + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> list[ConversationMessage]: + content = list[_ContentPart]() + + mm_parser = mm_tracker.create_parser() + # Set mm_processor_kwargs if parser supports it + if hasattr(mm_parser, "set_mm_processor_kwargs"): + mm_parser.set_mm_processor_kwargs(mm_processor_kwargs) + + for part in parts: + parse_res = _parse_chat_message_content_part( + part, + mm_parser, + wrap_dicts=wrap_dicts, + interleave_strings=interleave_strings, + ) + if parse_res: + content.append(parse_res) + + if wrap_dicts: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, content=content)] # type: ignore + texts = cast(list[str], content) + mm_placeholder_storage = mm_parser.mm_placeholder_storage() + if mm_placeholder_storage: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, texts, interleave_strings) + else: + text_prompt = "\n".join(texts) + + return [ConversationMessage(role=role, content=text_prompt)] \ No newline at end of file diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index a2a3b36ff73..861235c5ed9 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -50,6 +50,7 @@ def __init__(self, stage_config): self.stage_id = stage_config.stage_id self.engine_args = stage_config.engine_args self.model_stage = stage_config.engine_args.model_stage + self.requires_multimodal_data = getattr(stage_config.runtime, "requires_multimodal_data", False) self.engine_input_source = getattr(stage_config, "engine_input_source", []) self.engine_output_type = stage_config.engine_args.engine_output_type self.engine_outputs = None @@ -205,14 +206,20 @@ def process_engine_inputs( for source_output in source_outputs: engine_input = OmniTokensPrompt( prompt_token_ids=source_output.outputs[0].token_ids, - multi_modal_data=(multi_modal_data[source_output.request_id] if multi_modal_data else None), + multi_modal_data=( + multi_modal_data[source_output.request_id] + if self.requires_multimodal_data and multi_modal_data + else None + ), ) engine_inputs.append(engine_input) return engine_inputs else: engine_input_source = self.engine_input_source - return self.custom_process_input_func(stage_list, engine_input_source, prompt) + return self.custom_process_input_func( + stage_list, engine_input_source, prompt, self.requires_multimodal_data + ) def _stage_worker( @@ -227,13 +234,13 @@ def _stage_worker( import logging as _logging import time as _time - from vllm_omni.entrypoints.log_utils import ( # noqa: WPS433 + from vllm_omni.entrypoints.log_utils import ( compute_and_log_stage_request_stats, count_tokens_from_outputs, log_stage_batch_stats, log_stage_running_avg, ) - from vllm_omni.entrypoints.omni_llm import OmniStageLLM # noqa: WPS433 + from vllm_omni.entrypoints.omni_llm import OmniStageLLM # no inline JSONL/serialization imports; logging handled by utilities @@ -503,8 +510,8 @@ async def _stage_worker_async( import logging as _logging import time as _time - from vllm_omni.entrypoints.async_omni_llm import AsyncOmniStageLLM # noqa: WPS433 - from vllm_omni.entrypoints.log_utils import ( # noqa: WPS433 + from vllm_omni.entrypoints.async_omni_llm import AsyncOmniStageLLM + from vllm_omni.entrypoints.log_utils import ( compute_and_log_stage_request_stats, count_tokens_from_outputs, log_stage_batch_stats, @@ -724,4 +731,4 @@ def filter(self, record: _logging.LogRecord) -> bool: "stage_id": stage_id, "error": str(e), } - ) + ) \ No newline at end of file diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 6e7ade7c777..185e9997e16 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -3,10 +3,10 @@ import json import time import uuid -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncGenerator, AsyncIterator, Sequence from datetime import datetime, timedelta, timezone from io import BytesIO -from typing import Optional, Union +from typing import Any, Callable, Optional, Union import jinja2 from fastapi import Request @@ -18,7 +18,16 @@ soundfile = None from openai.types.chat.chat_completion_audio import ChatCompletionAudio as OpenAIChatCompletionAudio -from vllm.entrypoints.chat_utils import ConversationMessage, get_history_tool_calls_cnt, make_tool_call_id +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + get_history_tool_calls_cnt, + make_tool_call_id, + resolve_chat_template_content_format, +) from vllm.entrypoints.harmony_utils import parse_chat_output from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, @@ -35,7 +44,16 @@ UsageInfo, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import RequestPrompt, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_engine import ( + ChatLikeRequest, + EngineTokensPrompt, + RequestPrompt, + ResponsesRequest, + TextTokensPrompt, + clamp_prompt_logprobs, + is_list_of, +) +from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -50,6 +68,7 @@ ) from vllm.utils import as_list +from vllm_omni.entrypoints.chat_utils import parse_chat_messages_futures from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -215,6 +234,123 @@ async def create_chat_completion( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + async def _preprocess_chat( + self, + request: Union[ChatLikeRequest, ResponsesRequest], + tokenizer: AnyTokenizer, + messages: list[ChatCompletionMessageParam], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: Optional[list[dict[str, Any]]] = None, + documents: Optional[list[dict[str, str]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + add_special_tokens: bool = False, + ) -> tuple[ + list[ConversationMessage], + Sequence[RequestPrompt], + list[EngineTokensPrompt], + ]: + model_config = self.model_config + + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tool_dicts, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + model_config, + tokenizer, + content_format=resolved_content_format, + mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + request_prompt: Union[str, list[int]] + + if tokenizer is None: + request_prompt = "placeholder" + elif isinstance(tokenizer, MistralTokenizer): + request_prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + **_chat_template_kwargs, + ) + else: + request_prompt = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + + mm_data = await mm_data_future + + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + should_parse_tools = tool_parser is not None and ( + hasattr(request, "tool_choice") and request.tool_choice != "none" + ) + + if should_parse_tools: + if not isinstance(request, ChatCompletionRequest): + msg = "Tool usage is only supported for Chat Completions API" + raise NotImplementedError(msg) + + request = tool_parser(tokenizer).adjust_request( # type: ignore + request=request + ) + + if tokenizer is None: + assert isinstance(request_prompt, str), ( + "Prompt has to be a string", + "when the tokenizer is not initialised", + ) + prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) + elif isinstance(request_prompt, str): + prompt_inputs = await self._tokenize_prompt_input_async( + request, + tokenizer, + request_prompt, + add_special_tokens=add_special_tokens, + ) + else: + # For MistralTokenizer + assert is_list_of(request_prompt, int), "Prompt has to be either a string or a list of token ids" + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt, + ) + + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + engine_prompt["multi_modal_uuids"] = mm_uuids + + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + + if hasattr(request, "cache_salt") and request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + return conversation, [request_prompt], [engine_prompt] + def _to_sampling_params_list(self, sampling_params_list: list[dict]) -> list[SamplingParams]: final_sampling_params_list = [] for sampling_params in sampling_params_list: @@ -640,4 +776,4 @@ def _create_audio_choice(self, omni_outputs: OmniRequestOutput, role: str): stop_reason=None, ) choices.append(choice_data) - return choices + return choices \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py b/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py:Zone.Identifier b/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py:Zone.Identifier new file mode 100644 index 0000000000000000000000000000000000000000..d6c1ec682968c796b9f5e9e080cc6f674b57c766 GIT binary patch literal 25 dcma!!%Fjy;DN4*MPD?F{<>dl#JyUFr831@K2x set[str]: Add a prefix to the names of the loaded weights. """ return {maybe_prefix(prefix, name) for name in weights} + + +def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + if lst.numel() == 0: + return [] + + # Move to CPU and convert to list once (High Speedup) + # using .item() inside a loop is very slow. + data_list = lst.detach().cpu().tolist() + + # Calculate max on the list or tensor (Tensor max is fast enough) + max_val = int(torch.max(lst).item()) + + # Pre-allocate buckets + ranges: list[list[int]] = [[] for _ in range((max_val // interval) + 1)] + + for num in data_list: + index = int(num // interval) + ranges[index].append(num) + + return ranges \ No newline at end of file diff --git a/vllm_omni/model_executor/models/vision.py b/vllm_omni/model_executor/models/vision.py new file mode 100644 index 00000000000..d03892e540c --- /dev/null +++ b/vllm_omni/model_executor/models/vision.py @@ -0,0 +1,23 @@ +import torch + + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, +) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index_tensor = ( + torch.Tensor(t_index).to(llm_grid_h.device).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + ) + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids \ No newline at end of file diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index c1a04d9b04f..98b5d29e35b 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -10,7 +10,12 @@ TALKER_CODEC_END_TOKEN_ID = 8294 -def thinker2talker(stage_list, engine_input_source, prompt: Union[OmniTokensPrompt, TextPrompt] = None): +def thinker2talker( + stage_list, + engine_input_source, + prompt: Union[OmniTokensPrompt, TextPrompt] = None, + requires_multimodal_data: bool = False, +): if not engine_input_source: raise ValueError("engine_input_source cannot be empty") source_stage_id = engine_input_source[0] @@ -48,9 +53,11 @@ def thinker2talker(stage_list, engine_input_source, prompt: Union[OmniTokensProm + [TALKER_CODEC_END_TOKEN_ID], additional_information=additional_information, multi_modal_data=( - multi_modal_data[thinker_output.request_id] if multi_modal_data is not None else None + multi_modal_data[thinker_output.request_id] + if requires_multimodal_data and multi_modal_data is not None + else None ), mm_processor_kwargs=None, ) ) - return talker_inputs + return talker_inputs \ No newline at end of file diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 0a90dbdb8ee..ecc9a3fa8f3 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -6,6 +6,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import supports_mrope from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.sampling_params import SamplingType from vllm.utils import LazyLoader, cdiv @@ -29,6 +30,42 @@ class OmniGPUModelRunner(GPUModelRunner): + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) + + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -543,4 +580,4 @@ def _dummy_run( logit_indices = np.cumsum(num_scheduled_tokens) - 1 hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) - return hidden_states, hidden_states[logit_indices] + return hidden_states, hidden_states[logit_indices] \ No newline at end of file diff --git a/vllm_omni/worker/npu_model_runner.py b/vllm_omni/worker/npu_model_runner.py index faf958f3aaa..3c36085cbf8 100644 --- a/vllm_omni/worker/npu_model_runner.py +++ b/vllm_omni/worker/npu_model_runner.py @@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.forward_context import BatchDescriptor, DPMetadata, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import supports_mrope from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.multimodal import MULTIMODAL_REGISTRY @@ -77,6 +78,42 @@ def mm_budget(self): ) if self.supports_mm_inputs else None return self._mm_budget + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) + + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -220,10 +257,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - if vllm_version_is("0.11.0"): - self._init_mrope_positions(self.requests[req_id]) - else: - self._init_mrope_positions_0102(self.requests[req_id]) + self._init_mrope_positions(self.requests[req_id]) reqs_to_add.append(self.requests[req_id]) From 9e9a187e9ed58d0eb1a84498f5e4a9901829f472 Mon Sep 17 00:00:00 2001 From: AndyZhou952 Date: Tue, 25 Nov 2025 13:28:13 +0800 Subject: [PATCH 08/11] rm redundant files --- .../qwen2_5_omni/__init__.py:Zone.Identifier | Bin 25 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 vllm_omni/model_executor/models/qwen2_5_omni/__init__.py:Zone.Identifier diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py:Zone.Identifier b/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py:Zone.Identifier deleted file mode 100644 index d6c1ec682968c796b9f5e9e080cc6f674b57c766..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 25 dcma!!%Fjy;DN4*MPD?F{<>dl#JyUFr831@K2x Date: Wed, 26 Nov 2025 04:28:33 +0000 Subject: [PATCH 09/11] Use vllm-ascend's ModelRunner style Signed-off-by: gcanlin --- vllm_omni/worker/npu_ar_model_runner.py | 1194 ++++++++++------- .../worker/npu_diffusion_model_runner.py | 958 +++++++------ 2 files changed, 1284 insertions(+), 868 deletions(-) diff --git a/vllm_omni/worker/npu_ar_model_runner.py b/vllm_omni/worker/npu_ar_model_runner.py index 1fce6992b70..e56b515c4c6 100644 --- a/vllm_omni/worker/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu_ar_model_runner.py @@ -2,78 +2,216 @@ from __future__ import annotations -from typing import Any, Optional, Union +import math +from typing import Any import numpy as np import torch - -from vllm.forward_context import BatchDescriptor -from vllm.logger import init_logger +import torch.nn as nn +from vllm.config import CUDAGraphMode +from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import logger +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.outputs import AsyncModelRunnerOutput -from vllm.v1.structured_output.utils import apply_grammar_bitmask -from vllm.v1.utils import record_function_or_nullcontext -from vllm.v1.worker.gpu_model_runner import ( + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, +) + +# yapf: enable +from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, - IntermediateTensors, - get_pp_group, - get_tp_group, - has_kv_transfer_group, + AsyncModelRunnerOutput, + ModelRunnerOutput, ) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.compilation.acl_graph import update_attn_params, update_mla_attn_params +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.utils import ( + ProfileExecuteDuration, + enable_sp, + lmhead_tp_enable, +) from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput -from vllm.v1.worker.ubatch_utils import UBatchSlices -from vllm.v1.worker.utils import is_residual_scattered_for_sp + +from vllm_omni.engine import AdditionalInformationPayload, PromptEmbedsPayload from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.npu_model_runner import OmniNPUModelRunner -logger = init_logger(__name__) - class NPUARModelRunner(OmniNPUModelRunner): """Autoregressive NPU model runner that returns hidden states per request.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _preprocess( + def _prepare_inputs( self, - scheduler_output: "SchedulerOutput", - num_scheduled_tokens_np: np.ndarray, - intermediate_tensors: Optional[IntermediateTensors] = None, - ubatch_slices: Optional[UBatchSlices] = None, - num_tokens_after_padding: Optional[torch.Tensor] = None, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, ) -> tuple[ + dict[str, Any], + torch.Tensor, + np.ndarray, int, + torch.Tensor, int, - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], - dict[str, Any], - Optional[dict[str, dict]], + SpecDecodeMetadata, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int, + dict[str, dict] | None, ]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array( + [ + num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32, + ) + + if self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]: + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph(total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo) = self._sync_metadata_across_dp( + num_input_tokens, with_prefill, enable_dbo + ) + + # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens + # We should consider removing maybe_padded_num_tokens later + num_input_tokens = maybe_padded_num_tokens + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True + ) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + self.query_start_loc[: num_reqs + 1].copy_(self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.query_start_loc[num_reqs + 1 :].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_(self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.attn_state = attn_state # type: ignore + + self.with_prefill = with_prefill + self.num_tokens_across_dp = num_tokens_across_dp + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) + attn_metadata: dict[str, Any] = {} - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if ubatch_slices: - assert num_tokens_after_padding is not None - num_input_tokens = int(num_tokens_after_padding[0].item() * 2) - if hasattr(self, "pad_out_ubatch_slice"): - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif ubatch_slices is None: - num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad - - # _prepare_inputs may reorder the batch, so we must gather multi - # modal outputs after that to ensure the correct order - per_req_additional_information: Optional[dict[str, dict]] = None - if ( - self.supports_mm_inputs - and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder - ): + # Omni-new: per_req_additional_information + per_req_additional_information: dict[str, dict] | None = None + + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order + if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) @@ -81,48 +219,54 @@ def _preprocess( # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids[:num_scheduled_tokens], - multimodal_embeddings=mm_embeds or None, - ) - + input_ids = self.input_ids[:total_num_scheduled_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings(input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) + self.inputs_embeds[:total_num_scheduled_tokens].copy_(inputs_embeds) + # Omni-new: Reset per-step additional information collector (deprecated concat path) if hasattr(self, "_forward_additional_information"): self._forward_additional_information = None + # Omni-new: per-request additional information for this step per_req_additional_information = {} + # Omni-new: Overlay custom prompt_embeds per request for the prompt portion; + # collect additional_information (tensor/list) for prefill portion only for req_index, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) addi_cpu = getattr(req_state, "additional_information_cpu", None) - num_computed_tokens = int( - self.input_batch.num_computed_tokens_cpu[req_index] - ) + num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) prompt_len = len(req_state.prompt_token_ids) prompt_remaining = max(0, prompt_len - num_computed_tokens) - sched_tokens = int(num_scheduled_tokens_np[req_index]) + sched_tokens = int(num_scheduled_tokens[req_index]) overlay_len = min(sched_tokens, prompt_remaining) if overlay_len <= 0: continue if pe_cpu is not None: - src = pe_cpu[ - num_computed_tokens : num_computed_tokens + overlay_len - ].to(dtype=self.dtype, device=self.device, non_blocking=True) - start_offset = int(self.query_start_loc_cpu[req_index]) - self.inputs_embeds[start_offset : start_offset + overlay_len].copy_( - src + src = pe_cpu[num_computed_tokens : num_computed_tokens + overlay_len].to( + dtype=self.dtype, device=self.device, non_blocking=True ) + start_offset = int(self.query_start_loc.cpu[req_index]) + self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + # Build per-request additional information (no cross-request concat) if addi_cpu is not None and isinstance(addi_cpu, dict): req_info: dict[str, object] = {} for k, v in addi_cpu.items(): if isinstance(v, torch.Tensor): + # For prefill tokens, pass only the scheduled slice; + # for decode or no scheduled tokens, pass whole tensor if overlay_len > 0: try: - seg = v[ - num_computed_tokens : num_computed_tokens + overlay_len - ].detach().to("cpu").contiguous() + seg = ( + v[num_computed_tokens : num_computed_tokens + overlay_len] + .detach() + .to("cpu") + .contiguous() + ) except Exception: seg = v.detach().to("cpu").contiguous() req_info[k] = seg @@ -133,284 +277,355 @@ def _preprocess( else: req_info[k] = v per_req_additional_information[req_id] = req_info - - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = self.inputs_embeds[:num_input_tokens] - model_kwargs = { - **self._init_model_kwargs(num_scheduled_tokens), - **self._extract_mm_kwargs(scheduler_output), - } - elif self.enable_prompt_embeds and get_pp_group().is_first_rank: - if hasattr(self, "is_token_ids"): - token_ids_idx = ( - self.is_token_ids[:num_scheduled_tokens] - .nonzero(as_tuple=False) - .squeeze(1) - ) - if token_ids_idx.numel() > 0: - token_ids = self.input_ids[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) - self.inputs_embeds[token_ids_idx] = tokens_to_embeds - inputs_embeds = self.inputs_embeds[:num_input_tokens] - model_kwargs = self._init_model_kwargs(num_input_tokens) - input_ids = None + input_ids = self.input_ids[:num_input_tokens] else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_input_tokens) - if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] - else: - positions = self.positions[:num_input_tokens] + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens + ) if get_pp_group().is_first_rank: intermediate_tensors = None else: - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_(v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors( + {k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items()} + ) + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to(self.device, non_blocking=True) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata(num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[:total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + self.slot_mapping[total_num_scheduled_tokens:].fill_(0) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, ) - if ( - self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs - ): - encoder_inputs = self._extract_encoder_inputs(scheduler_output) - model_kwargs.update(encoder_inputs) + if self.speculative_config and spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder) or self.model_config.runner_type == "pooling": + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args, + ) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) return ( + attn_metadata, + positions, num_scheduled_tokens, num_input_tokens, - num_tokens_after_padding, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, input_ids, inputs_embeds, - positions, intermediate_tensors, - model_kwargs, + max_num_scheduled_tokens, per_req_additional_information, ) @torch.inference_mode() def execute_model( self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[OmniModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: - with record_function_or_nullcontext("Preprocess"): - with self.synchronize_input_prep(): - super()._update_states(scheduler_output) - - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - return EMPTY_MODEL_RUNNER_OUTPUT - if hasattr(self, "kv_connector_no_forward"): - return self.kv_connector_no_forward(scheduler_output) - return EMPTY_MODEL_RUNNER_OUTPUT - if self.cache_config.kv_sharing_fast_prefill: - assert not self.input_batch.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect " - "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs" - ) - - ( - attn_metadata, - positions, - num_scheduled_tokens_np, - num_input_tokens, - num_tokens_across_dp, - maybe_padded_num_tokens, - logits_indices, - spec_decode_metadata, - input_ids_prep, - inputs_embeds_prep, - intermediate_tensors_prep, - max_query_len, - ) = self._prepare_inputs(scheduler_output, intermediate_tensors) - - if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "eplb_updator"): - self.eplb_updator.forward_before() - - if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "eplb_updator"): - self.eplb_updator.take_update_info_from_eplb_process() - - with_prefill = getattr(self, "with_prefill", True) - - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - - input_ids = input_ids_prep - inputs_embeds = inputs_embeds_prep - intermediate_tensors = intermediate_tensors_prep - - per_req_additional_information: Optional[dict[str, dict]] = None - model_kwargs = self._init_model_kwargs(num_input_tokens) - - if ( - self.supports_mm_inputs - and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder - ): - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - - per_req_additional_information = {} - for req_index, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) - addi_cpu = getattr(req_state, "additional_information_cpu", None) - num_computed_tokens = int( - self.input_batch.num_computed_tokens_cpu[req_index] - ) - prompt_len = len(req_state.prompt_token_ids) - prompt_remaining = max(0, prompt_len - num_computed_tokens) - sched_tokens = int(num_scheduled_tokens_np[req_index]) - overlay_len = min(sched_tokens, prompt_remaining) - if overlay_len <= 0: - continue - if pe_cpu is not None and inputs_embeds is not None: - src = pe_cpu[ - num_computed_tokens : num_computed_tokens + overlay_len - ].to(dtype=self.dtype, device=self.device, non_blocking=True) - start_offset = int(self.query_start_loc_cpu[req_index]) - inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) - if addi_cpu is not None and isinstance(addi_cpu, dict): - req_info: dict[str, object] = {} - for k, v in addi_cpu.items(): - if isinstance(v, torch.Tensor): - if overlay_len > 0: - try: - seg = v[ - num_computed_tokens : num_computed_tokens + overlay_len - ].detach().to("cpu").contiguous() - except Exception: - seg = v.detach().to("cpu").contiguous() - req_info[k] = seg - else: - req_info[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, list): - req_info[k] = v + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + with ProfileExecuteDuration().capture_async("prepare input"): + self._update_states(scheduler_output) + + # Omni-new: Decode per-request prompt_embeds / additional_hidden_states payloads + # (if present) into CPU tensors + try: + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + if new_reqs: + for nr in new_reqs: + req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) + if req_id is None: + continue + # prompt_embeds + payload_pe = getattr(nr, "prompt_embeds", None) + if payload_pe is not None: + if isinstance(payload_pe, torch.Tensor): + pe_cpu = payload_pe.detach().to("cpu").contiguous() + elif isinstance(payload_pe, PromptEmbedsPayload): + dt = np.dtype(getattr(payload_pe, "dtype", "float32")) + arr = np.frombuffer(payload_pe.data, dtype=dt) + arr = arr.reshape(payload_pe.shape) + pe_cpu = torch.from_numpy(arr) else: - req_info[k] = v - per_req_additional_information[req_id] = req_info - - model_kwargs.update(self._extract_mm_kwargs(scheduler_output)) - - if ( - self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs - ): - encoder_inputs = self._extract_encoder_inputs(scheduler_output) - model_kwargs.update(encoder_inputs) + pe_cpu = None + if pe_cpu is not None and req_id in self.requests: + setattr( + self.requests[req_id], + "prompt_embeds_cpu", + pe_cpu, + ) + # additional_information + payload_info = getattr(nr, "additional_information", None) + if payload_info is not None: + info_dict = {} + if isinstance(payload_info, dict): + # Already decoded + info_dict = payload_info + elif isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict and req_id in self.requests: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding prompt_embeds / additional_information: {e}") + pass - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - scheduler_output.total_num_scheduled_tokens - == self.input_batch.num_reqs * max_query_len - ) - batch_descriptor = BatchDescriptor( - num_tokens=num_input_tokens, uniform_decode=uniform_decode - ) - if hasattr(self, "aclgraph_dispatcher"): - aclgraph_runtime_mode, batch_descriptor = ( - self.aclgraph_dispatcher.dispatch(batch_descriptor) - ) - elif hasattr(self, "cudagraph_dispatcher"): - aclgraph_runtime_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch(batch_descriptor) - ) - else: - from vllm.config import CUDAGraphMode - aclgraph_runtime_mode = CUDAGraphMode.NONE + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + logger.debug("skip this step for we receive the data from remote disaggregate prefill node") + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output) - moe_comm_type = None - if hasattr(self, "_select_moe_comm_method"): - moe_comm_type = self._select_moe_comm_method(num_input_tokens, with_prefill) + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + ( + attn_metadata, + positions, + num_scheduled_tokens_np, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_query_len, + per_req_additional_information, + ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + moe_comm_type = self._select_moe_comm_method(num_input_tokens, self.with_prefill) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) + aclgraph_runtime_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch(batch_descriptor) - with ( - set_ascend_forward_context( + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=with_prefill, - reserved_mc2_mask=getattr(self, "reserved_mc2_mask", None), + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output.total_num_scheduled_tokens, - prefetch_stream=getattr(self, "prefetch_stream", None), + prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=getattr(self, "weight_prefetch_method", None), - ), - record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, - ): - - model_kwargs_extra = {} - if per_req_additional_information: - model_kwargs_extra["additional_information_by_req_id"] = per_req_additional_information - try: - per_req_runtime_info = [] - for req_id in self.input_batch.req_ids: - req_state = self.requests.get(req_id) - info = ( - getattr(req_state, "additional_information_cpu", None) - if req_state is not None - else None - ) - per_req_runtime_info.append(info if isinstance(info, dict) else {}) - model_kwargs_extra["runtime_additional_information"] = ( - per_req_runtime_info + weight_prefetch_method=self.weight_prefetch_method, + ): + self.maybe_setup_kv_connector(scheduler_output) + + # Omni-new + model_kwargs_extra = {} + # Pass per-request additional information map for this step (no concat) + if per_req_additional_information: + model_kwargs_extra["additional_information_by_req_id"] = per_req_additional_information + # Always pass per-request runtime additional_information (persisted in request state) + try: + per_req_runtime_info = [] + for req_id in self.input_batch.req_ids: + req_state = self.requests.get(req_id) + info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + per_req_runtime_info.append(info if isinstance(info, dict) else {}) + model_kwargs_extra["runtime_additional_information"] = per_req_runtime_info + model_kwargs_extra["request_ids"] = self.input_batch.req_ids + # Pass each request's token span within the flattened sequence for this step, + # enabling the model to map decode/prefill by request + req_token_spans = [] + for req_index in range(len(self.input_batch.req_ids)): + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + req_token_spans.append((start_offset, start_offset + sched_tokens)) + model_kwargs_extra["request_token_spans"] = req_token_spans + except Exception: + pass + + hidden_states = self._generate_process_reqs_hidden_states( + attn_metadata, + self.with_prefill, + maybe_padded_num_tokens, + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + model_kwargs_extra, ) - model_kwargs_extra["request_ids"] = self.input_batch.req_ids - req_token_spans = [] - for req_index in range(len(self.input_batch.req_ids)): - start_offset = int(self.query_start_loc_cpu[req_index]) - sched_tokens = int(num_scheduled_tokens_np[req_index]) - req_token_spans.append((start_offset, start_offset + sched_tokens)) - model_kwargs_extra["request_token_spans"] = req_token_spans - except Exception: - pass - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - sampling_metadata=self.input_batch.sampling_metadata, - logits_index=logits_indices, - sampler=self.sampler, - **model_kwargs_extra, - ) - with record_function_or_nullcontext("Postprocess"): - hidden_states = model_output - aux_hidden_states = None + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) - hidden_states, multimodal_outputs = self.extract_multimodal_outputs( - hidden_states - ) + aux_hidden_states = None + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, aux_hidden_states = hidden_states + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) + finished_sending = None + finished_recving = None + with ProfileExecuteDuration().capture_async("post process"): + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + # Omni-new + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + # The model side may return per-request additional_information updates (model-agnostic channel). + # Convention: multimodal_outputs["additional_information_update"] is a list[dict] in batch order; + # the runner merges it into the corresponding request's additional_information_cpu for subsequent decode. try: - if ( - isinstance(multimodal_outputs, dict) - and ( - "additional_information_update" in multimodal_outputs - or "additional_information_update_by_req_id" in multimodal_outputs - ) + if isinstance(multimodal_outputs, dict) and ( + "additional_information_update" in multimodal_outputs + or "additional_information_update_by_req_id" in multimodal_outputs ): - updates_list = multimodal_outputs.get( - "additional_information_update" - ) + # Option A: list[dict] in batch order + updates_list = multimodal_outputs.get("additional_information_update") if isinstance(updates_list, list): for idx, upd in enumerate(updates_list): if not isinstance(upd, dict) or idx >= len(self.input_batch.req_ids): continue req_id = self.input_batch.req_ids[idx] self._merge_additional_information_update(req_id, upd) - updates_map = multimodal_outputs.get( - "additional_information_update_by_req_id" - ) + # Option B: dict[str, dict] keyed by req_id + updates_map = multimodal_outputs.get("additional_information_update_by_req_id") if isinstance(updates_map, dict): for req_id, upd in updates_map.items(): if not isinstance(upd, dict): @@ -419,220 +634,267 @@ def execute_model( continue self._merge_additional_information_update(req_id, upd) except Exception as e: - logger.error(f"Error merging for requests:{self.input_batch.req_ids} additional information update: {e}, with the multimodal_outputs as {multimodal_outputs}") - broadcast_pp_output = getattr(self, "broadcast_pp_output", False) - if not broadcast_pp_output: - - if not get_pp_group().is_last_rank: - assert isinstance(hidden_states, IntermediateTensors) + logger.error( + f"Error merging for requests:{self.input_batch.req_ids} additional \ + information update: {e}, with the multimodal_outputs as {multimodal_outputs}" + ) + broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: hidden_states.kv_connector_output = kv_connector_output return hidden_states - - if self.is_pooling_model: - output = self._pool( - hidden_states, num_scheduled_tokens, num_scheduled_tokens_np, - finished_sending=None, finished_recving=None, - kv_connector_output=kv_connector_output + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict(hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, + finished_sending, + finished_recving, + kv_connector_output, ) - return output - sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) - else: - assert not self.is_pooling_model - - if not get_pp_group().is_last_rank: - all_gather_tensors = { - "residual": not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens - ) + if broadcast_pp_output: + model_output_broadcast_data = ( + { + "logits": logits.contiguous(), } - get_pp_group().send_tensor_dict( - hidden_states.tensors, - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors, - ) - logits = None - sample_hidden_states = None - else: - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states) - - model_output_broadcast_data = {} - if logits is not None: - model_output_broadcast_data["logits"] = logits.contiguous() - + if logits is not None + else {} + ) model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 ) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] - if sample_hidden_states is None: - sample_hidden_states = hidden_states[logits_indices] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask( - scheduler_output, self.input_batch, logits, self.device + logits = self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[: self.input_batch.num_reqs] + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, ) + else: + if lmhead_tp_enable() and logits is not None: + logits = logits[: len(spec_decode_metadata.logits_indices)] + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + if self.need_accepted_tokens: + self._update_states_after_model_execute(output_token_ids) + + discard_sampled_tokens_req_indices: list[int] = [] + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] + if seq_len < req_state.num_tokens: + # Ignore the sampled token. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + discard_sampled_tokens_req_indices.append(i) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[: scheduler_output.total_num_scheduled_tokens], + scheduler_output, + ) - with record_function_or_nullcontext("Sample"): - sampler_output = self._sample(logits, spec_decode_metadata) - - def propose_draft_token_ids(sampled_token_ids): - if not hasattr(self, "propose_draft_token_ids") or spec_decode_metadata is None: - return - with record_function_or_nullcontext("Draft"): - if isinstance(sampled_token_ids, torch.Tensor): - if sampled_token_ids.dim() == 1: - sampled_token_ids_list = [[t.item()] for t in sampled_token_ids] - else: - sampled_token_ids_list = sampled_token_ids.tolist() - elif isinstance(sampled_token_ids, list) and len(sampled_token_ids) > 0: - if not isinstance(sampled_token_ids[0], list): - sampled_token_ids_list = [[t] for t in sampled_token_ids] - else: - sampled_token_ids_list = sampled_token_ids + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + else: + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set + } + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if req_idx not in invalid_req_indices_set else None else: - sampled_token_ids_list = sampled_token_ids - + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}" + ) + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + if self.speculative_config: self._draft_token_ids = self.propose_draft_token_ids( - sampled_token_ids_list, - self.input_batch.sampling_metadata, + valid_sampled_token_ids, + sampling_metadata, scheduler_output, spec_decode_metadata, positions, - num_scheduled_tokens, + scheduler_output.total_num_scheduled_tokens, hidden_states, attn_metadata, aux_hidden_states, ) - use_padded_batch_for_eagle = ( - self.speculative_config - and self.speculative_config.use_eagle() - and not self.speculative_config.disable_padded_drafter_batch - ) - effective_drafter_max_model_len = self.model_config.max_model_len - if effective_drafter_max_model_len is None: - effective_drafter_max_model_len = self.model_config.max_model_len - if ( - self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len is not None - ): - effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len - ) - input_fits_in_drafter = False - if spec_decode_metadata is not None and attn_metadata: - try: - first_layer_metadata = next(iter(attn_metadata.values())) if attn_metadata else None - if first_layer_metadata and hasattr(first_layer_metadata, "seq_lens"): - max_seq_len = first_layer_metadata.seq_lens.max().item() - input_fits_in_drafter = ( - max_seq_len + self.speculative_config.num_speculative_tokens - <= effective_drafter_max_model_len - ) - except Exception: - pass - - if use_padded_batch_for_eagle and input_fits_in_drafter: - sampled_tokens = sampler_output.sampled_token_ids - if isinstance(sampled_tokens, torch.Tensor): - if sampled_tokens.dim() == 1: - sampled_tokens_list = [[t.item()] for t in sampled_tokens] - else: - sampled_tokens_list = sampled_tokens.tolist() - else: - sampled_tokens_list = sampled_tokens - propose_draft_token_ids(sampled_tokens_list) - - with record_function_or_nullcontext("Bookkeep"): - ( - num_nans_in_logits, - logprobs_lists, - valid_sampled_token_ids, - prompt_logprobs_dict, - req_ids_output_copy, - req_id_to_index_output_copy, - invalid_req_indices, - ) = self._bookkeeping_sync( - scheduler_output, - sampler_output, - logits, - hidden_states, - num_scheduled_tokens, - ) - - if ( - self.speculative_config - and not use_padded_batch_for_eagle - and input_fits_in_drafter - ): - propose_draft_token_ids(valid_sampled_token_ids) - - with record_function_or_nullcontext("EPLB"): - if hasattr(self, "eplb_step"): - self.eplb_step() + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + # Omni-new: Convert to per-request tensors on CPU hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() - pooler_output: list[Optional[torch.Tensor]] = [] + pooler_output: list[torch.Tensor | None] = [] prev_logits_index = 0 for logits_index in logits_indices: - pooler_output.append( - hidden_states_cpu[prev_logits_index : logits_index + 1] - ) + pooler_output.append(hidden_states_cpu[prev_logits_index : logits_index + 1]) prev_logits_index = logits_index + 1 + # Omni-new output = OmniModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=( - pooler_output - if self.vllm_config.model_config.engine_output_type != "text" - else None - ), + pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, ) + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() if not self.use_async_scheduling: return output return AsyncNPUModelRunnerOutput( model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, + sampled_token_ids=sampled_token_ids, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) - def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: - req_state = self.requests.get(req_id) - if req_state is None: - return - existing = getattr(req_state, "additional_information_cpu", {}) - if not isinstance(existing, dict): - existing = {} - merged = dict(existing) - for k, v in upd.items(): - if isinstance(v, torch.Tensor): - merged[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, list): - new_list = [] - for item in v: - if isinstance(item, torch.Tensor): - new_list.append(item.detach().to("cpu").contiguous()) - else: - new_list.append(item) - merged[k] = new_list - else: - merged[k] = v - setattr(req_state, "additional_information_cpu", merged) - - - + def _generate_process_reqs_hidden_states( + self, + attn_metadata, + with_prefill, + maybe_padded_num_tokens, + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + model_kwargs_extra, + ): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs_extra, + ) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params( + self.update_stream, forward_context, maybe_padded_num_tokens, self.speculative_config + ) + else: + update_attn_params(self.update_stream, forward_context, maybe_padded_num_tokens) + + if get_forward_context().sp_enabled: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] + return hidden_states diff --git a/vllm_omni/worker/npu_diffusion_model_runner.py b/vllm_omni/worker/npu_diffusion_model_runner.py index d9bac70ca90..390cbe550ff 100644 --- a/vllm_omni/worker/npu_diffusion_model_runner.py +++ b/vllm_omni/worker/npu_diffusion_model_runner.py @@ -2,230 +2,491 @@ from __future__ import annotations -import gc -import logging -from typing import Any, List, Optional, Union +import math +from typing import TYPE_CHECKING, Any import numpy as np import torch - +import torch.nn as nn from vllm.config import CUDAGraphMode +from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import BatchDescriptor -from vllm.logger import init_logger +from vllm.logger import logger from vllm.multimodal.inputs import MultiModalKwargs -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import ( + cdiv, +) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.spec_decode.eagle import EagleProposer -from vllm.v1.utils import record_function_or_nullcontext -from vllm.v1.worker.gpu_model_runner import ( + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, +) + +# yapf: enable +from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, - IntermediateTensors, - PerLayerAttnMetadata, - get_pp_group, + AsyncModelRunnerOutput, + ModelRunnerOutput, ) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.utils import lmhead_tp_enable +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.utils import ( + ProfileExecuteDuration, + enable_sp, + lmhead_tp_enable, +) from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput -from vllm.v1.worker.ubatch_utils import UBatchSlices -from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs + from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.npu_model_runner import OmniNPUModelRunner -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput class NPUDiffusionModelRunner(OmniNPUModelRunner): """Diffusion model runner for vLLM-omni on NPU (non-autoregressive).""" - def _preprocess( + def _prepare_inputs( self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ubatch_slices: Optional[UBatchSlices] = None, - num_tokens_after_padding: Optional[torch.Tensor] = None, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, ) -> tuple[ + dict[str, Any], + torch.Tensor, + np.ndarray, int, + torch.Tensor, int, - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], + SpecDecodeMetadata, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int, dict[str, Any], ]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array( + [ + num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32, + ) + + if self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]: + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph(total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size + else: + # Eager mode. + # (NOTE)Omni-new: Maybe we should remove the below ACLGraph logic + # But for eager, the logic is consistent with GPUDiffusionModelRunner + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo) = self._sync_metadata_across_dp( + num_input_tokens, with_prefill, enable_dbo + ) + + # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens + # We should consider removing maybe_padded_num_tokens later + num_input_tokens = maybe_padded_num_tokens + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - num_input_tokens = scheduler_output.total_num_scheduled_tokens - num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) - if ( - self.supports_mm_inputs - and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder - ): + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True + ) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + self.query_start_loc[: num_reqs + 1].copy_(self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.query_start_loc[num_reqs + 1 :].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_(self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.attn_state = attn_state # type: ignore + + self.with_prefill = with_prefill + self.num_tokens_across_dp = num_tokens_across_dp + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) + attn_metadata: dict[str, Any] = {} + + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids[:num_input_tokens], - multimodal_embeddings=mm_embeds or None, - ) - - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds_scheduled) - input_ids = self.input_ids[:num_input_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings(input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] + # NOTE: Need to model_kwargs = { **self._init_model_kwargs(num_input_tokens), **self._extract_mm_kwargs(scheduler_output), } - elif self.enable_prompt_embeds and get_pp_group().is_first_rank: - if hasattr(self, "is_token_ids"): - token_ids_idx = ( - self.is_token_ids[:num_input_tokens] - .nonzero(as_tuple=False) - .squeeze(1) - ) - if token_ids_idx.numel() > 0: - token_ids = self.input_ids[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) - self.inputs_embeds[token_ids_idx] = tokens_to_embeds - - inputs_embeds = self.inputs_embeds[:num_input_tokens] - model_kwargs = self._init_model_kwargs(num_input_tokens) - input_ids = None + # (NOTE) Omni-new: input_ids isn't set as None else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) - if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] - else: - positions = self.positions[:num_input_tokens] + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens + ) if get_pp_group().is_first_rank: intermediate_tensors = None else: - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_(v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors( + {k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items()} ) - if ( - self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs - ): - encoder_inputs = self._extract_encoder_inputs(scheduler_output) - model_kwargs.update(encoder_inputs) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to(self.device, non_blocking=True) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata(num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[:total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + self.slot_mapping[total_num_scheduled_tokens:].fill_(0) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, + ) + + if self.speculative_config and spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder) or self.model_config.runner_type == "pooling": + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args, + ) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) return ( + attn_metadata, + positions, num_input_tokens, num_input_tokens, - num_tokens_after_padding, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, input_ids, inputs_embeds, - positions, intermediate_tensors, + max_num_scheduled_tokens, model_kwargs, ) @torch.inference_mode() def execute_model( self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[OmniModelRunnerOutput, IntermediateTensors]: - with record_function_or_nullcontext("Preprocess"): - with self.synchronize_input_prep(): - super()._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - return EMPTY_MODEL_RUNNER_OUTPUT - - ( - attn_metadata, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np, - spec_decode_common_attn_metadata, - max_query_len, - ubatch_slices, - num_tokens_after_padding, - ) = self._prepare_inputs(scheduler_output) - - if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "eplb_updator"): - self.eplb_updator.forward_before() - - if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "eplb_updator"): - self.eplb_updator.take_update_info_from_eplb_process() + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + with ProfileExecuteDuration().capture_async("prepare input"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + if self.dynamic_eplb: + self.eplb_updator.forward_before() ( - num_scheduled_tokens, + attn_metadata, + positions, + num_scheduled_tokens_np, num_input_tokens, num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, input_ids, inputs_embeds, - positions, intermediate_tensors, + max_query_len, model_kwargs, - ) = self._preprocess( - scheduler_output, - intermediate_tensors, - ubatch_slices, - num_tokens_after_padding, - ) + ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + moe_comm_type = self._select_moe_comm_method(num_input_tokens, self.with_prefill) + + # Omni-new: don't use cudagraph_dispatcher + # and remove ubatch_slices aclgraph_runtime_mode = CUDAGraphMode.NONE - with ( - set_ascend_forward_context( + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=True, # Diffusion models process all tokens at once + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=None, num_actual_tokens=scheduler_output.total_num_scheduled_tokens, - ), - record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, - ): - - outputs = self._run_diffusion( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - multimodal_kwargs=model_kwargs, - logits_indices=logits_indices, - ) + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + self.maybe_setup_kv_connector(scheduler_output) + + outputs = self._run_diffusion( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + multimodal_kwargs=model_kwargs, + logits_indices=logits_indices, + ) - if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "eplb_updator"): - self.eplb_updator.forward_end() + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) + + aux_hidden_states = None + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, aux_hidden_states = outputs + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) + finished_sending = None + finished_recving = None + # Omni-new: extract_multimodal_outputs _, multimodal_outputs = self.extract_multimodal_outputs(outputs) - pooler_output: List[Optional[torch.Tensor]] = [] + # Ensure one tensor per request, map to CPU for output struct + pooler_output: list[torch.Tensor | None] = [] if isinstance(multimodal_outputs, torch.Tensor): + # If model returned a single stacked tensor, split by requests assert multimodal_outputs.shape[0] == self.input_batch.num_reqs for i in range(self.input_batch.num_reqs): - pooler_output.append( - multimodal_outputs[i].detach().to("cpu").contiguous() - ) + pooler_output.append(multimodal_outputs[i].detach().to("cpu").contiguous()) elif isinstance(multimodal_outputs, list): for out in multimodal_outputs: - pooler_output.append( - out.detach().to("cpu").contiguous() if out is not None else None - ) + pooler_output.append(out.detach().to("cpu").contiguous() if out is not None else None) elif isinstance(multimodal_outputs, dict): for out in multimodal_outputs.values(): - pooler_output.append( - out.detach().to("cpu").contiguous() if out is not None else None - ) + pooler_output.append(out.detach().to("cpu").contiguous() if out is not None else None) else: raise RuntimeError("Unsupported diffusion output type") @@ -240,6 +501,15 @@ def execute_model( num_nans_in_logits={}, ) + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + + if self.dynamic_eplb: + self.eplb_updator.forward_end() + if not self.use_async_scheduling: return output @@ -255,11 +525,11 @@ def _run_diffusion( *, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor], + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, multimodal_kwargs: dict, logits_indices: torch.Tensor, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: """Runs the diffusion process and returns per-request tensors. Tries model interfaces in the following order for maximal compatibility: @@ -287,301 +557,185 @@ def _run_diffusion( "'forward', or 'diffuse'. Please implement one of them or adapt the runner." ) - @torch.inference_mode() def _dummy_run( self, num_tokens: int, - cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + with_prefill: bool = False, + is_torchair_compile: bool = False, + aclgraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, - skip_eplb: bool = False, - is_profile: bool = False, - create_mixed_batch: bool = False, - remove_lora: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Run a dummy forward pass to warm up/profile run or capture the ACL graph for the model.""" - assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { + ) -> torch.Tensor: + # only support eager mode and piecewise graph now + assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL, } - max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size - assert num_tokens <= self.scheduler_config.max_num_batched_tokens - max_num_reqs = self.scheduler_config.max_num_seqs + # Force dummy run on prefill stage when this node is deemed as kv producer. + if self.is_kv_producer and not self.is_kv_consumer: + with_prefill = True - num_reqs = min(num_tokens, max_num_reqs) - min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp( + num_tokens, with_prefill, False + ) + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.max_num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + num_reqs = (num_tokens + self.decode_token_per_req - 1) // self.decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - num_tokens_after_padding = None - - if num_tokens_after_padding is None: - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens_after_padding = num_tokens + num_pad - else: - num_tokens_across_dp = num_tokens_after_padding - num_tokens_after_padding = int(num_tokens_after_padding[0].item()) + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() - original_in_profile_run = self.in_profile_run - self.in_profile_run = is_profile - - if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "eplb_updator"): - self.eplb_updator.forward_before() + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None - need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) - dummy_indices = None - dummy_compute_logits = None - - if need_dummy_logits: - max_num_reqs_across_dp = num_tokens # Diffusion always uses with_prefill=True - dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32, device=self.device) - - def dummy_compute_logits(hidden_states): - return self.model.compute_logits(hidden_states[dummy_indices]) - - attn_metadata: Optional[PerLayerAttnMetadata] = None - - if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - attn_metadata = {} - - seq_lens = max_query_len - if hasattr(self, "seq_lens_np"): - self.seq_lens_np[:num_reqs] = seq_lens - self.seq_lens_np[num_reqs:] = 0 - if hasattr(self, "seq_lens_cpu"): - self.seq_lens_cpu[:num_reqs] = seq_lens - self.seq_lens_cpu[num_reqs:] = 0 - if hasattr(self, "seq_lens"): - self.seq_lens[:num_reqs].copy_( - self.seq_lens_cpu[:num_reqs], non_blocking=True) - self.seq_lens[num_reqs:].fill_(0) - - cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) - if hasattr(self, "query_start_loc_np"): - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1 : num_reqs + 1] = cum_num_tokens - if hasattr(self, "query_start_loc_cpu"): - self.query_start_loc_cpu[0] = 0 - self.query_start_loc_cpu[1 : num_reqs + 1] = cum_num_tokens - if hasattr(self, "query_start_loc"): - self.query_start_loc[:num_reqs + 1].copy_( - self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups - ): - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[: num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs - ], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch.block_table[ - kv_cache_group_id - ].get_device_tensor(num_reqs), - slot_mapping=self.input_batch.block_table[ - kv_cache_group_id - ].slot_mapping[:num_tokens], - causal=True, - ) - for attn_group in self.attn_groups[kv_cache_group_id]: + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] - assert type(attn_metadata) is dict - attn_metadata_i = attn_group.get_metadata_builder().build_for_cudagraph_capture( - common_attn_metadata + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, dtype=self.dtype, device=self.device ) - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - try: - with self.maybe_dummy_run_with_lora( - self.lora_config, num_scheduled_tokens, remove_lora - ): - model_kwargs = self._init_model_kwargs(num_tokens) - if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - model_kwargs = { - **model_kwargs, - **self._dummy_mm_kwargs(num_reqs), - } - elif self.enable_prompt_embeds: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - model_kwargs = self._init_model_kwargs(num_tokens) - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None + intermediate_tensors = IntermediateTensors( + {k: v[:num_tokens] for k, v in self.intermediate_tensors.items()} + ) - if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] - else: - positions = self.positions[:num_tokens] + # filter out the valid batch descriptor + _ag_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, uniform_decode=uniform_decode) + ) + if aclgraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + assert aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode == _ag_mode, ( + f"Aclgraph runtime mode mismatch at dummy_run. " + f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}." + ) + else: + aclgraph_runtime_mode = _ag_mode + + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( + False, + num_reqs=num_reqs, + num_tokens=num_tokens, + max_query_len=max_query_len, + aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, + ) - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device, - ) - ) + need_dummy_logits = not self.in_profile_run and lmhead_tp_enable() - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False - ) + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) - _cg_mode, batch_descriptor = ( - self.aclgraph_dispatcher.dispatch( - BatchDescriptor( - num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode, - ) - ) - if hasattr(self, "aclgraph_dispatcher") and not is_profile - else (CUDAGraphMode.NONE, None) - ) - if cudagraph_runtime_mode is not None: - assert ( - cudagraph_runtime_mode == CUDAGraphMode.NONE - or cudagraph_runtime_mode == _cg_mode - ), ( - f"ACL graph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." - ) - else: - cudagraph_runtime_mode = _cg_mode + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices]) - with self.maybe_randomize_inputs(input_ids), set_ascend_forward_context( + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill, + is_torchair_compile, + input_ids, + positions, attn_metadata, - self.vllm_config, - num_tokens=num_tokens_after_padding, - num_tokens_across_dp=num_tokens_across_dp, - with_prefill=True, # Diffusion models process all tokens at once - in_profile_run=self.in_profile_run, - aclgraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ): - hidden_states = self._generate_dummy_run_hidden_states( - with_prefill=True, - is_torchair_compile=False, - input_ids=input_ids, - positions=positions, - attn_metadata=attn_metadata, - num_tokens=num_tokens, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - model_kwargs=model_kwargs, - ) - + num_tokens, + intermediate_tensors, + inputs_embeds, + ) if need_dummy_logits: dummy_compute_logits(hidden_states) - if self.drafter: - self.drafter.dummy_run(num_tokens) - if need_dummy_logits: - self.drafter.model.compute_logits(hidden_states[dummy_indices]) - - if self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "model"): - self.model.clear_all_moe_loads() - if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: - if hasattr(self, "eplb_updator"): - self.eplb_updator.take_update_info_from_eplb_process() - finally: - self.in_profile_run = original_in_profile_run - - if not skip_eplb: - self.eplb_step(is_dummy=True, is_profile=is_profile) - - hidden_states, _ = self.extract_multimodal_outputs(hidden_states) - return hidden_states, None - - @torch.inference_mode() - def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: - logger.warning("Dummy sampler run is not implemented for diffusion model") - return None - - def profile_run(self) -> None: - if self.supports_mm_inputs: - if self.model_config.multimodal_config.skip_mm_profiling: - logger.info( - "Skipping memory profiling for multimodal encoder and " - "encoder cache." + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + skip_attn=True, + num_reqs=num_reqs, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, ) - else: - mm_budget = self.mm_budget - assert mm_budget is not None - - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ - dummy_modality - ] - - logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the " - "maximum feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) - - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) - - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs - ) - - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) - - encoder_output_shape = dummy_encoder_outputs[0].shape - if encoder_output_shape[0] < encoder_budget: - expanded_outputs = [] - for output in dummy_encoder_outputs: - expanded = output.new_zeros( - (encoder_budget, encoder_output_shape[-1]) - ) - num_tokens = output.shape[0] - expanded[:num_tokens].copy_(output) - expanded_outputs.append(expanded) - - dummy_encoder_outputs = expanded_outputs - - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - - hidden_states, _ = self._dummy_run(self.max_num_tokens, is_profile=True) - if get_pp_group().is_last_rank: - pass - self._sync_device() - del hidden_states - self.encoder_cache.clear() - gc.collect() - - + if need_dummy_logits: + self.drafter.model.compute_logits(hidden_states[dummy_indices]) + if self.in_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + return hidden_states From 428246c0fb91492024e3240130bd60236ee2e268 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Wed, 26 Nov 2025 04:36:42 +0000 Subject: [PATCH 10/11] fix the NPU adaptation issue in RMSNorm Signed-off-by: gcanlin --- .../models/qwen2_5_omni/qwen2_5_omni.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py index 4b29d266f6a..229702e5805 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -244,11 +244,17 @@ def forward( if inputs_embeds is not None and inputs_embeds.device != thinker_dev: inputs_embeds = inputs_embeds.to(thinker_dev) # Run thinker + + # FIXME: a temporary method to fix the NPU adaptation issue, need more discussion. + thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids + thinker_positions = positions[0] if positions.ndim > 1 else positions + thinker_inputs_embeds = inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds + thinker_output = self.thinker( - input_ids=input_ids, - positions=positions[0], + input_ids=thinker_input_ids, + positions=thinker_positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, + inputs_embeds=thinker_inputs_embeds, **kwargs, ) From 5c0028e511b4b749e26f901c2e554baa9ba3a8b4 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Wed, 26 Nov 2025 07:41:32 +0000 Subject: [PATCH 11/11] fix SamplingParams loading bug Signed-off-by: gcanlin --- vllm_omni/entrypoints/stage_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index fa7fedd6946..f7fb23b8a3b 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -7,7 +7,7 @@ import os import pickle from multiprocessing import shared_memory as _shm - +from omegaconf import OmegaConf import cloudpickle logger = logging.getLogger(__name__)