diff --git a/docs/api/README.md b/docs/api/README.md index bb4e9d87779..a8049b01118 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -111,3 +111,8 @@ Worker classes and model runners for distributed inference. - [vllm_omni.worker.gpu_generation_model_runner.GPUGenerationModelRunner][] - [vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker][] - [vllm_omni.worker.gpu_model_runner.OmniGPUModelRunner][] +- [vllm_omni.worker.npu.npu_ar_model_runner.NPUARModelRunner][] +- [vllm_omni.worker.npu.npu_ar_worker.NPUARWorker][] +- [vllm_omni.worker.npu.npu_diffusion_model_runner.NPUDiffusionModelRunner][] +- [vllm_omni.worker.npu.npu_diffusion_worker.NPUDiffusionWorker][] +- [vllm_omni.worker.npu.npu_model_runner.OmniNPUModelRunner][] diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index c301e7a1279..7303bebbcfd 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -4,4 +4,4 @@ vLLM supports the following hardware platforms: - [GPU](gpu.md) - [NVIDIA CUDA](gpu.md) -- [Ascend NPU](npu.md) +- [NPU](npu.md) diff --git a/docs/getting_started/installation/npu.md b/docs/getting_started/installation/npu.md index d341c1504c2..197bcec305b 100644 --- a/docs/getting_started/installation/npu.md +++ b/docs/getting_started/installation/npu.md @@ -1,3 +1,23 @@ -# Ascend-NPU +# NPU -vLLM-Omni is a Python library that supports the following NPU variants. Select your NPU type to see vendor specific instructions: +vLLM-Omni supports NPU through the vLLM Ascend Plugin (vllm-ascend). This is a community maintained hardware plugin for running vLLM on NPU. + +## Requirements + +- OS: Linux +- Python: 3.12 + +!!! note + vLLM-Omni is currently not natively supported on Windows. + +=== "NPU" + + --8<-- "docs/getting_started/installation/npu/npu.inc.md:requirements" + +## Installation + +### Recommended + +=== "NPU" + + --8<-- "docs/getting_started/installation/npu/npu.inc.md:installation" diff --git a/docs/getting_started/installation/npu/npu.inc.md b/docs/getting_started/installation/npu/npu.inc.md index c20460ed939..02d757dfcc0 100644 --- a/docs/getting_started/installation/npu/npu.inc.md +++ b/docs/getting_started/installation/npu/npu.inc.md @@ -1,5 +1,48 @@ +# --8<-- [start:requirements] + +For detailed hardware and software requirements, please refer to the [vllm-ascend installation documentation](https://docs.vllm.ai/projects/ascend/en/latest/installation.html). + +# --8<-- [end:requirements] # --8<-- [start:installation] -vLLM-Omni mainly contains python implementations for framework and models. +The recommended way to use vLLM-Omni on NPU is through the vllm-ascend pre-built Docker images: + +```bash +# Update DEVICE according to your NPUs (/dev/davinci[0-7]) +export DEVICE0=/dev/davinci0 +export DEVICE1=/dev/davinci1 +# Update the vllm-ascend image +# Atlas A2: +# export IMAGE=quay.io/ascend/vllm-ascend:v0.11.0rc2 +# Atlas A3: +# export IMAGE=quay.io/ascend/vllm-ascend:v0.11.0rc2-a3 +export IMAGE=quay.io/ascend/vllm-ascend:v0.11.0rc2 +docker run --rm \ + --name vllm-omni-npu \ + --device $DEVICE0 \ + --device $DEVICE1 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v /root/.cache:/root/.cache \ + -p 8000:8000 \ + -it $IMAGE bash + +# Inside the container, install vLLM-Omni from source +cd /vllm-workspace +git clone https://github.com/vllm-project/vllm-omni.git +cd vllm-omni +pip install -v -e . +export VLLM_WORKER_MULTIPROC_METHOD=spawn +``` + +The default workdir is `/workspace`, with vLLM, vLLM-Ascend and vLLM-Omni code placed in `/vllm-workspace` installed in development mode. + +For other installation methods (pip installation, building from source, custom Docker builds), please refer to the [vllm-ascend installation guide](https://docs.vllm.ai/projects/ascend/en/latest/installation.html). # --8<-- [end:installation] diff --git a/examples/offline_inference/qwen_image/gradio_demo.py b/examples/offline_inference/qwen_image/gradio_demo.py index 7906054a0d1..2c4979e97c1 100644 --- a/examples/offline_inference/qwen_image/gradio_demo.py +++ b/examples/offline_inference/qwen_image/gradio_demo.py @@ -5,6 +5,7 @@ import torch from vllm_omni.entrypoints.omni import Omni +from vllm_omni.utils.platform_utils import detect_device_type ASPECT_RATIOS: dict[str, tuple[int, int]] = { "1:1": (1328, 1328), @@ -62,7 +63,7 @@ def get_omni(model_name: str) -> Omni: def build_demo(args: argparse.Namespace) -> gr.Blocks: - device = "cuda" + device = detect_device_type() omni = get_omni(args.model) def run_inference( diff --git a/examples/offline_inference/qwen_image/text_to_image.py b/examples/offline_inference/qwen_image/text_to_image.py index e0a4fb8b2e9..4524fcefe39 100644 --- a/examples/offline_inference/qwen_image/text_to_image.py +++ b/examples/offline_inference/qwen_image/text_to_image.py @@ -7,6 +7,7 @@ import torch from vllm_omni.entrypoints.omni import Omni +from vllm_omni.utils.platform_utils import detect_device_type def parse_args() -> argparse.Namespace: @@ -45,7 +46,7 @@ def parse_args() -> argparse.Namespace: def main(): args = parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" + device = detect_device_type() generator = torch.Generator(device=device).manual_seed(args.seed) omni = Omni(model=args.model) diff --git a/vllm_omni/core/sched/diffusion_scheduler.py b/vllm_omni/core/sched/diffusion_scheduler.py index e8707aabf35..56b7be68cae 100644 --- a/vllm_omni/core/sched/diffusion_scheduler.py +++ b/vllm_omni/core/sched/diffusion_scheduler.py @@ -4,9 +4,9 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.request_queue import create_request_queue -from vllm.v1.core.sched.scheduler import EngineCoreOutputs, Request, RequestStatus, SchedulerOutput, SpecDecodingStats +from vllm.v1.core.sched.scheduler import Request, RequestStatus, SchedulerOutput, SpecDecodingStats from vllm.v1.core.sched.utils import remove_all -from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm_omni.core.sched.output import OmniNewRequestData from vllm_omni.core.sched.scheduler import OmniScheduler diff --git a/vllm_omni/core/sched/generation_scheduler.py b/vllm_omni/core/sched/generation_scheduler.py index 7916c36ed2f..9a0c072660f 100644 --- a/vllm_omni/core/sched/generation_scheduler.py +++ b/vllm_omni/core/sched/generation_scheduler.py @@ -5,14 +5,13 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.request_queue import create_request_queue from vllm.v1.core.sched.scheduler import ( - EngineCoreOutputs, Request, RequestStatus, SchedulerOutput, SpecDecodingStats, ) from vllm.v1.core.sched.utils import remove_all -from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm_omni.core.sched.output import OmniNewRequestData from vllm_omni.core.sched.scheduler import OmniScheduler diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index cd1e2045809..05f13afaf3b 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -9,7 +9,7 @@ from vllm_omni.diffusion.registry import get_diffusion_post_process_func from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.scheduler import scheduler -from vllm_omni.diffusion.worker.gpu_worker import WorkerProc +from vllm_omni.utils.platform_utils import get_diffusion_worker_class logger = init_logger(__name__) @@ -76,6 +76,9 @@ def _launch_workers(self, broadcast_handle): mp.set_start_method("spawn", force=True) processes = [] + # Get the appropriate worker class for current device + worker_proc = get_diffusion_worker_class() + # Launch all worker processes scheduler_pipe_readers = [] scheduler_pipe_writers = [] @@ -84,7 +87,7 @@ def _launch_workers(self, broadcast_handle): reader, writer = mp.Pipe(duplex=False) scheduler_pipe_writers.append(writer) process = mp.Process( - target=WorkerProc.worker_main, + target=worker_proc.worker_main, args=( i, # rank od_config, diff --git a/vllm_omni/diffusion/distributed/utils.py b/vllm_omni/diffusion/distributed/utils.py index fb930f21a34..2e2f0fa537a 100644 --- a/vllm_omni/diffusion/distributed/utils.py +++ b/vllm_omni/diffusion/distributed/utils.py @@ -5,7 +5,11 @@ import torch +from vllm_omni.utils.platform_utils import detect_device_type + def get_local_device() -> torch.device: - """Return the torch device for the current rank.""" - return torch.device(f"cuda:{os.environ.get('LOCAL_RANK', 0)}") + """Return the torch device for the current rank based on detected device type.""" + device_type = detect_device_type() + local_rank = os.environ.get("LOCAL_RANK", 0) + return torch.device(f"{device_type}:{local_rank}") diff --git a/vllm_omni/diffusion/worker/npu/npu_worker.py b/vllm_omni/diffusion/worker/npu/npu_worker.py new file mode 100644 index 00000000000..4baf206033b --- /dev/null +++ b/vllm_omni/diffusion/worker/npu/npu_worker.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing as mp +import os + +import torch +import zmq +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.registry import initialize_model +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = init_logger(__name__) + + +class NPUWorker: + """ + A worker that executes the model on a single NPU. + """ + + def __init__( + self, + local_rank: int, + rank: int, + od_config: OmniDiffusionConfig, + ): + self.local_rank = local_rank + self.rank = rank + self.od_config = od_config + self.pipeline = None + + self.init_device_and_model() + + def init_device_and_model(self) -> None: + """Initialize the device and load the model.""" + world_size = self.od_config.num_gpus + rank = self.rank + # Set environment variables for distributed initialization + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(self.od_config.master_port) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + device = torch.device(f"npu:{rank}") + torch.npu.set_device(device) + + # hack + vllm_config = VllmConfig() + vllm_config.parallel_config.tensor_parallel_size = self.od_config.num_gpus + set_current_vllm_config(vllm_config) + + init_distributed_environment(world_size=world_size, rank=rank) + initialize_model_parallel(tensor_model_parallel_size=world_size) + + with device: + self.pipeline = initialize_model(self.od_config) + self.pipeline.load_weights() + self.pipeline.eval() + logger.info(f"Worker {self.rank}: Initialized device, model, and distributed environment.") + logger.info(f"Worker {self.rank}: Model loaded successfully.") + + @torch.inference_mode() + def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: + """ + Execute a forward pass. + """ + assert self.pipeline is not None + # TODO: dealing with first req for now + req = reqs[0] + output = self.pipeline.forward(req) + return output + + +class NPUWorkerProc: + """Wrapper that runs one Worker in a separate process.""" + + def __init__( + self, + od_config: OmniDiffusionConfig, + gpu_id: int, + broadcast_handle, + ): + self.od_config = od_config + + # Inter-process Communication + self.context = zmq.Context(io_threads=2) + + # Initialize MessageQueue reader from handle + self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) + + self.result_mq = None + self.result_mq_handle = None + + # Setup result sender (only for rank 0 for now, or whoever needs to reply) + # Assuming only rank 0 replies to scheduler as per original logic + if gpu_id == 0: + # Create MessageQueue for results (1 writer -> 1 reader) + # We assume the reader (SyncScheduler) will act as rank 0 + self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0]) + self.result_mq_handle = self.result_mq.export_handle() + logger.info(f"Worker {gpu_id} created result MessageQueue") + + assert od_config.master_port is not None + worker = NPUWorker( + local_rank=gpu_id, + rank=gpu_id, + od_config=od_config, + ) + self.worker = worker + self.gpu_id = gpu_id + self._running = True + + def return_result(self, output: DiffusionOutput): + """ + replies to client, only on rank 0 + """ + if self.result_mq is not None: + self.result_mq.enqueue(output) + + def recv_reqs(self): + """ + Receive requests from broadcast queue + """ + return self.mq.dequeue() + + # TODO: queueing, cancellation + def worker_busy_loop(self) -> None: + """Main busy loop for Multiprocessing Workers""" + + logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") + + while self._running: + reqs = None + # 1: receive requests + try: + reqs = self.recv_reqs() + except Exception as e: + logger.error( + f"Error receiving requests in scheduler event loop: {e}", + exc_info=True, + ) + continue + + # 2: execute, make sure a reply is always sent + try: + output = self.worker.execute_model(reqs, self.od_config) + except Exception as e: + logger.error( + f"Error executing forward in event loop: {e}", + exc_info=True, + ) + output = DiffusionOutput(error=str(e)) + + try: + self.return_result(output) + except zmq.ZMQError as e: + # Reply failed; log and keep loop alive to accept future requests + logger.error(f"ZMQ error sending reply: {e}") + continue + + logger.info("event loop terminated.") + # if self.result_sender is not None: + # self.result_sender.close() + self.context.term() + + @staticmethod + def worker_main( + rank: int, + od_config: OmniDiffusionConfig, + pipe_writer: mp.connection.Connection, + broadcast_handle, + ) -> None: + """Worker initialization and execution loops.""" + + worker_proc = NPUWorkerProc( + od_config, + gpu_id=rank, + broadcast_handle=broadcast_handle, + ) + logger.info(f"Worker {rank}: Scheduler loop started.") + pipe_writer.send( + { + "status": "ready", + "result_handle": worker_proc.result_mq_handle if rank == 0 else None, + } + ) + worker_proc.worker_busy_loop() + logger.info(f"Worker {rank}: Shutdown complete.") diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 635005cb5b5..5a891f95f69 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -32,7 +32,7 @@ _to_dict, maybe_dump_to_shm, maybe_load_from_ipc_with_metrics, - set_stage_gpu_devices, + set_stage_devices, ) from vllm_omni.inputs.data import OmniTokensPrompt @@ -425,7 +425,10 @@ def filter(self, record: _logging.LogRecord) -> bool: # Device mapping try: - set_stage_gpu_devices(stage_id, runtime_cfg.get("devices")) + from vllm_omni.utils import detect_device_type + + device_type = detect_device_type() + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: _logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e) @@ -717,7 +720,10 @@ def filter(self, record: _logging.LogRecord) -> bool: # Device mapping try: - set_stage_gpu_devices(stage_id, runtime_cfg.get("devices")) + from vllm_omni.utils import detect_device_type + + device_type = detect_device_type() + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: _logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e) diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 8ddc7d5049d..eda037b68f7 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -13,100 +13,143 @@ logger = logging.getLogger(__name__) -def set_stage_gpu_devices(stage_id: int, devices: str | int | None) -> None: - """Configure per-stage CUDA visibility and current device. - - Behavior - - Comma-separated string (e.g. "2,5,7"): set CUDA_VISIBLE_DEVICES exactly - to this list; logical index 0 is used as current device. - - Integer or digit-string: treat as logical index (0-based) into the current - CUDA_VISIBLE_DEVICES mapping; map to the physical device, and then set - CUDA_VISIBLE_DEVICES to this single device. - - None/"cpu": keep default visibility. - - Otherwise: set CUDA_VISIBLE_DEVICES to the provided single device string. +def set_stage_devices( + stage_id: int, + devices: str | int | None, + device_type: str | None = None, +) -> None: + """Configure per-stage device visibility and current device (CUDA or NPU). + + This function sets environment variables that control which devices are visible + to the process, and sets the current device. It must be called BEFORE worker + initialization so that workers see the correct devices. + + Args: + stage_id: Stage identifier for logging + devices: Device specification: + - Comma-separated string (e.g. "2,5,7"): set device visibility env var + exactly to this list; logical index 0 is used as current device. + - Integer or digit-string: treat as logical index (0-based) into the + current device visibility mapping; map to physical device, then set + env var to this single device. + - None/"cpu": keep default visibility. + - Otherwise: set env var to the provided single device string. + device_type: Device type ("cuda" or "npu"). If None, auto-detects. + + Behavior: + - CUDA: Sets CUDA_VISIBLE_DEVICES and calls torch.cuda.set_device() + - NPU: Sets ASCEND_RT_VISIBLE_DEVICES and calls torch.npu.set_device() """ + from vllm_omni.utils import detect_device_type, get_device_control_env_var + + if device_type is None: + device_type = detect_device_type() + + env_var = get_device_control_env_var() + + # Select device-specific torch functions + if device_type == "npu": + try: + import torch.npu # type: ignore[import-untyped] + except ImportError: + logger.debug("[Stage-%s] torch.npu not available, skipping NPU device setup", stage_id) + return + + is_available_fn = torch.npu.is_available + set_device_fn = torch.npu.set_device + device_count_fn = torch.npu.device_count + get_device_properties_fn = torch.npu.get_device_properties + mem_get_info_fn = torch.npu.mem_get_info + get_device_name_fn = torch.npu.get_device_name + device_type_label = "NPU" + elif device_type == "cuda": + import torch # noqa: WPS433 + + is_available_fn = torch.cuda.is_available + set_device_fn = torch.cuda.set_device + device_count_fn = torch.cuda.device_count + get_device_properties_fn = torch.cuda.get_device_properties + mem_get_info_fn = torch.cuda.mem_get_info + get_device_name_fn = torch.cuda.get_device_name + device_type_label = "CUDA" + else: + logger.debug("[Stage-%s] Unsupported device type: %s", stage_id, device_type) + return + try: selected_physical: int | None = None logical_idx: int | None = None if isinstance(devices, str) and "," in devices: - os.environ["CUDA_VISIBLE_DEVICES"] = devices + os.environ[env_var] = devices toks = [t.strip() for t in devices.split(",") if t.strip() != ""] if toks: try: selected_physical = int(toks[0]) logger.debug( - "[Stage-%s] Set CUDA_VISIBLE_DEVICES to %s; logical 0 -> physical %s", + "[Stage-%s] Set %s to %s; logical 0 -> physical %s", stage_id, + env_var, devices, selected_physical, ) except Exception as e: - logger.debug("[Stage-%s] Failed to parse first CUDA device: %s", stage_id, e) + logger.debug("[Stage-%s] Failed to parse first %s device: %s", stage_id, device_type_label, e) selected_physical = None elif isinstance(devices, (int, str)) and (isinstance(devices, int) or str(devices).isdigit()): logical_idx = max(0, int(devices)) - vis = os.environ.get("CUDA_VISIBLE_DEVICES") + vis = os.environ.get(env_var) if vis: try: mapping = [int(x) for x in vis.split(",") if x.strip() != ""] if 0 <= logical_idx < len(mapping): selected_physical = mapping[logical_idx] except Exception as e: - logger.debug("[Stage-%s] Failed to map logical index via CUDA_VISIBLE_DEVICES: %s", stage_id, e) + logger.debug("[Stage-%s] Failed to map logical index via %s: %s", stage_id, env_var, e) selected_physical = None if selected_physical is None: selected_physical = int(logical_idx) - os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical) + os.environ[env_var] = str(selected_physical) logger.debug( - "[Stage-%s] Logical index %d -> physical %s; set CUDA_VISIBLE_DEVICES to single device", + "[Stage-%s] Logical index %d -> physical %s; set %s to single device", stage_id, logical_idx + 1, selected_physical, + env_var, ) elif devices in (None, "cpu"): logger.debug("[Stage-%s] Using default device visibility (devices=%s)", stage_id, devices) else: selected_physical = int(str(devices)) - os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical) - logger.debug( - "[Stage-%s] Set CUDA_VISIBLE_DEVICES to single device %s (fallback)", stage_id, selected_physical - ) + os.environ[env_var] = str(selected_physical) + logger.debug("[Stage-%s] Set %s to single device %s (fallback)", stage_id, env_var, selected_physical) try: - import torch + import torch # noqa: WPS433 - if torch.cuda.is_available(): + if is_available_fn(): try: - torch.cuda.set_device(0) + set_device_fn(0) except Exception as e: logger.debug( - "[Stage-%s] torch.cuda.set_device(0) failed: %s", - stage_id, - e, - exc_info=True, + "[Stage-%s] %s set_device(0) failed: %s", stage_id, device_type_label, e, exc_info=True ) - num = torch.cuda.device_count() + num = device_count_fn() info = [] for i in range(num): - total = torch.cuda.get_device_properties(i).total_memory - free, _ = torch.cuda.mem_get_info(i) + total = get_device_properties_fn(i).total_memory + free, _ = mem_get_info_fn(i) info.append( { "idx": i, - "name": torch.cuda.get_device_name(i), + "name": get_device_name_fn(i), "total": int(total), "free": int(free), } ) - logger.debug("[Stage-%s] CUDA devices visible=%s info=%s", stage_id, num, info) + logger.debug("[Stage-%s] %s devices visible=%s info=%s", stage_id, device_type_label, num, info) except Exception as e: - logger.debug( - "[Stage-%s] Failed to query CUDA devices: %s", - stage_id, - e, - exc_info=True, - ) + logger.debug("[Stage-%s] Failed to query %s devices: %s", stage_id, device_type_label, e, exc_info=True) except Exception as e: logger.warning("Failed to interpret devices for stage %s: %s", stage_id, e) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 857d9d491da..57a6b755468 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -1,13 +1,10 @@ -from __future__ import annotations - -import logging import os from pathlib import Path from omegaconf import OmegaConf from vllm.transformers_utils.config import get_config -logger = logging.getLogger(__name__) +from vllm_omni.utils import detect_device_type # Get the project root directory (2 levels up from this file) PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -16,9 +13,9 @@ def load_stage_configs_from_model(model: str) -> list: """Load stage configurations from model's default config file. - Loads stage configurations based on the model type. Looks for a - YAML configuration file in the stage_configs directory matching - the model's model_type. + Loads stage configurations based on the model type and device type. + First tries to load a device-specific YAML file from stage_configs/{device_type}/ + directory. If not found, falls back to the default config file. Args: model: Model name or path (used to determine model_type) @@ -31,6 +28,17 @@ def load_stage_configs_from_model(model: str) -> list: """ hf_config = get_config(model, trust_remote_code=True) model_type = hf_config.model_type + device_type = detect_device_type() + + # Try device-specific config first + if device_type != "cuda": + device_config_file = f"vllm_omni/model_executor/stage_configs/{device_type}/{model_type}.yaml" + device_config_path = PROJECT_ROOT / device_config_file + if os.path.exists(device_config_path): + stage_configs = load_stage_configs_from_yaml(config_path=str(device_config_path)) + return stage_configs + + # Fall back to default config stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" stage_config_path = PROJECT_ROOT / stage_config_file if not os.path.exists(stage_config_path): diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py index 34961441f3c..f6f5608cef6 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -35,6 +35,7 @@ ) from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights, split_list_into_ranges from vllm_omni.model_executor.models.vision import get_llm_pos_ids_for_vision +from vllm_omni.utils.platform_utils import is_npu TALKER_CODEC_EOS_TOKEN_ID = 8294 TALKER_CODEC_BOS_TOKEN_ID = 8293 @@ -201,6 +202,7 @@ def forward( """ if self.model_stage == "thinker": # Normalize to batched inputs if caller provides 1D/2D unbatched tensors + # TODO: Remove this hack when NPU supports batched inputs properly added_batch_dim = False if input_ids is not None and input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) @@ -228,12 +230,25 @@ def forward( positions = positions.to(thinker_dev) if inputs_embeds is not None and inputs_embeds.device != thinker_dev: inputs_embeds = inputs_embeds.to(thinker_dev) + + if is_npu(): + # TODO: remove this hack when NPU supports batched inputs properly + thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids + thinker_positions = positions[0] if positions.ndim > 1 else positions + thinker_inputs_embeds = ( + inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds + ) + else: + thinker_input_ids = input_ids + thinker_positions = positions[0] + thinker_inputs_embeds = inputs_embeds + # Run thinker thinker_output = self.thinker( - input_ids=input_ids, - positions=positions[0], + input_ids=thinker_input_ids, + positions=thinker_positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, + inputs_embeds=thinker_inputs_embeds, **kwargs, ) @@ -851,7 +866,7 @@ def _init_token2wav_model(self, hf_model_folder): __init__.""" if self.token2wav is None or self.token2wav_config is None: return - device = "cuda" if torch.cuda.is_available() else "cpu" + device = self._module_device(self.token2wav) # optional speaker resources conds = getattr(self.token2wav_config, "conds", None) ref_mels = getattr(self.token2wav_config, "ref_mels", None) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py index f1109a9a7b1..a3241928ec2 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py @@ -33,6 +33,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler +from vllm_omni.utils.platform_utils import is_npu + # Provide a no-op auto_docstring decorator to satisfy annotations if missing def auto_docstring(func=None, **_kwargs): @@ -724,7 +726,13 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> else: beta = 0.0 - kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + # TODO: When torch.kaiser_window supports NPU, remove the device="cpu" argument + if is_npu(): + kaiser_window = torch.kaiser_window( + kernel_size, beta=beta, periodic=False, dtype=torch.float32, device="cpu" + ).to("npu") + else: + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) # Compute time indices if is_even: @@ -734,7 +742,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> # Compute sinc filter if cutoff == 0: - return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) sinc_filter = torch.sinc(2 * cutoff * time_indices) normalized_filter = 2 * cutoff * kaiser_window * sinc_filter @@ -745,6 +753,29 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> return normalized_filter.view(1, 1, kernel_size) +def replication_pad_1d(hidden_states: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor: + """ + Manual replicate padding to avoid replication_pad1d kernel limits on NPU. + TODO: remove when F.pad supports replicate mode on NPU. + """ + # NOTE: a immature implementation for running in NPU. Need to discuss. + if pad_left == 0 and pad_right == 0: + return hidden_states + + segments = [] + if pad_left > 0: + left = hidden_states[..., :1].expand(*hidden_states.shape[:-1], pad_left) + segments.append(left) + + segments.append(hidden_states) + + if pad_right > 0: + right = hidden_states[..., -1:].expand(*hidden_states.shape[:-1], pad_right) + segments.append(right) + + return torch.cat(segments, dim=-1) + + class UpSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None): super().__init__() @@ -760,14 +791,28 @@ def __init__(self, ratio=2, kernel_size=None): def forward(self, hidden_states): channels = hidden_states.shape[1] - hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to(self.filter.dtype) - hidden_states = self.ratio * F.conv_transpose1d( - hidden_states, - self.filter.expand(channels, -1, -1), - stride=self.stride, - groups=channels, - ).to(hidden_states_dtype) + if is_npu(): + # TODO: When F.pad supports replicate mode on NPU, remove this branch + input_dtype = hidden_states.dtype + # F.pad in NPU doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d(hidden_states.to(self.filter.dtype), self.pad, self.pad) + filter_convert_dtype = self.filter.to(hidden_states.dtype) + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, + filter_convert_dtype.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(input_dtype) + else: + hidden_states_dtype = hidden_states.dtype + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to(self.filter.dtype) + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, + self.filter.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(hidden_states_dtype) hidden_states = hidden_states[..., self.pad_left : -self.pad_right] return hidden_states @@ -793,14 +838,29 @@ def __init__(self, ratio=2, kernel_size=None): def forward(self, hidden_states): channels = hidden_states.shape[1] - hidden_states_dtype = hidden_states.dtype - hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate").to(self.filter.dtype) - out = F.conv1d( - hidden_states, - self.filter.expand(channels, -1, -1), - stride=self.stride, - groups=channels, - ).to(hidden_states_dtype) + if is_npu(): + input_dtype = hidden_states.dtype + # F.pad in NPU doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d(hidden_states.to(self.filter.dtype), self.pad_left, self.pad_right) + filter_on_device = self.filter.to(device=hidden_states.device, dtype=hidden_states.dtype) + out = F.conv1d( + hidden_states, + filter_on_device.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(input_dtype) + else: + hidden_states_dtype = hidden_states.dtype + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate").to( + self.filter.dtype + ) + out = F.conv1d( + hidden_states, + self.filter.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(hidden_states_dtype) return out diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen2_5_omni.yaml new file mode 100644 index 00000000000..f2205b8be9c --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/npu/qwen2_5_omni.yaml @@ -0,0 +1,94 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. +stage_args: + - stage_id: 0 + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker + scheduler_cls: vllm_omni.core.sched.scheduler.OmniScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker + scheduler_cls: vllm_omni.core.sched.scheduler.OmniScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + - stage_id: 2 + runtime: + process: true + devices: "0" # Example: use a different NPU than the previous stage; use "0" if single NPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.npu.npu_diffusion_worker.NPUDiffusionWorker + scheduler_cls: vllm_omni.core.sched.diffusion_scheduler.DiffusionScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/vllm_omni/utils/__init__.py b/vllm_omni/utils/__init__.py index e69de29bb2d..50dbb478d90 100644 --- a/vllm_omni/utils/__init__.py +++ b/vllm_omni/utils/__init__.py @@ -0,0 +1,11 @@ +from vllm_omni.utils.platform_utils import ( + detect_device_type, + get_device_control_env_var, + is_npu, +) + +__all__ = [ + "detect_device_type", + "get_device_control_env_var", + "is_npu", +] diff --git a/vllm_omni/utils/platform_utils.py b/vllm_omni/utils/platform_utils.py new file mode 100644 index 00000000000..641f763457d --- /dev/null +++ b/vllm_omni/utils/platform_utils.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import torch +from vllm.platforms import current_platform + + +def detect_device_type() -> str: + device_type = getattr(current_platform, "device_type", None) + if isinstance(device_type, str) and device_type: + return device_type.lower() + if torch.cuda.is_available(): + return "cuda" + if hasattr(torch, "npu") and torch.npu.is_available(): # type: ignore[attr-defined] + return "npu" + return "cpu" + + +def is_npu() -> bool: + return detect_device_type() == "npu" + + +def get_device_control_env_var() -> str: + """Return the environment variable name for device visibility control.""" + if hasattr(current_platform, "device_control_env_var"): + env_var = getattr(current_platform, "device_control_env_var", None) + if isinstance(env_var, str) and env_var: + return env_var + + device_type = detect_device_type() + if device_type == "npu": + return "ASCEND_RT_VISIBLE_DEVICES" + return "CUDA_VISIBLE_DEVICES" # fallback + + +def get_diffusion_worker_class(): + """Get the appropriate diffusion WorkerProc class based on current device type. + + Returns: + The WorkerProc class for the detected device type. + + Raises: + ImportError: If the worker module for the detected device is not available. + """ + device_type = detect_device_type() + + if device_type == "npu": + from vllm_omni.diffusion.worker.npu.npu_worker import NPUWorkerProc + + return NPUWorkerProc + else: + # Default to GPU worker for cuda and other devices + from vllm_omni.diffusion.worker.gpu_worker import WorkerProc + + return WorkerProc diff --git a/vllm_omni/worker/npu/__init__.py b/vllm_omni/worker/npu/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/worker/npu/npu_ar_model_runner.py b/vllm_omni/worker/npu/npu_ar_model_runner.py new file mode 100644 index 00000000000..bff15f2abc2 --- /dev/null +++ b/vllm_omni/worker/npu/npu_ar_model_runner.py @@ -0,0 +1,963 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import math +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from vllm.config import CUDAGraphMode +from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import logger +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.core.sched.output import SchedulerOutput + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, +) + +# yapf: enable +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + ModelRunnerOutput, +) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.compilation.acl_graph import update_attn_params, update_mla_attn_params +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.utils import ( + ProfileExecuteDuration, + enable_sp, + lmhead_tp_enable, +) +from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput + +from vllm_omni.engine import AdditionalInformationPayload, PromptEmbedsPayload +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.npu.npu_model_runner import OmniNPUModelRunner + + +class NPUARModelRunner(OmniNPUModelRunner): + """Autoregressive NPU model runner that returns hidden states per request.""" + + def _prepare_inputs( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> tuple[ + dict[str, Any], + torch.Tensor, + np.ndarray, + int, + torch.Tensor, + int, + torch.Tensor, + SpecDecodeMetadata, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int, + dict[str, dict] | None, + ]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array( + [ + num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32, + ) + + if self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]: + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph(total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo) = self._sync_metadata_across_dp( + num_input_tokens, with_prefill, enable_dbo + ) + + # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens + # We should consider removing maybe_padded_num_tokens later + num_input_tokens = maybe_padded_num_tokens + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True + ) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + self.query_start_loc[: num_reqs + 1].copy_(self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.query_start_loc[num_reqs + 1 :].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_(self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.attn_state = attn_state # type: ignore + + self.with_prefill = with_prefill + self.num_tokens_across_dp = num_tokens_across_dp + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) + attn_metadata: dict[str, Any] = {} + + # Omni-new: per_req_additional_information + per_req_additional_information: dict[str, dict] | None = None + + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:total_num_scheduled_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings(input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:total_num_scheduled_tokens].copy_(inputs_embeds) + + # -------------------------------------- Omni-new ------------------------------------------------- + # Omni-new: Reset per-step additional information collector (deprecated concat path) + if hasattr(self, "_forward_additional_information"): + self._forward_additional_information = None + # Omni-new: per-request additional information for this step + per_req_additional_information = {} + + # Omni-new: Overlay custom prompt_embeds per request for the prompt portion; + # collect additional_information (tensor/list) for prefill portion only + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + addi_cpu = getattr(req_state, "additional_information_cpu", None) + num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + if pe_cpu is not None: + src = pe_cpu[num_computed_tokens : num_computed_tokens + overlay_len].to( + dtype=self.dtype, device=self.device, non_blocking=True + ) + start_offset = int(self.query_start_loc_cpu[req_index]) + self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + # Build per-request additional information (no cross-request concat) + if addi_cpu is not None and isinstance(addi_cpu, dict): + req_info: dict[str, object] = {} + for k, v in addi_cpu.items(): + if isinstance(v, torch.Tensor): + # For prefill tokens, pass only the scheduled slice; + # for decode or no scheduled tokens, pass whole tensor + if overlay_len > 0: + try: + seg = ( + v[num_computed_tokens : num_computed_tokens + overlay_len] + .detach() + .to("cpu") + .contiguous() + ) + except Exception: + seg = v.detach().to("cpu").contiguous() + req_info[k] = seg + else: + req_info[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + req_info[k] = v + else: + req_info[k] = v + per_req_additional_information[req_id] = req_info + # ------------------------------------------------------------------------------------------------ + + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = self.input_ids[:num_input_tokens] + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens + ) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_(v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors( + {k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items()} + ) + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to(self.device, non_blocking=True) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata(num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[:total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + self.slot_mapping[total_num_scheduled_tokens:].fill_(0) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc_cpu[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, + ) + + if self.speculative_config and spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder) or self.model_config.runner_type == "pooling": + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args, + ) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) + + return ( + attn_metadata, + positions, + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_num_scheduled_tokens, + per_req_additional_information, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + with ProfileExecuteDuration().capture_async("prepare input"): + self._update_states(scheduler_output) + + # -------------------------------------- Omni-new ------------------------------------------------- + # Omni-new: Decode per-request prompt_embeds / additional_hidden_states payloads + # (if present) into CPU tensors + try: + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + if new_reqs: + for nr in new_reqs: + req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) + if req_id is None: + continue + # prompt_embeds + payload_pe = getattr(nr, "prompt_embeds", None) + if payload_pe is not None: + if isinstance(payload_pe, torch.Tensor): + pe_cpu = payload_pe.detach().to("cpu").contiguous() + elif isinstance(payload_pe, PromptEmbedsPayload): + dt = np.dtype(getattr(payload_pe, "dtype", "float32")) + arr = np.frombuffer(payload_pe.data, dtype=dt) + arr = arr.reshape(payload_pe.shape) + pe_cpu = torch.from_numpy(arr) + else: + pe_cpu = None + if pe_cpu is not None and req_id in self.requests: + setattr( + self.requests[req_id], + "prompt_embeds_cpu", + pe_cpu, + ) + # additional_information + payload_info = getattr(nr, "additional_information", None) + if payload_info is not None: + info_dict = {} + if isinstance(payload_info, dict): + # Already decoded + info_dict = payload_info + elif isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict and req_id in self.requests: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding prompt_embeds / additional_information: {e}") + pass + # ------------------------------------------------------------------------------------------------ + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + logger.debug("skip this step for we receive the data from remote disaggregate prefill node") + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output) + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + ( + attn_metadata, + positions, + num_scheduled_tokens_np, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_query_len, + per_req_additional_information, + ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + moe_comm_type = self._select_moe_comm_method(num_input_tokens, self.with_prefill) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) + aclgraph_runtime_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch(batch_descriptor) + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + self.maybe_setup_kv_connector(scheduler_output) + # -------------------------------------- Omni-new ------------------------------------------------- + model_kwargs_extra = {} + # Pass per-request additional information map for this step (no concat) + if per_req_additional_information: + model_kwargs_extra["additional_information_by_req_id"] = per_req_additional_information + # Always pass per-request runtime additional_information (persisted in request state) + try: + per_req_runtime_info = [] + for req_id in self.input_batch.req_ids: + req_state = self.requests.get(req_id) + info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + per_req_runtime_info.append(info if isinstance(info, dict) else {}) + model_kwargs_extra["runtime_additional_information"] = per_req_runtime_info + model_kwargs_extra["request_ids"] = self.input_batch.req_ids + # Pass each request's token span within the flattened sequence for this step, + # enabling the model to map decode/prefill by request + req_token_spans = [] + for req_index in range(len(self.input_batch.req_ids)): + start_offset = int(self.query_start_loc_cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + req_token_spans.append((start_offset, start_offset + sched_tokens)) + model_kwargs_extra["request_token_spans"] = req_token_spans + except Exception: + pass + # ------------------------------------------------------------------------------------------------ + + hidden_states = self._generate_process_reqs_hidden_states( + attn_metadata, + self.with_prefill, + maybe_padded_num_tokens, + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + model_kwargs_extra, + ) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) + + aux_hidden_states = None + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, aux_hidden_states = hidden_states + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) + finished_sending = None + finished_recving = None + with ProfileExecuteDuration().capture_async("post process"): + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + + # -------------------------------------- Omni-new ------------------------------------------------- + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + # The model side may return per-request additional_information updates (model-agnostic channel). + # Convention: multimodal_outputs["additional_information_update"] is a list[dict] in batch order; + # the runner merges it into the corresponding request's additional_information_cpu for subsequent decode. + try: + if isinstance(multimodal_outputs, dict) and ( + "additional_information_update" in multimodal_outputs + or "additional_information_update_by_req_id" in multimodal_outputs + ): + # Option A: list[dict] in batch order + updates_list = multimodal_outputs.get("additional_information_update") + if isinstance(updates_list, list): + for idx, upd in enumerate(updates_list): + if not isinstance(upd, dict) or idx >= len(self.input_batch.req_ids): + continue + req_id = self.input_batch.req_ids[idx] + self._merge_additional_information_update(req_id, upd) + # Option B: dict[str, dict] keyed by req_id + updates_map = multimodal_outputs.get("additional_information_update_by_req_id") + if isinstance(updates_map, dict): + for req_id, upd in updates_map.items(): + if not isinstance(upd, dict): + continue + if req_id not in self.requests: + continue + self._merge_additional_information_update(req_id, upd) + except Exception as e: + logger.error( + f"Error merging for requests:{self.input_batch.req_ids} additional \ + information update: {e}, with the multimodal_outputs as {multimodal_outputs}" + ) + # ------------------------------------------------------------------------------------------------ + broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict(hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, + finished_sending, + finished_recving, + kv_connector_output, + ) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + if broadcast_pp_output: + model_output_broadcast_data = ( + { + "logits": logits.contiguous(), + } + if logits is not None + else {} + ) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + logits = self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[: self.input_batch.num_reqs] + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + if lmhead_tp_enable() and logits is not None: + logits = logits[: len(spec_decode_metadata.logits_indices)] + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + if self.need_accepted_tokens: + self._update_states_after_model_execute(output_token_ids) + + discard_sampled_tokens_req_indices: list[int] = [] + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] + if seq_len < req_state.num_tokens: + # Ignore the sampled token. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + discard_sampled_tokens_req_indices.append(i) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[: scheduler_output.total_num_scheduled_tokens], + scheduler_output, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + else: + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set + } + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}" + ) + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + if self.speculative_config: + self._draft_token_ids = self.propose_draft_token_ids( + valid_sampled_token_ids, + sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + scheduler_output.total_num_scheduled_tokens, + hidden_states, + attn_metadata, + aux_hidden_states, + ) + + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + # -------------------------------------- Omni-new ------------------------------------------------- + # Omni-new: Convert to per-request tensors on CPU + hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() + pooler_output: list[torch.Tensor | None] = [] + prev_logits_index = 0 + for rid, logits_index in zip(req_ids_output_copy, logits_indices): + # Base payload: hidden slice for this request in this iteration + hidden_slice = hidden_states_cpu[prev_logits_index : logits_index + 1] + payload: dict[str, object] = {"hidden": hidden_slice} + # Merge multimodal_outputs if present + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + mm_payload: dict[str, object] = {} + for k, v in multimodal_outputs.items(): + try: + # Case 1: tensor aligned on token dimension + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + mm_payload[k] = v.detach().to("cpu")[prev_logits_index : logits_index + 1].contiguous() + # Case 2: nested dict of tensors aligned on token dimension (e.g., selected_hidden_layers) + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: + sub_dict[str(sk)] = ( + sv.detach().to("cpu")[prev_logits_index : logits_index + 1].contiguous() + ) + if sub_dict: + mm_payload[k] = sub_dict + elif isinstance(v, list): + element: torch.Tensor = v[0] + multimodal_outputs[k] = v[1:] if len(v) > 1 else v + mm_payload[k] = element + except Exception as e: + # Best-effort; skip malformed entries + logger.error(f"Error in merge multimodal outputs: {e}") + if mm_payload: + payload.update(mm_payload) + pooler_output.append(payload) # type: ignore[arg-type] + prev_logits_index = logits_index + 1 + + # Omni-new + output = OmniModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), + kv_connector_output=kv_connector_output, + ) + # ------------------------------------------------------------------------------------------------ + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + if not self.use_async_scheduling: + return output + + return AsyncNPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + + def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: + req_state = self.requests.get(req_id) + if req_state is None: + return + existing = getattr(req_state, "additional_information_cpu", {}) + if not isinstance(existing, dict): + existing = {} + merged = dict(existing) + for k, v in upd.items(): + if isinstance(v, torch.Tensor): + merged[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + new_list = [] + for item in v: + if isinstance(item, torch.Tensor): + new_list.append(item.detach().to("cpu").contiguous()) + else: + new_list.append(item) + merged[k] = new_list + else: + merged[k] = v + setattr(req_state, "additional_information_cpu", merged) + + def _generate_process_reqs_hidden_states( + self, + attn_metadata, + with_prefill, + maybe_padded_num_tokens, + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + model_kwargs_extra, + ): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs_extra, + ) + + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params( + self.update_stream, forward_context, maybe_padded_num_tokens, self.speculative_config + ) + else: + update_attn_params(self.update_stream, forward_context, maybe_padded_num_tokens) + + if get_forward_context().sp_enabled: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] + return hidden_states diff --git a/vllm_omni/worker/npu/npu_ar_worker.py b/vllm_omni/worker/npu/npu_ar_worker.py new file mode 100644 index 00000000000..fb2f0586eb5 --- /dev/null +++ b/vllm_omni/worker/npu/npu_ar_worker.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_ascend.worker.worker_v1 import NPUWorker + +from vllm_omni.worker.npu.npu_ar_model_runner import NPUARModelRunner + + +class NPUARWorker(NPUWorker): + """NPU AR worker for thinker/talker stages in Omni model.""" + + def init_device(self): + device = self._init_device() + + self.model_runner: NPUARModelRunner = NPUARModelRunner(self.vllm_config, device) diff --git a/vllm_omni/worker/npu/npu_diffusion_model_runner.py b/vllm_omni/worker/npu/npu_diffusion_model_runner.py new file mode 100644 index 00000000000..ccda654e22c --- /dev/null +++ b/vllm_omni/worker/npu/npu_diffusion_model_runner.py @@ -0,0 +1,748 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torch.nn as nn +from vllm.config import CUDAGraphMode +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import BatchDescriptor +from vllm.logger import logger +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.sequence import IntermediateTensors +from vllm.utils import ( + cdiv, +) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.core.sched.output import SchedulerOutput + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, +) + +# yapf: enable +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + ModelRunnerOutput, +) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.utils import ( + ProfileExecuteDuration, + enable_sp, + lmhead_tp_enable, +) +from vllm_ascend.worker.model_runner_v1 import AsyncNPUModelRunnerOutput + +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.npu.npu_model_runner import OmniNPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class NPUDiffusionModelRunner(OmniNPUModelRunner): + """Diffusion model runner for vLLM-omni on NPU (non-autoregressive).""" + + def _prepare_inputs( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> tuple[ + dict[str, Any], + torch.Tensor, + np.ndarray, + int, + torch.Tensor, + int, + torch.Tensor, + SpecDecodeMetadata, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int, + dict[str, Any], + ]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array( + [ + num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32, + ) + + if self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]: + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph(total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo) = self._sync_metadata_across_dp( + num_input_tokens, with_prefill, enable_dbo + ) + + # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens + # We should consider removing maybe_padded_num_tokens later + num_input_tokens = maybe_padded_num_tokens + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True + ) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + self.query_start_loc[: num_reqs + 1].copy_(self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.query_start_loc[num_reqs + 1 :].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_(self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.attn_state = attn_state # type: ignore + + self.with_prefill = with_prefill + self.num_tokens_across_dp = num_tokens_across_dp + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) + attn_metadata: dict[str, Any] = {} + + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + # -------------------------------------- Omni-new ------------------------------------------------- + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:num_input_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings(input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + + model_kwargs = { + **self._init_model_kwargs(num_input_tokens), + **self._extract_mm_kwargs(scheduler_output), + } + # (NOTE) Omni-new: input_ids isn't set as None + # ------------------------------------------------------------------------------------------------- + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs(num_input_tokens) + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens + ) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_(v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors( + {k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items()} + ) + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to(self.device, non_blocking=True) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata(num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[:total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + self.slot_mapping[total_num_scheduled_tokens:].fill_(0) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, + ) + + if self.speculative_config and spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder) or self.model_config.runner_type == "pooling": + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args, + ) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) + + return ( + attn_metadata, + positions, + num_input_tokens, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_num_scheduled_tokens, + model_kwargs, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + with ProfileExecuteDuration().capture_async("prepare input"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + ( + attn_metadata, + positions, + num_scheduled_tokens_np, + num_input_tokens, + num_tokens_across_dp, + maybe_padded_num_tokens, + logits_indices, + spec_decode_metadata, + input_ids, + inputs_embeds, + intermediate_tensors, + max_query_len, + model_kwargs, + ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + moe_comm_type = self._select_moe_comm_method(num_input_tokens, self.with_prefill) + + # -------------------------------------- Omni-new ------------------------------------------------- + # Omni-new: don't use cudagraph_dispatcher + # and remove ubatch_slices + aclgraph_runtime_mode = CUDAGraphMode.NONE + # ------------------------------------------------------------------------------------------------- + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=None, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + self.maybe_setup_kv_connector(scheduler_output) + + outputs = self._run_diffusion( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + multimodal_kwargs=model_kwargs, + logits_indices=logits_indices, + ) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) + + aux_hidden_states = None + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, aux_hidden_states = outputs + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) + finished_sending = None + finished_recving = None + + # -------------------------------------- Omni-new ------------------------------------------------- + # Omni-new: extract_multimodal_outputs + _, multimodal_outputs = self.extract_multimodal_outputs(outputs) + # Ensure one tensor per request, map to CPU for output struct + pooler_output: list[torch.Tensor | None] = [] + if isinstance(multimodal_outputs, torch.Tensor): + # If model returned a single stacked tensor, split by requests + assert multimodal_outputs.shape[0] == self.input_batch.num_reqs + for i in range(self.input_batch.num_reqs): + pooler_output.append(multimodal_outputs[i].detach().to("cpu").contiguous()) + elif isinstance(multimodal_outputs, list): + for out in multimodal_outputs: + pooler_output.append(out.detach().to("cpu").contiguous() if out is not None else None) + elif isinstance(multimodal_outputs, dict): + for out in multimodal_outputs.values(): + pooler_output.append(out.detach().to("cpu").contiguous() if out is not None else None) + else: + raise RuntimeError("Unsupported diffusion output type") + + output = OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + num_nans_in_logits={}, + ) + # ------------------------------------------------------------------------------------------------- + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + + if self.dynamic_eplb: + self.eplb_updator.forward_end() + + if not self.use_async_scheduling: + return output + + return AsyncNPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=[], + invalid_req_indices=[], + async_output_copy_stream=self.async_output_copy_stream, + ) + + def _run_diffusion( + self, + *, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, + multimodal_kwargs: dict, + logits_indices: torch.Tensor, + ) -> torch.Tensor | list[torch.Tensor]: + """Runs the diffusion process and returns per-request tensors. + + Tries model interfaces in the following order for maximal compatibility: + 1) model.sample(condition=..., **kwargs) + 2) model.forward(condition=..., **kwargs) + 3) model.diffuse(condition=..., **kwargs) + """ + kwargs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs(multimodal_kwargs, device=self.device), + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + if hasattr(self.model, "forward"): + return self.model.forward(**kwargs) + # TODO: add the diffuse method for other models + + raise RuntimeError( + "The loaded model does not expose diffusion interfaces 'sample', " + "'forward', or 'diffuse'. Please implement one of them or adapt the runner." + ) + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + is_torchair_compile: bool = False, + aclgraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + ) -> torch.Tensor: + # only support eager mode and piecewise graph now + assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + # Force dummy run on prefill stage when this node is deemed as kv producer. + if self.is_kv_producer and not self.is_kv_consumer: + with_prefill = True + + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp( + num_tokens, with_prefill, False + ) + + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.max_num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + num_reqs = (num_tokens + self.decode_token_per_req - 1) // self.decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() + + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, dtype=self.dtype, device=self.device + ) + intermediate_tensors = IntermediateTensors( + {k: v[:num_tokens] for k, v in self.intermediate_tensors.items()} + ) + + # filter out the valid batch descriptor + _ag_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, uniform_decode=uniform_decode) + ) + if aclgraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + assert aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode == _ag_mode, ( + f"Aclgraph runtime mode mismatch at dummy_run. " + f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}." + ) + else: + aclgraph_runtime_mode = _ag_mode + + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( + False, + num_reqs=num_reqs, + num_tokens=num_tokens, + max_query_len=max_query_len, + aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, + ) + + need_dummy_logits = not self.in_profile_run and lmhead_tp_enable() + + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices]) + + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill, + is_torchair_compile, + input_ids, + positions, + attn_metadata, + num_tokens, + intermediate_tensors, + inputs_embeds, + ) + if need_dummy_logits: + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + skip_attn=True, + num_reqs=num_reqs, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + if need_dummy_logits: + self.drafter.model.compute_logits(hidden_states[dummy_indices]) + if self.in_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + + # -------------------------------------- Omni-new ------------------------------------------------- + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + # ------------------------------------------------------------------------------------------------- + return hidden_states diff --git a/vllm_omni/worker/npu/npu_diffusion_worker.py b/vllm_omni/worker/npu/npu_diffusion_worker.py new file mode 100644 index 00000000000..9974354adb6 --- /dev/null +++ b/vllm_omni/worker/npu/npu_diffusion_worker.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_ascend.worker.worker_v1 import NPUWorker + +from vllm_omni.worker.npu.npu_diffusion_model_runner import NPUDiffusionModelRunner + + +class NPUDiffusionWorker(NPUWorker): + """NPU diffusion worker for code2wav stage in Omni model.""" + + def init_device(self): + device = self._init_device() + + self.model_runner: NPUDiffusionModelRunner = NPUDiffusionModelRunner(self.vllm_config, device) diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py new file mode 100644 index 00000000000..3b1fdb48671 --- /dev/null +++ b/vllm_omni/worker/npu/npu_model_runner.py @@ -0,0 +1,543 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import numpy as np +import torch +from vllm.config import CUDAGraphMode +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.interfaces import SupportsMultiModal, supports_mrope +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import BatchedTensorInputs, MultiModalKwargsItem +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.sampling_params import SamplingType +from vllm.utils import cdiv +from vllm.v1.worker.gpu_model_runner import IntermediateTensors +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.utils import enable_sp, lmhead_tp_enable +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from vllm_ascend.worker.npu_input_batch import CachedRequestState + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +class OmniNPUModelRunner(NPUModelRunner): + """ + Base class for NPU model runners with multimodality support. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.is_multimodal_raw_input_only_model = self.model_config.is_multimodal_raw_input_only_model + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(self.model_config) + + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) + + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + req_ids_to_add: list[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if sampling_params and sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if pooling_params: + assert (task := pooling_params.task) is not None, "You did not set `task` in the API" + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + backward_kwargs = {} + backward_kwargs["mm_features"] = new_req_data.mm_features + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + **backward_kwargs, + ) + + # -------------------------------------- Omni-new ------------------------------------------------- + # If prompt embeddings are provided, decode and attach to inter_data + try: + if getattr(new_req_data, "prompt_embeds", None) is not None: + payload = new_req_data.prompt_embeds + dtype = getattr(np, payload.dtype) + arr = np.frombuffer(payload.data, dtype=dtype) + arr = arr.reshape(payload.shape) + pe_cpu = torch.from_numpy(arr) + # Store temporarily on CPU; later moved to device in builder + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # Also replace payload with Tensor for user visibility in + # scheduler_output + try: + new_req_data.prompt_embeds = pe_cpu # type: ignore[assignment] + except Exception: + pass + except Exception as e: + logger.error(f"Error decoding prompt embeds: {e}") + # Decode additional_information payloads (dictionary) + try: + if getattr(new_req_data, "additional_information", None) is not None: + payload_info = new_req_data.additional_information + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + else: + from vllm_omni.engine import AdditionalInformationPayload + + if isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding additional information: {e}") + pass + # ------------------------------------------------------------------------------------------------ + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(self.requests[req_id]) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = num_computed_tokens + len(new_token_ids) - req_state.num_tokens + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + else: + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens + if new_block_ids is not None: + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()) + if spec_token_ids: + num_spec_tokens = len(spec_token_ids) + start_index = self.input_batch.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.input_batch.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] += num_spec_tokens + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + @torch.inference_mode() + def extract_multimodal_outputs( + self, hidden_states: Union[torch.Tensor, list[torch.Tensor]] + ) -> tuple[torch.Tensor, Union[torch.Tensor, list[torch.Tensor], dict]]: + """Extract multimodal outputs from hidden states.""" + if hasattr(self.model, "have_multimodal_outputs") and self.model.have_multimodal_outputs: + text_hidden_states = hidden_states.text_hidden_states + multimodal_outputs = hidden_states.multimodal_outputs + + elif isinstance(hidden_states, torch.Tensor): + text_hidden_states = hidden_states + multimodal_outputs = {} + elif isinstance(hidden_states, list): + text_hidden_states = hidden_states[0] + multimodal_outputs = {} + else: + raise ValueError(f"Invalid hidden states type: {type(hidden_states)}") + return text_hidden_states, multimodal_outputs + + def _init_model_kwargs(self, num_tokens: int): + model_kwargs = dict[str, Any]() + + if not self.is_pooling_model: + return model_kwargs + + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if ( + param.extra_kwargs is not None + and (token_types := param.extra_kwargs.get("compressed_token_type_ids")) is not None + ): + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens_cpu[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(device=self.device) + return model_kwargs + + def _extract_mm_kwargs( + self, + scheduler_output: "SchedulerOutput", + ) -> BatchedTensorInputs: + if not scheduler_output or not self.is_multimodal_raw_input_only_model: + return {} + + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + for feature in req.mm_features: + if feature.data is not None: + mm_kwargs.append(feature.data) + + # Input all modalities at once + model = cast(SupportsMultiModal, self.model) + mm_kwargs_combined: BatchedTensorInputs = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ): + mm_kwargs_combined.update(mm_kwargs_group) + + return mm_kwargs_combined + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + is_torchair_compile: bool = False, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + uniform_decode: bool = False, + ) -> torch.Tensor: + # only support eager mode and piecewise graph now + assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + # Force dummy run on prefill stage when this node is deemed as kv producer. + if self.is_kv_producer and not self.is_kv_consumer: + with_prefill = True + + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp( + num_tokens, with_prefill, False + ) + + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.max_num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + num_reqs = (num_tokens + self.decode_token_per_req - 1) // self.decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() + + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, dtype=self.dtype, device=self.device + ) + intermediate_tensors = IntermediateTensors( + {k: v[:num_tokens] for k, v in self.intermediate_tensors.items()} + ) + + # filter out the valid batch descriptor + _ag_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, uniform_decode=uniform_decode) + ) + if aclgraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + assert aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode == _ag_mode, ( + f"Aclgraph runtime mode mismatch at dummy_run. " + f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}." + ) + else: + aclgraph_runtime_mode = _ag_mode + + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( + False, + num_reqs=num_reqs, + num_tokens=num_tokens, + max_query_len=max_query_len, + aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, + ) + + need_dummy_logits = not self.in_profile_run and lmhead_tp_enable() + + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices]) + + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + ): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill, + is_torchair_compile, + input_ids, + positions, + attn_metadata, + num_tokens, + intermediate_tensors, + inputs_embeds, + ) + if need_dummy_logits: + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + skip_attn=True, + num_reqs=num_reqs, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + if need_dummy_logits: + self.drafter.model.compute_logits(hidden_states[dummy_indices]) + if self.in_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + + return hidden_states