diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 655d0307fbf..69083cb26de 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -1,10 +1,11 @@ +from typing import Any, Optional, Sequence, Union +import logging import multiprocessing as mp +from concurrent.futures import ThreadPoolExecutor, as_completed import os +import queue import sys import time -from collections.abc import Sequence -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Optional, Union import cloudpickle from pydantic import ValidationError @@ -25,17 +26,19 @@ from vllm_omni.engine.arg_utils import OmniEngineArgs from vllm_omni.engine.output_processor import MultimodalOutputProcessor from vllm_omni.engine.processor import OmniProcessor +from vllm_omni.entrypoints.omni_stage import OmniStage +from vllm_omni.entrypoints.utils import load_stage_configs_from_model, select_worker_class from vllm_omni.entrypoints.log_utils import ( - OrchestratorMetrics, + remove_old_logs, configure_orchestrator_logger, init_stats_paths, - remove_old_logs, + OrchestratorMetrics, +) +from vllm_omni.entrypoints.stage_utils import ( + maybe_load_from_ipc as _load, + encode_for_ipc as _encode, + serialize_obj as _ser, ) -from vllm_omni.entrypoints.omni_stage import OmniStage -from vllm_omni.entrypoints.stage_utils import encode_for_ipc as _encode -from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load -from vllm_omni.entrypoints.stage_utils import serialize_obj as _set -from vllm_omni.entrypoints.utils import load_stage_configs_from_model, load_stage_configs_from_yaml from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -49,25 +52,23 @@ class OmniLLM: - max_inflight=1 per stage (serial within a stage), but pipeline across stages """ - def __init__( - self, - model: str, - stage_configs_path: Optional[str] = None, - log_stats: bool = False, - log_file: Optional[str] = None, - init_sleep_seconds: int = 20, - shm_threshold_bytes: int = 65536, - batch_timeout: int = 10, - init_timeout: int = 300, - **kwargs, - ): + def __init__(self, model: str, + stage_configs=None, + log_stats: bool = False, + log_file: Optional[str] = None, + init_sleep_seconds: int = 20, + shm_threshold_bytes: int = 65536, + batch_timeout: int = 10, + init_timeout: int = 300, + **kwargs): self.batch_timeout = batch_timeout self._enable_stats: bool = bool(log_stats) # Do NOT call super().__init__ to avoid creating OmniStageLLM instances in parent. - if stage_configs_path is None: + if stage_configs is None: self.stage_configs = load_stage_configs_from_model(model) else: - self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) + self.stage_configs = stage_configs + # Optional file handler for orchestrator self._log_file = log_file @@ -78,15 +79,8 @@ def __init__( self._stats_file, self._overall_stats_file = init_stats_paths(self._enable_stats, self._log_file) self._initialize_stages(model, init_sleep_seconds, shm_threshold_bytes, init_timeout) - def _initialize_stages( - self, - model: str, - init_sleep_seconds: int, - shm_threshold_bytes: int, - init_timeout: int, - ) -> None: + def _initialize_stages(self, model: str, init_sleep_seconds: int, shm_threshold_bytes: int, init_timeout: int) -> None: self.stage_list: list[OmniStage] = [] - # Build OmniStage instances in parallel, preserve original order def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: idx, cfg = idx_cfg @@ -110,6 +104,7 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: # Wait for all stages to report readiness before seeding self._stages_ready: set[int] = set() self._wait_for_stages_ready(timeout=init_timeout) + def _start_stage_processes(self, model: str) -> None: for stage_id, stage in enumerate(self.stage_list): @@ -136,10 +131,7 @@ def close(self) -> None: try: q.put_nowait(None) except Exception as e: - logger.warning( - "[Orchestrator] Failed to send shutdown signal to stage input queue: %s", - e, - ) + logger.warning("[Orchestrator] Failed to send shutdown signal to stage input queue: %s", e) for stage in self.stage_list: try: stage.stop_stage_worker() @@ -155,26 +147,17 @@ def __del__(self) -> None: # best-effort def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - sampling_params_list: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, - ) -> list[OmniRequestOutput]: - try: - return self._run_generation(prompts, sampling_params_list) - except Exception as e: - logger.exception("[Orchestrator] Failed to run generation: %s", e) - raise e - finally: - self.close() - - def _run_generation( - self, - prompts: Union[PromptType, Sequence[PromptType]], - sampling_params_list: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, + sampling_params_list: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, ) -> list[OmniRequestOutput]: logger.debug("[Orchestrator] generate() called") if sampling_params_list is None: raise ValueError("sampling_params_list is required for pipelined generation") if len(sampling_params_list) != len(self.stage_list): - raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + raise ValueError( + f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}" + ) # Normalize prompts to a list for per-request iteration if not isinstance(prompts, (list, tuple)): @@ -193,6 +176,7 @@ def _run_generation( # Track per-request start time for end-to-end timing _req_start_ts: dict[int, float] = {} _wall_start_ts: float = time.time() + _last_finish_ts: float = _wall_start_ts # Determine the final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) final_stage_id_for_e2e = -1 @@ -203,20 +187,10 @@ def _run_generation( if final_stage_id_for_e2e < 0: final_stage_id_for_e2e = len(self.stage_list) - 1 except Exception as e: - logger.debug( - "[Orchestrator] Failed to determine final stage for E2E; falling back to last: %s", - e, - exc_info=True, - ) + logger.debug("[Orchestrator] Failed to determine final stage for E2E; falling back to last: %s", e, exc_info=True) final_stage_id_for_e2e = len(self.stage_list) - 1 # Metrics/aggregation helper - metrics = OrchestratorMetrics( - num_stages, - self._enable_stats, - self._stats_file, - self._overall_stats_file, - _wall_start_ts, - ) + metrics = OrchestratorMetrics(num_stages, self._enable_stats, self._stats_file, self._overall_stats_file, _wall_start_ts) # Seed stage-0 queue with all requests logger.debug("[Orchestrator] Seeding %d requests into stage-0", len(request_prompts)) @@ -240,11 +214,7 @@ def _run_generation( completed_requests = 0 total_requests = len(request_prompts) - logger.debug( - "[Orchestrator] Entering scheduling loop: total_requests=%d, stages=%d", - total_requests, - num_stages, - ) + logger.debug("[Orchestrator] Entering scheduling loop: total_requests=%d, stages=%d", total_requests, num_stages) while completed_requests < total_requests: made_progress = False for stage_id, stage in enumerate(self.stage_list): @@ -255,17 +225,11 @@ def _run_generation( made_progress = True req_id = result.get("request_id") if "error" in result: - logger.error( - "Stage %s error on request %s: %s", - stage_id, - req_id, - result["error"], - ) + logger.error("Stage %s error on request %s: %s", stage_id, req_id, result["error"]) continue - + if result.get("type") == "stage_ready": - # Only happens when stage is initialized slower than expected, - # so we wait for a short time and try again + #Only happens when stage is initialized slower than expected, so we wait for a short time and try again time.sleep(0.05) continue @@ -277,17 +241,8 @@ def _run_generation( if _m is not None: metrics.on_stage_metrics(stage_id, req_id, _m) except Exception as e: - logger.exception( - "[Orchestrator] Failed to process metrics for stage %s, req %s: %s", - stage_id, - req_id, - e, - ) - logger.debug( - "[Orchestrator] Stage-%s completed request %s; forwarding or finalizing", - stage_id, - req_id, - ) + logger.exception("[Orchestrator] Failed to process metrics for stage %s, req %s: %s", stage_id, req_id, e) + logger.debug("[Orchestrator] Stage-%s completed request %s; forwarding or finalizing", stage_id, req_id) stage.set_engine_outputs(engine_outputs) if getattr(stage, "final_output", False): @@ -298,30 +253,15 @@ def _run_generation( request_output=engine_outputs, ) ) - logger.debug( - "[Orchestrator] Request %s finalized at stage-%s", - req_id, - stage_id, - ) + logger.debug("[Orchestrator] Request %s finalized at stage-%s", req_id, stage_id) - # End-to-end timing and time-per-token for final output - # (only once per request at the designated final stage) + # End-to-end timing and time-per-token for final output (only once per request at the designated final stage) try: rid_int = int(req_id) if isinstance(req_id, (int, str)) and str(req_id).isdigit() else req_id if stage_id == final_stage_id_for_e2e and rid_int not in metrics.e2e_done: - metrics.on_finalize_request( - stage_id, - req_id, - engine_outputs, - _req_start_ts.get(req_id, _wall_start_ts), - ) + metrics.on_finalize_request(stage_id, req_id, engine_outputs, _req_start_ts.get(req_id, _wall_start_ts)) except Exception as e: - logger.exception( - "[Orchestrator] Finalize request handling error for req %s at stage %s: %s", - req_id, - stage_id, - e, - ) + logger.exception("[Orchestrator] Finalize request handling error for req %s at stage %s: %s", req_id, stage_id, e) next_stage_id = stage_id + 1 if next_stage_id < num_stages: @@ -332,7 +272,7 @@ def _run_generation( # Measure transfer size and time (encode + enqueue) size_bytes = 0 try: - size_bytes = len(_set(next_inputs)) + size_bytes = len(_ser(next_inputs)) except Exception: size_bytes = 0 t0 = time.time() @@ -342,51 +282,27 @@ def _run_generation( obj_key="engine_inputs", shm_key="engine_inputs_shm", ) - ipc_payload.update( - { - "request_id": req_id, - "sampling_params": sp_next, - "sent_ts": time.time(), - } - ) + ipc_payload.update({ + "request_id": req_id, + "sampling_params": sp_next, + "sent_ts": time.time(), + }) self.stage_list[next_stage_id].submit(ipc_payload) t1 = time.time() tx_ms = (t1 - t0) * 1000.0 - metrics.on_forward( - stage_id, - next_stage_id, - req_id, - int(size_bytes), - float(tx_ms), - bool("engine_inputs_shm" in ipc_payload), - ) + metrics.on_forward(stage_id, next_stage_id, req_id, int(size_bytes), float(tx_ms), bool("engine_inputs_shm" in ipc_payload)) except Exception as e: - logger.warning( - "[Orchestrator] IPC encode failed for req %s: %s; falling back to inline payload", - req_id, - e, - ) - self.stage_list[next_stage_id].submit( - { - "request_id": req_id, - "engine_inputs": next_inputs, - "sampling_params": sp_next, - } - ) - logger.debug( - "[Orchestrator] Forwarded request %s to stage-%s", - req_id, - next_stage_id, - ) + logger.warning("[Orchestrator] IPC encode failed for req %s: %s; falling back to inline payload", req_id, e) + self.stage_list[next_stage_id].submit({ + "request_id": req_id, + "engine_inputs": next_inputs, + "sampling_params": sp_next, + }) + logger.debug("[Orchestrator] Forwarded request %s to stage-%s", req_id, next_stage_id) remaining_by_stage[next_stage_id] += 1 else: completed_requests += 1 - logger.debug( - "[Orchestrator] Request %s fully completed (%d/%d)", - req_id, - completed_requests, - total_requests, - ) + logger.debug("[Orchestrator] Request %s fully completed (%d/%d)", req_id, completed_requests, total_requests) if not made_progress: time.sleep(0.005) @@ -425,9 +341,7 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: not_ready = sorted(set(range(num_stages)) - set(self._stages_ready)) logger.warning( "[Orchestrator] Initialization timeout: only %s/%s stages are ready; not ready: %s", - len(self._stages_ready), - num_stages, - not_ready, + len(self._stages_ready), num_stages, not_ready, ) # Provide actionable suggestions before shutdown try: @@ -438,7 +352,9 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: "Increase initialization wait time (init_sleep_seconds or call-site timeout).", ] if getattr(self, "_log_file", None): - suggestions.append(f"Inspect per-stage log files for details: {self._log_file}.stage.log") + suggestions.append( + f"Inspect per-stage log files for details: {self._log_file}.stage.log" + ) logger.error( "[Orchestrator] Stage initialization failed, shutting down. Suggestions:\n- %s", "\n- ".join(suggestions), @@ -468,9 +384,13 @@ class OmniStageLLM(LLM): def __init__( self, model: str, - compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + compilation_config: Optional[ + Union[int, dict[str, Any], CompilationConfig] + ] = None, hf_overrides: Optional[HfOverrides] = None, - structured_outputs_config: Optional[Union[dict[str, Any], StructuredOutputsConfig]] = None, + structured_outputs_config: Optional[ + Union[dict[str, Any], StructuredOutputsConfig] + ] = None, **kwargs, ): """LLM constructor.""" @@ -479,12 +399,18 @@ def __init__( if "worker_cls" in kwargs: worker_cls = kwargs["worker_cls"] + # Select appropriate worker class based on device type + if isinstance(worker_cls, str): + worker_cls = select_worker_class(worker_cls) + kwargs["worker_cls"] = worker_cls # if the worker_cls is not qualified string name, # we serialize it using cloudpickle to avoid pickling issues - if isinstance(worker_cls, type): + elif isinstance(worker_cls, type): kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) - if "kv_transfer_config" in kwargs and isinstance(kwargs["kv_transfer_config"], dict): + if "kv_transfer_config" in kwargs and isinstance( + kwargs["kv_transfer_config"], dict + ): from vllm.config.kv_transfer import KVTransferConfig raw_config_dict = kwargs["kv_transfer_config"] @@ -492,7 +418,8 @@ def __init__( kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict) except ValidationError as e: logger.error( - "Failed to convert 'kv_transfer_config' dict to KVTransferConfig object. Dict: %s. Error: %s", + "Failed to convert 'kv_transfer_config' dict to " + "KVTransferConfig object. Dict: %s. Error: %s", raw_config_dict, e, ) @@ -505,10 +432,16 @@ def __init__( if compilation_config is not None: if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig(level=compilation_config) + compilation_config_instance = CompilationConfig( + level=compilation_config + ) elif isinstance(compilation_config, dict): compilation_config_instance = CompilationConfig( - **{k: v for k, v in compilation_config.items() if is_init_field(CompilationConfig, k)} + **{ + k: v + for k, v in compilation_config.items() + if is_init_field(CompilationConfig, k) + } ) else: compilation_config_instance = compilation_config @@ -518,7 +451,11 @@ def __init__( if structured_outputs_config is not None: if isinstance(structured_outputs_config, dict): structured_outputs_instance = StructuredOutputsConfig( - **{k: v for k, v in structured_outputs_config.items() if is_init_field(StructuredOutputsConfig, k)} + **{ + k: v + for k, v in structured_outputs_config.items() + if is_init_field(StructuredOutputsConfig, k) + } ) else: structured_outputs_instance = structured_outputs_config @@ -534,7 +471,9 @@ def __init__( ) # Create the Engine (autoselects V0 vs V1) - self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + self.llm_engine = LLMEngine.from_engine_args( + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS + ) self.llm_engine.output_processor = MultimodalOutputProcessor( tokenizer=self.llm_engine.tokenizer, log_stats=self.llm_engine.log_stats, @@ -556,4 +495,6 @@ def __init__( # Load the Input/Output processor plugin if any io_processor_plugin = self.llm_engine.model_config.io_processor_plugin - self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin) + self.io_processor = get_io_processor( + self.llm_engine.vllm_config, io_processor_plugin + ) \ No newline at end of file diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 481c4821359..861235c5ed9 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -31,7 +31,7 @@ _to_dict, maybe_dump_to_shm, maybe_load_from_ipc_with_metrics, - set_stage_gpu_devices, + set_stage_devices, ) from vllm_omni.inputs.data import OmniTokensPrompt @@ -281,7 +281,9 @@ def filter(self, record: _logging.LogRecord) -> bool: # Device mapping try: - set_stage_gpu_devices(stage_id, runtime_cfg.get("devices")) + from vllm_omni.utils import detect_device_type + device_type = detect_device_type() + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: _logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e) @@ -556,7 +558,9 @@ def filter(self, record: _logging.LogRecord) -> bool: # Device mapping try: - set_stage_gpu_devices(stage_id, runtime_cfg.get("devices")) + from vllm_omni.utils import detect_device_type + device_type = detect_device_type() + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: _logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e) @@ -727,4 +731,4 @@ def filter(self, record: _logging.LogRecord) -> bool: "stage_id": stage_id, "error": str(e), } - ) + ) \ No newline at end of file diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 14c266c85f9..185e9997e16 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -776,4 +776,4 @@ def _create_audio_choice(self, omni_outputs: OmniRequestOutput, role: str): stop_reason=None, ) choices.append(choice_data) - return choices + return choices \ No newline at end of file diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 36ad2b0ab1c..f7fb23b8a3b 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -1,122 +1,144 @@ from __future__ import annotations -import json +from typing import Any, Dict, Optional, Tuple, Union + import logging +import json import os import pickle from multiprocessing import shared_memory as _shm -from typing import Any - -import cloudpickle from omegaconf import OmegaConf +import cloudpickle logger = logging.getLogger(__name__) -def set_stage_gpu_devices(stage_id: int, devices: str | int | None) -> None: - """Configure per-stage CUDA visibility and current device. - - Behavior - - Comma-separated string (e.g. "2,5,7"): set CUDA_VISIBLE_DEVICES exactly - to this list; logical index 0 is used as current device. - - Integer or digit-string: treat as logical index (0-based) into the current - CUDA_VISIBLE_DEVICES mapping; map to the physical device, and then set - CUDA_VISIBLE_DEVICES to this single device. - - None/"cpu": keep default visibility. - - Otherwise: set CUDA_VISIBLE_DEVICES to the provided single device string. +def set_stage_devices( + stage_id: int, + devices: Optional[Union[str, int]], + device_type: Optional[str] = None, +) -> None: + """Configure per-stage device visibility and current device (CUDA or NPU). + + This function sets environment variables that control which devices are visible + to the process, and sets the current device. It must be called BEFORE worker + initialization so that workers see the correct devices. + + Args: + stage_id: Stage identifier for logging + devices: Device specification: + - Comma-separated string (e.g. "2,5,7"): set device visibility env var + exactly to this list; logical index 0 is used as current device. + - Integer or digit-string: treat as logical index (0-based) into the + current device visibility mapping; map to physical device, then set + env var to this single device. + - None/"cpu": keep default visibility. + - Otherwise: set env var to the provided single device string. + device_type: Device type ("cuda" or "npu"). If None, auto-detects. + + Behavior: + - CUDA: Sets CUDA_VISIBLE_DEVICES and calls torch.cuda.set_device() + - NPU: Sets ASCEND_RT_VISIBLE_DEVICES and calls torch.npu.set_device() """ + from vllm_omni.utils import detect_device_type, get_device_control_env_var + + if device_type is None: + device_type = detect_device_type() + + env_var = get_device_control_env_var() + + # Select device-specific torch functions + if device_type == "npu": + try: + import torch.npu # type: ignore[import-untyped] + except ImportError: + logger.debug("[Stage-%s] torch.npu not available, skipping NPU device setup", stage_id) + return + + is_available_fn = torch.npu.is_available + set_device_fn = torch.npu.set_device + device_count_fn = torch.npu.device_count + get_device_properties_fn = torch.npu.get_device_properties + mem_get_info_fn = torch.npu.mem_get_info + get_device_name_fn = torch.npu.get_device_name + device_type_label = "NPU" + elif device_type == "cuda": + import torch # noqa: WPS433 + + is_available_fn = torch.cuda.is_available + set_device_fn = torch.cuda.set_device + device_count_fn = torch.cuda.device_count + get_device_properties_fn = torch.cuda.get_device_properties + mem_get_info_fn = torch.cuda.mem_get_info + get_device_name_fn = torch.cuda.get_device_name + device_type_label = "CUDA" + else: + logger.debug("[Stage-%s] Unsupported device type: %s", stage_id, device_type) + return + try: - selected_physical: int | None = None - logical_idx: int | None = None + selected_physical: Optional[int] = None + logical_idx: Optional[int] = None if isinstance(devices, str) and "," in devices: - os.environ["CUDA_VISIBLE_DEVICES"] = devices + os.environ[env_var] = devices toks = [t.strip() for t in devices.split(",") if t.strip() != ""] if toks: try: selected_physical = int(toks[0]) logger.debug( - "[Stage-%s] Set CUDA_VISIBLE_DEVICES to %s; logical 0 -> physical %s", - stage_id, - devices, - selected_physical, + "[Stage-%s] Set %s to %s; logical 0 -> physical %s", + stage_id, env_var, devices, selected_physical, ) except Exception as e: - logger.debug("[Stage-%s] Failed to parse first CUDA device: %s", stage_id, e) + logger.debug("[Stage-%s] Failed to parse first %s device: %s", stage_id, device_type_label, e) selected_physical = None elif isinstance(devices, (int, str)) and (isinstance(devices, int) or str(devices).isdigit()): logical_idx = max(0, int(devices)) - vis = os.environ.get("CUDA_VISIBLE_DEVICES") + vis = os.environ.get(env_var) if vis: try: mapping = [int(x) for x in vis.split(",") if x.strip() != ""] if 0 <= logical_idx < len(mapping): selected_physical = mapping[logical_idx] except Exception as e: - logger.debug( - "[Stage-%s] Failed to map logical index via CUDA_VISIBLE_DEVICES: %s", - stage_id, - e, - ) + logger.debug("[Stage-%s] Failed to map logical index via %s: %s", stage_id, env_var, e) selected_physical = None if selected_physical is None: selected_physical = int(logical_idx) - os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical) + os.environ[env_var] = str(selected_physical) logger.debug( - "[Stage-%s] Logical index %d -> physical %s; set CUDA_VISIBLE_DEVICES to single device", - stage_id, - logical_idx + 1, - selected_physical, + "[Stage-%s] Logical index %d -> physical %s; set %s to single device", + stage_id, logical_idx + 1, selected_physical, env_var, ) elif devices in (None, "cpu"): - logger.debug( - "[Stage-%s] Using default device visibility (devices=%s)", - stage_id, - devices, - ) + logger.debug("[Stage-%s] Using default device visibility (devices=%s)", stage_id, devices) else: selected_physical = int(str(devices)) - os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical) - logger.debug( - "[Stage-%s] Set CUDA_VISIBLE_DEVICES to single device %s (fallback)", - stage_id, - selected_physical, - ) + os.environ[env_var] = str(selected_physical) + logger.debug("[Stage-%s] Set %s to single device %s (fallback)", stage_id, env_var, selected_physical) try: import torch # noqa: WPS433 - - if torch.cuda.is_available(): + if is_available_fn(): try: - torch.cuda.set_device(0) + set_device_fn(0) except Exception as e: - logger.debug( - "[Stage-%s] torch.cuda.set_device(0) failed: %s", - stage_id, - e, - exc_info=True, - ) - num = torch.cuda.device_count() + logger.debug("[Stage-%s] %s set_device(0) failed: %s", stage_id, device_type_label, e, exc_info=True) + num = device_count_fn() info = [] for i in range(num): - total = torch.cuda.get_device_properties(i).total_memory - free, _ = torch.cuda.mem_get_info(i) - info.append( - { - "idx": i, - "name": torch.cuda.get_device_name(i), - "total": int(total), - "free": int(free), - } - ) - logger.debug("[Stage-%s] CUDA devices visible=%s info=%s", stage_id, num, info) + total = get_device_properties_fn(i).total_memory + free, _ = mem_get_info_fn(i) + info.append({ + "idx": i, + "name": get_device_name_fn(i), + "total": int(total), + "free": int(free), + }) + logger.debug("[Stage-%s] %s devices visible=%s info=%s", stage_id, device_type_label, num, info) except Exception as e: - logger.debug( - "[Stage-%s] Failed to query CUDA devices: %s", - stage_id, - e, - exc_info=True, - ) + logger.debug("[Stage-%s] Failed to query %s devices: %s", stage_id, device_type_label, e, exc_info=True) except Exception as e: logger.warning("Failed to interpret devices for stage %s: %s", stage_id, e) @@ -126,7 +148,7 @@ def serialize_obj(obj: Any) -> bytes: return cloudpickle.dumps(obj) -def shm_write_bytes(payload: bytes) -> dict[str, Any]: +def shm_write_bytes(payload: bytes) -> Dict[str, Any]: """Write bytes into SharedMemory and return meta dict {name,size}. Caller should close the segment; the receiver should unlink. @@ -143,7 +165,7 @@ def shm_write_bytes(payload: bytes) -> dict[str, Any]: return meta -def shm_read_bytes(meta: dict[str, Any]) -> bytes: +def shm_read_bytes(meta: Dict[str, Any]) -> bytes: """Read bytes from SharedMemory by meta {name,size} and cleanup.""" shm = _shm.SharedMemory(name=meta["name"]) # type: ignore[index] mv = memoryview(shm.buf) @@ -170,7 +192,7 @@ def _ensure_parent_dir(path: str) -> None: pass -def append_jsonl(path: str, record: dict[str, Any]) -> None: +def append_jsonl(path: str, record: Dict[str, Any]) -> None: """Append a JSON record as one line to a JSONL file (best-effort). This is safe to call from multiple processes when each process writes @@ -187,7 +209,7 @@ def append_jsonl(path: str, record: dict[str, Any]) -> None: logger.exception("Failed to append JSONL to %s", path) -def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]: +def maybe_dump_to_shm(obj: Any, threshold: int) -> Tuple[bool, Any]: """Dump object to SHM if serialized size exceeds threshold. Returns (True, meta) when dumped; otherwise (False, original_obj). @@ -198,7 +220,7 @@ def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]: return False, obj -def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) -> Any: +def maybe_load_from_ipc(container: Dict[str, Any], obj_key: str, shm_key: str) -> Any: """Load object from container that may carry SHM or inline object. Deprecated: prefer `maybe_load_from_ipc_with_metrics` to also obtain @@ -209,9 +231,7 @@ def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) - return container[obj_key] -def maybe_load_from_ipc_with_metrics( - container: dict[str, Any], obj_key: str, shm_key: str -) -> tuple[Any, dict[str, float]]: +def maybe_load_from_ipc_with_metrics(container: Dict[str, Any], obj_key: str, shm_key: str) -> tuple[Any, Dict[str, float]]: """Load object and return (object, metrics) with RX bytes and decode time. Metrics keys: @@ -219,7 +239,6 @@ def maybe_load_from_ipc_with_metrics( - rx_decode_time_ms: float """ import time as _time # local import to avoid overhead at module import - t0 = _time.time() if shm_key in container: meta = container[shm_key] # type: ignore[index] @@ -243,13 +262,13 @@ def maybe_load_from_ipc_with_metrics( } -def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict[str, Any]: +def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> Dict[str, Any]: """Return a dict payload for IPC: inline (obj_key) or SHM (shm_key). When serialized size exceeds threshold, returns {shm_key: {name,size}}; otherwise returns {obj_key: obj}. """ - payload: dict[str, Any] = {} + payload: Dict[str, Any] = {} use_shm, data = maybe_dump_to_shm(obj, threshold) if use_shm: payload[shm_key] = data @@ -259,7 +278,7 @@ def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict # Convert OmegaConf/objects to plain dicts -def _to_dict(x: Any) -> dict[str, Any]: +def _to_dict(x: Any) -> Dict[str, Any]: try: if isinstance(x, dict): return dict(x) @@ -268,4 +287,4 @@ def _to_dict(x: Any) -> dict[str, Any]: try: return dict(x) except Exception: - return {} + return {} \ No newline at end of file diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 7662304c068..e0b10c93c47 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -1,13 +1,11 @@ -from __future__ import annotations - -import logging import os from pathlib import Path +from typing import Optional from omegaconf import OmegaConf -from vllm.transformers_utils.config import get_config -logger = logging.getLogger(__name__) +from vllm.transformers_utils.config import get_config +from vllm_omni.utils import detect_device_type # Get the project root directory (2 levels up from this file) PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -29,3 +27,27 @@ def load_stage_configs_from_yaml(config_path: str): """Load stage configs from yaml file.""" config_data = OmegaConf.load(config_path) return config_data.stage_args + + +def select_worker_class(worker_cls: Optional[str], device_type: Optional[str] = None) -> Optional[str]: + """Select appropriate worker class based on device type.""" + if worker_cls is None: + return None + + if device_type is None: + device_type = detect_device_type() + + if device_type == "npu": + # Replace module path: gpu_ar_worker -> npu_ar_worker + if "gpu_ar_worker" in worker_cls: + worker_cls = worker_cls.replace("gpu_ar_worker", "npu_ar_worker") + elif "gpu_diffusion_worker" in worker_cls: + worker_cls = worker_cls.replace("gpu_diffusion_worker", "npu_diffusion_worker") + + # Replace class name: GPUARWorker -> NPUARWorker, GPUDiffusionWorker -> NPUDiffusionWorker + if "GPUARWorker" in worker_cls: + worker_cls = worker_cls.replace("GPUARWorker", "NPUARWorker") + elif "GPUDiffusionWorker" in worker_cls: + worker_cls = worker_cls.replace("GPUDiffusionWorker", "NPUDiffusionWorker") + + return worker_cls diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py index b650712b59f..7e98ead6e3b 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -235,12 +235,17 @@ def forward( positions = positions.to(thinker_dev) if inputs_embeds is not None and inputs_embeds.device != thinker_dev: inputs_embeds = inputs_embeds.to(thinker_dev) + + thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids + thinker_positions = positions[0] if positions.ndim > 1 else positions + thinker_inputs_embeds = inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds + # Run thinker thinker_output = self.thinker( - input_ids=input_ids, - positions=positions[0], + input_ids=thinker_input_ids, + positions=thinker_positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, + inputs_embeds=thinker_inputs_embeds, **kwargs, ) @@ -858,7 +863,7 @@ def _init_token2wav_model(self, hf_model_folder): __init__.""" if self.token2wav is None or self.token2wav_config is None: return - device = "cuda" if torch.cuda.is_available() else "cpu" + device = self._module_device(self.token2wav) # optional speaker resources conds = getattr(self.token2wav_config, "conds", None) ref_mels = getattr(self.token2wav_config, "ref_mels", None) @@ -1019,19 +1024,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Load token2wav weights (if any) if token2wav_weights and self.token2wav is not None: # download weights from huggingface for spk_dict.pt - model_path = self.vllm_config.model_config.model - download_dir = self.vllm_config.load_config.download_dir - if os.path.exists(model_path): - hf_model_folder = model_path - else: - hf_model_folder = download_weights_from_hf_specific( - model_path, - download_dir, - allow_patterns=["*.pt"], - ) + + hf_model_folder = download_weights_from_hf_specific( + self.vllm_config.model_config.model, + self.vllm_config.load_config.download_dir, + allow_patterns=["*.pt"], + ) self._init_token2wav_model(hf_model_folder) t2w_loaded = self.token2wav.load_weights(token2wav_weights, os.path.join(hf_model_folder, "spk_dict.pt")) t2w_loaded = add_prefix_to_loaded_weights(t2w_loaded, "token2wav") loaded_weights.update(t2w_loaded) - return loaded_weights + return loaded_weights \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py index 28107cb455e..6dc1553c9ca 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -1056,4 +1056,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) except Exception: pass - return loaded_weights + return loaded_weights \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py index f1109a9a7b1..32d4067826e 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py @@ -724,17 +724,17 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> else: beta = 0.0 - kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32, device='cpu') # Compute time indices if is_even: - time_indices = torch.arange(-half_size, half_size) + 0.5 + time_indices = torch.arange(-half_size, half_size, device='cpu') + 0.5 else: - time_indices = torch.arange(kernel_size) - half_size + time_indices = torch.arange(kernel_size, device='cpu') - half_size # Compute sinc filter if cutoff == 0: - return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + return torch.zeros((1, 1, kernel_size), dtype=torch.float32, device='cpu') # Ensures correct shape sinc_filter = torch.sinc(2 * cutoff * time_indices) normalized_filter = 2 * cutoff * kaiser_window * sinc_filter @@ -744,6 +744,28 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> return normalized_filter.view(1, 1, kernel_size) +def replication_pad_1d(hidden_states: torch.Tensor, pad_left: int, + pad_right: int) -> torch.Tensor: + """Manual replicate padding to avoid replication_pad1d kernel limits on NPU.""" + # NOTE: a immature implmentation for running in Ascend NPU. Need to discuss. + if pad_left == 0 and pad_right == 0: + return hidden_states + + segments = [] + if pad_left > 0: + left = hidden_states[..., :1].expand( + *hidden_states.shape[:-1], pad_left) + segments.append(left) + + segments.append(hidden_states) + + if pad_right > 0: + right = hidden_states[..., -1:].expand( + *hidden_states.shape[:-1], pad_right) + segments.append(right) + + return torch.cat(segments, dim=-1) + class UpSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None): @@ -760,14 +782,21 @@ def __init__(self, ratio=2, kernel_size=None): def forward(self, hidden_states): channels = hidden_states.shape[1] - hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to(self.filter.dtype) + input_dtype = hidden_states.dtype + # F.pad in Ascend doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d( + hidden_states.to(self.filter.dtype), self.pad, self.pad + ) + filter_on_device = self.filter.to( + device=hidden_states.device, dtype=hidden_states.dtype + ) hidden_states = self.ratio * F.conv_transpose1d( hidden_states, - self.filter.expand(channels, -1, -1), + filter_on_device.expand(channels, -1, -1), stride=self.stride, groups=channels, - ).to(hidden_states_dtype) + ).to(input_dtype) hidden_states = hidden_states[..., self.pad_left : -self.pad_right] return hidden_states @@ -793,14 +822,21 @@ def __init__(self, ratio=2, kernel_size=None): def forward(self, hidden_states): channels = hidden_states.shape[1] - hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate").to(self.filter.dtype) + input_dtype = hidden_states.dtype + # F.pad in Ascend doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d( + hidden_states.to(self.filter.dtype), self.pad_left, self.pad_right + ) + filter_on_device = self.filter.to( + device=hidden_states.device, dtype=hidden_states.dtype + ) out = F.conv1d( hidden_states, - self.filter.expand(channels, -1, -1), + filter_on_device.expand(channels, -1, -1), stride=self.stride, groups=channels, - ).to(hidden_states_dtype) + ).to(input_dtype) return out @@ -1768,4 +1804,4 @@ def process_chunk( prev_generated=(prev_generated if isinstance(prev_generated, torch.Tensor) else None), finished=finished, ) - return _mel if _mel is not None else prev_generated, out + return _mel if _mel is not None else prev_generated, out \ No newline at end of file diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 8e6977950ea..a2d569c8f7b 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -52,4 +52,4 @@ for model_arch, (mod_folder, mod_relname, cls_name) in _OMNI_MODELS.items() }, } -) +) \ No newline at end of file diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index ce2d281fb70..98b5d29e35b 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -60,4 +60,4 @@ def thinker2talker( mm_processor_kwargs=None, ) ) - return talker_inputs + return talker_inputs \ No newline at end of file diff --git a/vllm_omni/utils/__init__.py b/vllm_omni/utils/__init__.py index e69de29bb2d..50dbb478d90 100644 --- a/vllm_omni/utils/__init__.py +++ b/vllm_omni/utils/__init__.py @@ -0,0 +1,11 @@ +from vllm_omni.utils.platform_utils import ( + detect_device_type, + get_device_control_env_var, + is_npu, +) + +__all__ = [ + "detect_device_type", + "get_device_control_env_var", + "is_npu", +] diff --git a/vllm_omni/utils/platform_utils.py b/vllm_omni/utils/platform_utils.py new file mode 100644 index 00000000000..1379ad61248 --- /dev/null +++ b/vllm_omni/utils/platform_utils.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Callable + +import torch + +from vllm.platforms import current_platform + + +def detect_device_type() -> str: + device_type = getattr(current_platform, "device_type", None) + if isinstance(device_type, str) and device_type: + return device_type.lower() + if torch.cuda.is_available(): + return "cuda" + if hasattr(torch, "npu") and torch.npu.is_available(): # type: ignore[attr-defined] + return "npu" + return "cpu" + + +def is_npu() -> bool: + return detect_device_type() == "npu" + + +def get_device_control_env_var() -> str: + """Return the environment variable name for device visibility control.""" + if hasattr(current_platform, "device_control_env_var"): + env_var = getattr(current_platform, "device_control_env_var", None) + if isinstance(env_var, str) and env_var: + return env_var + + device_type = detect_device_type() + if device_type == "npu": + return "ASCEND_RT_VISIBLE_DEVICES" + return "CUDA_VISIBLE_DEVICES" # fallback \ No newline at end of file diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 55e5e1391ed..ecc9a3fa8f3 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -580,4 +580,4 @@ def _dummy_run( logit_indices = np.cumsum(num_scheduled_tokens) - 1 hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) - return hidden_states, hidden_states[logit_indices] + return hidden_states, hidden_states[logit_indices] \ No newline at end of file diff --git a/vllm_omni/worker/npu_ar_model_runner.py b/vllm_omni/worker/npu_ar_model_runner.py new file mode 100644 index 00000000000..e56b515c4c6 --- /dev/null +++ b/vllm_omni/worker/npu_ar_model_runner.py @@ -0,0 +1,900 @@ +"""AR NPU Model Runner for vLLM-omni.""" + +from __future__ import annotations + +import math +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from vllm.config import CUDAGraphMode +from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import logger +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.core.sched.output import SchedulerOutput + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, +) + +# yapf: enable +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + ModelRunnerOutput, +) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.compilation.acl_graph import update_attn_params, update_mla_attn_params +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.utils import ( + ProfileExecuteDuration, + enable_sp, + lmhead_tp_enable, +) +from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput + +from vllm_omni.engine import AdditionalInformationPayload, PromptEmbedsPayload +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.npu_model_runner import OmniNPUModelRunner + + +class NPUARModelRunner(OmniNPUModelRunner): + """Autoregressive NPU model runner that returns hidden states per request.""" + + def _prepare_inputs( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> tuple[ + dict[str, Any], + torch.Tensor, + np.ndarray, + int, + torch.Tensor, + int, + torch.Tensor, + SpecDecodeMetadata, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int, + dict[str, dict] | None, + ]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array( + [ + num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32, + ) + + if self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]: + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph(total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo) = self._sync_metadata_across_dp( + num_input_tokens, with_prefill, enable_dbo + ) + + # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens + # We should consider removing maybe_padded_num_tokens later + num_input_tokens = maybe_padded_num_tokens + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True + ) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + self.query_start_loc[: num_reqs + 1].copy_(self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.query_start_loc[num_reqs + 1 :].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_(self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.attn_state = attn_state # type: ignore + + self.with_prefill = with_prefill + self.num_tokens_across_dp = num_tokens_across_dp + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) + attn_metadata: dict[str, Any] = {} + + # Omni-new: per_req_additional_information + per_req_additional_information: dict[str, dict] | None = None + + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:total_num_scheduled_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings(input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:total_num_scheduled_tokens].copy_(inputs_embeds) + + # Omni-new: Reset per-step additional information collector (deprecated concat path) + if hasattr(self, "_forward_additional_information"): + self._forward_additional_information = None + # Omni-new: per-request additional information for this step + per_req_additional_information = {} + + # Omni-new: Overlay custom prompt_embeds per request for the prompt portion; + # collect additional_information (tensor/list) for prefill portion only + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + addi_cpu = getattr(req_state, "additional_information_cpu", None) + num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + if pe_cpu is not None: + src = pe_cpu[num_computed_tokens : num_computed_tokens + overlay_len].to( + dtype=self.dtype, device=self.device, non_blocking=True + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + # Build per-request additional information (no cross-request concat) + if addi_cpu is not None and isinstance(addi_cpu, dict): + req_info: dict[str, object] = {} + for k, v in addi_cpu.items(): + if isinstance(v, torch.Tensor): + # For prefill tokens, pass only the scheduled slice; + # for decode or no scheduled tokens, pass whole tensor + if overlay_len > 0: + try: + seg = ( + v[num_computed_tokens : num_computed_tokens + overlay_len] + .detach() + .to("cpu") + .contiguous() + ) + except Exception: + seg = v.detach().to("cpu").contiguous() + req_info[k] = seg + else: + req_info[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + req_info[k] = v + else: + req_info[k] = v + per_req_additional_information[req_id] = req_info + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = self.input_ids[:num_input_tokens] + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens + ) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_(v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors( + {k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items()} + ) + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to(self.device, non_blocking=True) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata(num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[:total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + self.slot_mapping[total_num_scheduled_tokens:].fill_(0) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, + ) + + if self.speculative_config and spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder) or self.model_config.runner_type == "pooling": + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args, + ) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) + + return ( + attn_metadata, + positions, + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_num_scheduled_tokens, + per_req_additional_information, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + with ProfileExecuteDuration().capture_async("prepare input"): + self._update_states(scheduler_output) + + # Omni-new: Decode per-request prompt_embeds / additional_hidden_states payloads + # (if present) into CPU tensors + try: + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + if new_reqs: + for nr in new_reqs: + req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) + if req_id is None: + continue + # prompt_embeds + payload_pe = getattr(nr, "prompt_embeds", None) + if payload_pe is not None: + if isinstance(payload_pe, torch.Tensor): + pe_cpu = payload_pe.detach().to("cpu").contiguous() + elif isinstance(payload_pe, PromptEmbedsPayload): + dt = np.dtype(getattr(payload_pe, "dtype", "float32")) + arr = np.frombuffer(payload_pe.data, dtype=dt) + arr = arr.reshape(payload_pe.shape) + pe_cpu = torch.from_numpy(arr) + else: + pe_cpu = None + if pe_cpu is not None and req_id in self.requests: + setattr( + self.requests[req_id], + "prompt_embeds_cpu", + pe_cpu, + ) + # additional_information + payload_info = getattr(nr, "additional_information", None) + if payload_info is not None: + info_dict = {} + if isinstance(payload_info, dict): + # Already decoded + info_dict = payload_info + elif isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict and req_id in self.requests: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding prompt_embeds / additional_information: {e}") + pass + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + logger.debug("skip this step for we receive the data from remote disaggregate prefill node") + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output) + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + ( + attn_metadata, + positions, + num_scheduled_tokens_np, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_query_len, + per_req_additional_information, + ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + moe_comm_type = self._select_moe_comm_method(num_input_tokens, self.with_prefill) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) + aclgraph_runtime_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch(batch_descriptor) + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + self.maybe_setup_kv_connector(scheduler_output) + + # Omni-new + model_kwargs_extra = {} + # Pass per-request additional information map for this step (no concat) + if per_req_additional_information: + model_kwargs_extra["additional_information_by_req_id"] = per_req_additional_information + # Always pass per-request runtime additional_information (persisted in request state) + try: + per_req_runtime_info = [] + for req_id in self.input_batch.req_ids: + req_state = self.requests.get(req_id) + info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + per_req_runtime_info.append(info if isinstance(info, dict) else {}) + model_kwargs_extra["runtime_additional_information"] = per_req_runtime_info + model_kwargs_extra["request_ids"] = self.input_batch.req_ids + # Pass each request's token span within the flattened sequence for this step, + # enabling the model to map decode/prefill by request + req_token_spans = [] + for req_index in range(len(self.input_batch.req_ids)): + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + req_token_spans.append((start_offset, start_offset + sched_tokens)) + model_kwargs_extra["request_token_spans"] = req_token_spans + except Exception: + pass + + hidden_states = self._generate_process_reqs_hidden_states( + attn_metadata, + self.with_prefill, + maybe_padded_num_tokens, + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + model_kwargs_extra, + ) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) + + aux_hidden_states = None + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, aux_hidden_states = hidden_states + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) + finished_sending = None + finished_recving = None + with ProfileExecuteDuration().capture_async("post process"): + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + # Omni-new + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + # The model side may return per-request additional_information updates (model-agnostic channel). + # Convention: multimodal_outputs["additional_information_update"] is a list[dict] in batch order; + # the runner merges it into the corresponding request's additional_information_cpu for subsequent decode. + try: + if isinstance(multimodal_outputs, dict) and ( + "additional_information_update" in multimodal_outputs + or "additional_information_update_by_req_id" in multimodal_outputs + ): + # Option A: list[dict] in batch order + updates_list = multimodal_outputs.get("additional_information_update") + if isinstance(updates_list, list): + for idx, upd in enumerate(updates_list): + if not isinstance(upd, dict) or idx >= len(self.input_batch.req_ids): + continue + req_id = self.input_batch.req_ids[idx] + self._merge_additional_information_update(req_id, upd) + # Option B: dict[str, dict] keyed by req_id + updates_map = multimodal_outputs.get("additional_information_update_by_req_id") + if isinstance(updates_map, dict): + for req_id, upd in updates_map.items(): + if not isinstance(upd, dict): + continue + if req_id not in self.requests: + continue + self._merge_additional_information_update(req_id, upd) + except Exception as e: + logger.error( + f"Error merging for requests:{self.input_batch.req_ids} additional \ + information update: {e}, with the multimodal_outputs as {multimodal_outputs}" + ) + broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict(hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, + finished_sending, + finished_recving, + kv_connector_output, + ) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + if broadcast_pp_output: + model_output_broadcast_data = ( + { + "logits": logits.contiguous(), + } + if logits is not None + else {} + ) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + logits = self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[: self.input_batch.num_reqs] + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + if lmhead_tp_enable() and logits is not None: + logits = logits[: len(spec_decode_metadata.logits_indices)] + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + if self.need_accepted_tokens: + self._update_states_after_model_execute(output_token_ids) + + discard_sampled_tokens_req_indices: list[int] = [] + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] + if seq_len < req_state.num_tokens: + # Ignore the sampled token. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + discard_sampled_tokens_req_indices.append(i) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[: scheduler_output.total_num_scheduled_tokens], + scheduler_output, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + else: + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set + } + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}" + ) + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + if self.speculative_config: + self._draft_token_ids = self.propose_draft_token_ids( + valid_sampled_token_ids, + sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + scheduler_output.total_num_scheduled_tokens, + hidden_states, + attn_metadata, + aux_hidden_states, + ) + + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + # Omni-new: Convert to per-request tensors on CPU + hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() + pooler_output: list[torch.Tensor | None] = [] + prev_logits_index = 0 + for logits_index in logits_indices: + pooler_output.append(hidden_states_cpu[prev_logits_index : logits_index + 1]) + prev_logits_index = logits_index + 1 + + # Omni-new + output = OmniModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), + kv_connector_output=kv_connector_output, + ) + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + if not self.use_async_scheduling: + return output + + return AsyncNPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + + def _generate_process_reqs_hidden_states( + self, + attn_metadata, + with_prefill, + maybe_padded_num_tokens, + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + model_kwargs_extra, + ): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs_extra, + ) + + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params( + self.update_stream, forward_context, maybe_padded_num_tokens, self.speculative_config + ) + else: + update_attn_params(self.update_stream, forward_context, maybe_padded_num_tokens) + + if get_forward_context().sp_enabled: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] + return hidden_states diff --git a/vllm_omni/worker/npu_ar_worker.py b/vllm_omni/worker/npu_ar_worker.py new file mode 100644 index 00000000000..99b62dfecd9 --- /dev/null +++ b/vllm_omni/worker/npu_ar_worker.py @@ -0,0 +1,14 @@ +from vllm_ascend.worker.worker_v1 import NPUWorker +from vllm_omni.worker.npu_ar_model_runner import NPUARModelRunner + + +class NPUARWorker(NPUWorker): + """NPU AR worker for thinker/talker stages in Qwen2.5-Omni.""" + + def init_device(self): + device = self._init_device() + + self.model_runner: NPUARModelRunner = NPUARModelRunner( + self.vllm_config, device + ) + diff --git a/vllm_omni/worker/npu_diffusion_model_runner.py b/vllm_omni/worker/npu_diffusion_model_runner.py new file mode 100644 index 00000000000..390cbe550ff --- /dev/null +++ b/vllm_omni/worker/npu_diffusion_model_runner.py @@ -0,0 +1,741 @@ +"""Diffusion NPU Model Runner for vLLM-omni.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torch.nn as nn +from vllm.config import CUDAGraphMode +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import BatchDescriptor +from vllm.logger import logger +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.sequence import IntermediateTensors +from vllm.utils import ( + cdiv, +) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.core.sched.output import SchedulerOutput + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, +) + +# yapf: enable +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + ModelRunnerOutput, +) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.utils import ( + ProfileExecuteDuration, + enable_sp, + lmhead_tp_enable, +) +from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput + +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.npu_model_runner import OmniNPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class NPUDiffusionModelRunner(OmniNPUModelRunner): + """Diffusion model runner for vLLM-omni on NPU (non-autoregressive).""" + + def _prepare_inputs( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> tuple[ + dict[str, Any], + torch.Tensor, + np.ndarray, + int, + torch.Tensor, + int, + torch.Tensor, + SpecDecodeMetadata, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int, + dict[str, Any], + ]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array( + [ + num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32, + ) + + if self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]: + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph(total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size + else: + # Eager mode. + # (NOTE)Omni-new: Maybe we should remove the below ACLGraph logic + # But for eager, the logic is consistent with GPUDiffusionModelRunner + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo) = self._sync_metadata_across_dp( + num_input_tokens, with_prefill, enable_dbo + ) + + # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens + # We should consider removing maybe_padded_num_tokens later + num_input_tokens = maybe_padded_num_tokens + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True + ) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + self.query_start_loc[: num_reqs + 1].copy_(self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.query_start_loc[num_reqs + 1 :].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_(self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.attn_state = attn_state # type: ignore + + self.with_prefill = with_prefill + self.num_tokens_across_dp = num_tokens_across_dp + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) + attn_metadata: dict[str, Any] = {} + + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:num_input_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings(input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + # NOTE: Need to + model_kwargs = { + **self._init_model_kwargs(num_input_tokens), + **self._extract_mm_kwargs(scheduler_output), + } + # (NOTE) Omni-new: input_ids isn't set as None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs(num_input_tokens) + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens + ) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_(v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors( + {k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items()} + ) + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to(self.device, non_blocking=True) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata(num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[:total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + self.slot_mapping[total_num_scheduled_tokens:].fill_(0) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, + ) + + if self.speculative_config and spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder) or self.model_config.runner_type == "pooling": + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args, + ) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) + + return ( + attn_metadata, + positions, + num_input_tokens, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_num_scheduled_tokens, + model_kwargs, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + with ProfileExecuteDuration().capture_async("prepare input"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + ( + attn_metadata, + positions, + num_scheduled_tokens_np, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_query_len, + model_kwargs, + ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + moe_comm_type = self._select_moe_comm_method(num_input_tokens, self.with_prefill) + + # Omni-new: don't use cudagraph_dispatcher + # and remove ubatch_slices + aclgraph_runtime_mode = CUDAGraphMode.NONE + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=None, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + self.maybe_setup_kv_connector(scheduler_output) + + outputs = self._run_diffusion( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + multimodal_kwargs=model_kwargs, + logits_indices=logits_indices, + ) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) + + aux_hidden_states = None + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, aux_hidden_states = outputs + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) + finished_sending = None + finished_recving = None + + # Omni-new: extract_multimodal_outputs + _, multimodal_outputs = self.extract_multimodal_outputs(outputs) + # Ensure one tensor per request, map to CPU for output struct + pooler_output: list[torch.Tensor | None] = [] + if isinstance(multimodal_outputs, torch.Tensor): + # If model returned a single stacked tensor, split by requests + assert multimodal_outputs.shape[0] == self.input_batch.num_reqs + for i in range(self.input_batch.num_reqs): + pooler_output.append(multimodal_outputs[i].detach().to("cpu").contiguous()) + elif isinstance(multimodal_outputs, list): + for out in multimodal_outputs: + pooler_output.append(out.detach().to("cpu").contiguous() if out is not None else None) + elif isinstance(multimodal_outputs, dict): + for out in multimodal_outputs.values(): + pooler_output.append(out.detach().to("cpu").contiguous() if out is not None else None) + else: + raise RuntimeError("Unsupported diffusion output type") + + output = OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + num_nans_in_logits={}, + ) + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + + if self.dynamic_eplb: + self.eplb_updator.forward_end() + + if not self.use_async_scheduling: + return output + + return AsyncNPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=[], + invalid_req_indices=[], + async_output_copy_stream=self.async_output_copy_stream, + ) + + def _run_diffusion( + self, + *, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, + multimodal_kwargs: dict, + logits_indices: torch.Tensor, + ) -> torch.Tensor | list[torch.Tensor]: + """Runs the diffusion process and returns per-request tensors. + + Tries model interfaces in the following order for maximal compatibility: + 1) model.sample(condition=..., **kwargs) + 2) model.forward(condition=..., **kwargs) + 3) model.diffuse(condition=..., **kwargs) + """ + kwargs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs(multimodal_kwargs, device=self.device), + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + if hasattr(self.model, "forward"): + return self.model.forward(**kwargs) + # TODO: add the diffuse method for other models + + raise RuntimeError( + "The loaded model does not expose diffusion interfaces 'sample', " + "'forward', or 'diffuse'. Please implement one of them or adapt the runner." + ) + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + is_torchair_compile: bool = False, + aclgraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + ) -> torch.Tensor: + # only support eager mode and piecewise graph now + assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + # Force dummy run on prefill stage when this node is deemed as kv producer. + if self.is_kv_producer and not self.is_kv_consumer: + with_prefill = True + + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp( + num_tokens, with_prefill, False + ) + + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.max_num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + num_reqs = (num_tokens + self.decode_token_per_req - 1) // self.decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() + + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, dtype=self.dtype, device=self.device + ) + intermediate_tensors = IntermediateTensors( + {k: v[:num_tokens] for k, v in self.intermediate_tensors.items()} + ) + + # filter out the valid batch descriptor + _ag_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, uniform_decode=uniform_decode) + ) + if aclgraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + assert aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode == _ag_mode, ( + f"Aclgraph runtime mode mismatch at dummy_run. " + f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}." + ) + else: + aclgraph_runtime_mode = _ag_mode + + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( + False, + num_reqs=num_reqs, + num_tokens=num_tokens, + max_query_len=max_query_len, + aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, + ) + + need_dummy_logits = not self.in_profile_run and lmhead_tp_enable() + + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices]) + + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill, + is_torchair_compile, + input_ids, + positions, + attn_metadata, + num_tokens, + intermediate_tensors, + inputs_embeds, + ) + if need_dummy_logits: + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + skip_attn=True, + num_reqs=num_reqs, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + if need_dummy_logits: + self.drafter.model.compute_logits(hidden_states[dummy_indices]) + if self.in_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + return hidden_states diff --git a/vllm_omni/worker/npu_diffusion_worker.py b/vllm_omni/worker/npu_diffusion_worker.py new file mode 100644 index 00000000000..c496cd2ce9a --- /dev/null +++ b/vllm_omni/worker/npu_diffusion_worker.py @@ -0,0 +1,13 @@ +from vllm_ascend.worker.worker_v1 import NPUWorker +from vllm_omni.worker.npu_diffusion_model_runner import NPUDiffusionModelRunner + + +class NPUDiffusionWorker(NPUWorker): + """NPU diffusion worker for code2wav stage in Qwen2.5-Omni.""" + + def init_device(self): + device = self._init_device() + + self.model_runner: NPUDiffusionModelRunner = NPUDiffusionModelRunner( + self.vllm_config, device + ) diff --git a/vllm_omni/worker/npu_model_runner.py b/vllm_omni/worker/npu_model_runner.py new file mode 100644 index 00000000000..3c36085cbf8 --- /dev/null +++ b/vllm_omni/worker/npu_model_runner.py @@ -0,0 +1,1189 @@ +"""NPU Model Runner base class for vLLM-omni. + +Provides multimodality extensions for NPU model runners, including payload +decoding and multimodal output extraction. +""" + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, List, Optional, Union, cast + +import math +import numpy as np +import torch + +import vllm.envs as envs +from vllm.config import CUDAGraphMode +from vllm.distributed.kv_transfer import has_kv_transfer_group +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed.kv_transfer import get_kv_transfer_group +from vllm.forward_context import BatchDescriptor, DPMetadata, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import supports_mrope +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargsItem +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.sampling_params import SamplingType +from vllm.utils import cdiv, round_up +from vllm.v1.attention.backends.utils import CommonAttentionMetadata, split_attn_metadata +from vllm.v1.outputs import KVConnectorOutput, LogprobsLists, LogprobsTensors, SamplerOutput +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.worker.utils import is_residual_scattered_for_sp, MultiModalBudget +from vllm.v1.worker.gpu_model_runner import IntermediateTensors, PerLayerAttnMetadata +from vllm.v1.worker.ubatch_splitting import ubatch_split +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from vllm_ascend.worker.npu_input_batch import CachedRequestState +from vllm_ascend.utils import enable_sp, vllm_version_is, lmhead_tp_enable + +from vllm_omni.engine import AdditionalInformationPayload, PromptEmbedsPayload + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +class OmniNPUModelRunner(NPUModelRunner): + """Base class for NPU model runners with multimodality support. + + Extends NPUModelRunner with: + - Payload decoding (prompt_embeds, additional_information) + - Multimodal output extraction + - Additional information update merging + - Multimodal initialization (mm_budget, mm_registry, supports_mm_inputs) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.is_multimodal_raw_input_only_model = ( + self.model_config.is_multimodal_raw_input_only_model) + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + self.model_config) + + self._mm_budget = None + + @property + def mm_budget(self): + if self._mm_budget is None: + self._mm_budget = MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) if self.supports_mm_inputs else None + return self._mm_budget + + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) + + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input NPU tensors for the model. + + The SamplingMetadata is updated and copied to the NPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + reqs_to_add: list[CachedRequestState] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if self.is_pooling_model: + assert pooling_params is not None + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" + + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + # Handle backward compatibility for mm_features/mm_kwargs + backward_kwargs = {} + if vllm_version_is("0.11.0"): + backward_kwargs["mm_features"] = getattr(new_req_data, "mm_features", None) + else: + backward_kwargs["mm_kwargs"] = getattr(new_req_data, "mm_kwargs", None) + backward_kwargs["mm_hashes"] = getattr(new_req_data, "mm_hashes", None) + backward_kwargs["mm_positions"] = getattr(new_req_data, "mm_positions", None) + + req_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + **backward_kwargs, + ) + self.requests[req_id] = req_state + + # If prompt embeddings are provided, decode and attach to request state + try: + if getattr(new_req_data, "prompt_embeds", None) is not None: + payload = new_req_data.prompt_embeds + if isinstance(payload, PromptEmbedsPayload): + dtype = getattr(np, payload.dtype) + arr = np.frombuffer(payload.data, dtype=dtype) + arr = arr.reshape(payload.shape) + pe_cpu = torch.from_numpy(arr) + elif isinstance(payload, torch.Tensor): + pe_cpu = payload.detach().to("cpu").contiguous() + else: + pe_cpu = None + # Store temporarily on CPU; later moved to device in builder + if pe_cpu is not None: + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # Also replace payload with Tensor for user visibility in + # scheduler_output + try: + new_req_data.prompt_embeds = pe_cpu # type: ignore[assignment] + except Exception: + pass + except Exception as e: + logger.error(f"Error decoding prompt embeds: {e}") + # Decode additional_information payloads (dictionary) + try: + if getattr(new_req_data, "additional_information", None) is not None: + payload_info = new_req_data.additional_information + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + elif isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype( + getattr(entry, "tensor_dtype", "float32") + ) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding additional information: {e}") + pass + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(self.requests[req_id]) + + reqs_to_add.append(self.requests[req_id]) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + else: + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + reqs_to_add.append(req_state) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens + if new_block_ids is not None: + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) + if spec_token_ids: + num_spec_tokens = len(spec_token_ids) + start_index = self.input_batch.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index + ] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] += num_spec_tokens + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for request in reqs_to_add: + self.input_batch.add_request(request) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + @torch.inference_mode() + def extract_multimodal_outputs( + self, hidden_states: Union[torch.Tensor, List[torch.Tensor]] + ) -> tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor], dict]]: + """Extract multimodal outputs from hidden states.""" + if ( + hasattr(self.model, "have_multimodal_outputs") + and self.model.have_multimodal_outputs + ): + text_hidden_states = hidden_states.text_hidden_states + multimodal_outputs = hidden_states.multimodal_outputs + + elif isinstance(hidden_states, torch.Tensor): + text_hidden_states = hidden_states + multimodal_outputs = {} + elif isinstance(hidden_states, List): + text_hidden_states = hidden_states[0] + multimodal_outputs = {} + else: + raise ValueError(f"Invalid hidden states type: {type(hidden_states)}") + return text_hidden_states, multimodal_outputs + + def _sample( + self, logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata] + ) -> SamplerOutput: + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + if hasattr(self, "_update_states_after_model_execute"): + self._update_states_after_model_execute(output_token_ids) + + return sampler_output + + def _bookkeeping_sync( + self, scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, num_scheduled_tokens: int + ) -> tuple[ + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], + ]: + num_nans_in_logits = {} + # placeholder for now, TODO: _get_nans_in_logits() + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + invalid_req_indices = [] + + if not self.use_async_scheduling: + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + discard_sampled_tokens_req_indices.append(i) + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + else: + valid_sampled_token_ids = [] + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + discard_sampled_tokens_req_indices.append(i) + invalid_req_indices = discard_sampled_tokens_req_indices + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + + req_ids = self.input_batch.req_ids + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + + req_id = req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + return ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + + def get_dp_padding( + self, num_tokens: int + ) -> tuple[int, Optional[torch.Tensor]]: + """Determines the total number of tokens that each rank will run.""" + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + + if dp_size == 1 or self.vllm_config.model_config.enforce_eager: + return 0, None + + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + """Calculate input token count with padding if needed.""" + if (self.use_aclgraph and num_scheduled_tokens + <= self.aclgraph_batch_sizes[-1]): + # Add padding to the batch size for ACLGraph. + # Note: pad_for_cudagraph works for both CUDA graphs and ACLGraph + return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) + + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if (self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + + def _generate_dummy_run_hidden_states(self, with_prefill, + is_torchair_compile, input_ids, + positions, attn_metadata, num_tokens, + intermediate_tensors, inputs_embeds, + model_kwargs=None): + """Override to support model_kwargs for multimodal/pooling models.""" + if model_kwargs is None: + model_kwargs = {} + + hidden_states = self.model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs) + + from vllm_ascend.compilation.acl_graph import update_attn_params, update_mla_attn_params + + forward_context = get_forward_context() + assert forward_context is not None + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ + not forward_context.capturing: + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + positions.shape[0], + self.speculative_config) + else: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) + + from vllm_ascend.spec_decode.interface import SpecDcodeType + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, _ = hidden_states + else: + hidden_states = hidden_states + return hidden_states + + def _init_model_kwargs(self, num_tokens: int): + """Initialize model kwargs.""" + model_kwargs = dict[str, Any]() + + if not self.is_pooling_model: + return model_kwargs + + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens_cpu[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( + device=self.device) + return model_kwargs + + def sync_and_slice_intermediate_tensors( + self, num_tokens: int, intermediate_tensors: IntermediateTensors, + sync_self: bool + ) -> IntermediateTensors: + """Sync and slice intermediate tensors for pipeline parallelism.""" + assert self.intermediate_tensors is not None + + tp = self.vllm_config.parallel_config.tensor_parallel_size + is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens) + + if sync_self: + assert intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + is_scattered = k == "residual" and is_rs + copy_len = num_tokens // tp if is_scattered else num_tokens + self.intermediate_tensors[k][:copy_len].copy_( + v[:copy_len], non_blocking=True) + + return IntermediateTensors({ + k: v[:num_tokens // tp] if k == "residual" and is_rs else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + def _extract_mm_kwargs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict: + """Extract multimodal kwargs.""" + if not scheduler_output or not self.is_multimodal_raw_input_only_model: + return {} + + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + # Handle version compatibility + if vllm_version_is("0.11.0"): + mm_features = getattr(req, "mm_features", None) + else: + mm_features = getattr(req, "mm_kwargs", None) + + if mm_features: + if isinstance(mm_features, list): + for feature in mm_features: + if hasattr(feature, "data") and feature.data is not None: + mm_kwargs.append(feature.data) + elif isinstance(mm_features, dict): + return mm_features + + if not mm_kwargs: + return {} + + model = cast(SupportsMultiModal, self.model) + mm_kwargs_combined: dict = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ): + mm_kwargs_combined.update(mm_kwargs_group) + + return mm_kwargs_combined + + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models.""" + if hasattr(self, "_batch_mm_kwargs_from_scheduler"): + mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + else: + return {} + + if not mm_kwargs: + return {} + + model = cast(SupportsMultiModal, self.model) + encoder_features = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ): + encoder_features.update(mm_kwargs_group) + + return encoder_features + + @contextmanager + def synchronize_input_prep(self): + """Synchronize input preparation for async scheduling.""" + if getattr(self, 'use_async_scheduling', False): + if not hasattr(self, "prepare_inputs_event") or self.prepare_inputs_event is None: + self.prepare_inputs_event = torch.npu.Event() + self.prepare_inputs_event.record(torch.npu.current_stream()) + + if not hasattr(self, "prepare_inputs_event") or self.prepare_inputs_event is None: + yield + return + + self.prepare_inputs_event.synchronize() + try: + yield + finally: + self.prepare_inputs_event.record() + + @contextmanager + def maybe_get_kv_connector_output( + self, scheduler_output: "SchedulerOutput" + ): + """KV connector context manager.""" + if not has_kv_transfer_group(): + yield None + return + + output = KVConnectorOutput() + + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + kv_connector.start_load_kv(get_forward_context()) + try: + yield output + finally: + kv_connector.wait_for_save() + output.finished_sending, output.finished_recving = ( + kv_connector.get_finished(scheduler_output.finished_req_ids)) + output.kv_connector_stats = kv_connector.get_kv_connector_stats() + kv_connector.clear_connector_metadata() + + def pad_out_ubatch_slice(self, ubatch_slices, num_total_tokens: int): + """Pad ubatch slice for DBO (Dynamic Batch Overlap).""" + from vllm.v1.worker.ubatch_utils import UBatchSlice + if len(ubatch_slices) < 2: + return + padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, + num_total_tokens) + ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, + padded_second_ubatch_slice) + + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: + """Step for EPLB (Expert Parallelism Load Balancing).""" + if hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_end() + + def _get_mm_dummy_batch( + self, + modality: str, + max_items_per_batch: int, + ) -> dict: + """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_model_len, + mm_counts={modality: 1}, + cache=self.mm_budget.cache, + ) + dummy_mm_data = dummy_decoder_data.multi_modal_data + + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch + + model = cast(SupportsMultiModal, self.model) + return next(mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=getattr(self, "pin_memory", False), + merge_by_field_config=model.merge_by_field_config, + )) + + def _dummy_mm_kwargs(self, num_seqs: int) -> dict: + """Return dummy multimodal kwargs for dummy runs.""" + if not self.is_multimodal_raw_input_only_model: + return {} + + mm_budget = self.mm_budget + assert mm_budget is not None + + dummy_modality = mm_budget.get_modality_with_max_tokens() + return self._get_mm_dummy_batch(dummy_modality, num_seqs) + + @contextmanager + def maybe_randomize_inputs(self, input_ids: Optional[torch.Tensor]): + """Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.""" + dp_size = self.vllm_config.parallel_config.data_parallel_size + randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 + + if not randomize_inputs: + yield + return + + if input_ids is None: + yield + return + + import functools + + @functools.cache + def rand_input_ids() -> torch.Tensor: + return torch.randint_like( + self.input_ids, + low=0, + high=self.model_config.get_vocab_size(), + dtype=input_ids.dtype, + ) + + logger.debug_once("Randomizing dummy data for DP Rank") + input_ids.copy_( + rand_input_ids()[:input_ids.size(0)], + non_blocking=True + ) + yield + input_ids.fill_(0) + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + uniform_decode: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + ) -> torch.Tensor: + """Run a dummy forward pass to warm up/profile run or capture the ACL graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: Used to control the behavior. + - if not set will determine the aclgraph mode based on using + the self.aclgraph_dispatcher. + - CUDAGraphMode.NONE: No aclgraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise aclgraph. + - CUDAGraphMode.FULL: Full aclgraph, attention metadata is needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + allow_microbatching: If True, allow ubatch splitting if DBO is enabled. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. + remove_lora: If False, dummy LoRAs are not destroyed after the run. + """ + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + + if hasattr(self, "use_aclgraph") and self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + with_prefill = create_mixed_batch or (not uniform_decode and num_tokens > 1) + + if hasattr(self, "is_kv_producer") and self.is_kv_producer and \ + hasattr(self, "is_kv_consumer") and not self.is_kv_consumer: + with_prefill = True + + if hasattr(self, "_sync_metadata_across_dp"): + (num_tokens, num_tokens_across_dp, with_prefill, _) = \ + self._sync_metadata_across_dp(num_tokens, with_prefill, False) + else: + num_tokens_across_dp = None + + if hasattr(self, "_select_moe_comm_method"): + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + else: + moe_comm_type = None + + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + + if create_mixed_batch: + assert not uniform_decode + num_decode_tokens = num_tokens // 2 + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + decode_token_per_req = getattr(self, "decode_token_per_req", 1) + num_reqs = (num_tokens + decode_token_per_req - 1) // decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + + ubatch_slices = None + num_tokens_after_padding = None + + if self.parallel_config.enable_dbo and allow_microbatching: + ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + total_num_scheduled_tokens, + total_num_scheduled_tokens, + uniform_decode=uniform_decode, + vllm_config=self.vllm_config, + ) + if ubatch_num_tokens_after_padding is not None: + num_tokens_after_padding = ubatch_num_tokens_after_padding * 2 + + if num_tokens_after_padding is None: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens_after_padding = num_tokens + num_pad + else: + if isinstance(num_tokens_after_padding, torch.Tensor): + num_tokens_after_padding = int(num_tokens_after_padding[0].item()) + elif isinstance(num_tokens_after_padding, (list, np.ndarray)): + num_tokens_after_padding = int(num_tokens_after_padding[0]) + + attn_metadata: Optional[PerLayerAttnMetadata] = None + + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + + if create_mixed_batch: + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + seq_lens = max_query_len + + self.seq_lens_np[:num_reqs] = seq_lens + self.seq_lens_np[num_reqs:] = 0 + if isinstance(seq_lens, list): + self.seq_lens_cpu[:num_reqs] = torch.tensor(seq_lens, dtype=torch.int32) + else: + self.seq_lens_cpu[:num_reqs] = seq_lens + self.seq_lens_cpu[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + self.seq_lens[num_reqs:].fill_(0) + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc_cpu[0] = 0 + self.query_start_loc_cpu[1 : num_reqs + 1] = torch.from_numpy(cum_num_tokens) + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups + ): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=self.max_model_len, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id + ].get_device_tensor(num_reqs), + slot_mapping=self.input_batch.block_table[ + kv_cache_group_id + ].slot_mapping[:num_tokens], + causal=True, + ) + for attn_group in self.attn_groups[kv_cache_group_id]: + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata + ) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list + ): + assert common_attn_metadata.max_query_len == 1 + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build_for_cudagraph_capture(common_attn_metadata) + for layer_name in attn_group.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + assert type(attn_metadata) is dict + attn_metadata_i = attn_group.get_metadata_builder().build_for_cudagraph_capture( + common_attn_metadata + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): + model_kwargs = self._init_model_kwargs(num_tokens) + + # Prepare inputs (NPU uses direct tensor access, not .gpu buffers) + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_kwargs = { + **model_kwargs, + **self._dummy_mm_kwargs(num_reqs), + } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_kwargs = self._init_model_kwargs(num_tokens) + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + # Prepare intermediate_tensors if PP + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + ) + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, self.intermediate_tensors, False + ) + + # Dispatch graph mode (check for aclgraph_dispatcher) + if hasattr(self, "aclgraph_dispatcher") and not is_profile: + _cg_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + else: + _cg_mode, batch_descriptor = (CUDAGraphMode.NONE, None) + + # Map GPU parameter name to NPU internal name for clarity + # cudagraph_runtime_mode (GPU signature) → aclgraph_runtime_mode (NPU internal) + if cudagraph_runtime_mode is not None: + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( + f"ACL graph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) + aclgraph_runtime_mode = cudagraph_runtime_mode + else: + aclgraph_runtime_mode = _cg_mode + + # Adjust for ubatch if needed + if ubatch_slices is not None: + num_tokens_after_padding = ubatch_slices[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_after_padding + + original_in_profile_run = self.in_profile_run + self.in_profile_run = is_profile + + if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.forward_before() + + need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) + dummy_indices = None + dummy_compute_logits = None + + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32, device=self.device) + + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices]) + + try: + with self.maybe_randomize_inputs(input_ids), set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_after_padding, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + reserved_mc2_mask=getattr(self, "reserved_mc2_mask", None), + moe_comm_type=moe_comm_type, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + prefetch_stream=getattr(self, "prefetch_stream", None), + model_instance=self.model, + weight_prefetch_method=getattr(self, "weight_prefetch_method", None), + ): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill=with_prefill, + is_torchair_compile=False, + input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata, + num_tokens=num_tokens, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + model_kwargs=model_kwargs, + ) + + if need_dummy_logits: + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + skip_attn=True, + num_reqs=num_reqs, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + if need_dummy_logits: + self.drafter.model.compute_logits(hidden_states[dummy_indices]) + + if self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "model"): + self.model.clear_all_moe_loads() + if not self.in_profile_run and hasattr(self, "dynamic_eplb") and self.dynamic_eplb: + if hasattr(self, "eplb_updator"): + self.eplb_updator.take_update_info_from_eplb_process() + finally: + self.in_profile_run = original_in_profile_run + + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + # Extract text hidden states from multimodal outputs + # Parent class will handle logit indexing in profile_run + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + return hidden_states