From cc549869b5765b536f5afdb29cc77f3358f6d411 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Fri, 16 Jan 2026 10:07:04 +0000 Subject: [PATCH 01/11] [NPU] Enable ACLGraph for Qwen3-Omni Signed-off-by: gcanlin --- .../stage_configs/npu/qwen3_omni_moe.yaml | 4 +- vllm_omni/worker/npu/npu_ar_model_runner.py | 44 ++++++++++++++++++- vllm_omni/worker/npu/npu_model_runner.py | 42 +++++++++++++++--- 3 files changed, 81 insertions(+), 9 deletions(-) diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml index db51e6d7d3..f89f205a21 100644 --- a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml @@ -15,7 +15,7 @@ stage_args: worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.6 - enforce_eager: true + enforce_eager: false trust_remote_code: true engine_output_type: latent # Output hidden states for talker distributed_executor_backend: "mp" @@ -44,7 +44,7 @@ stage_args: worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.2 - enforce_eager: true + enforce_eager: false trust_remote_code: true engine_output_type: latent # Output codec codes for code2wav # tensor_parallel_size: 2 diff --git a/vllm_omni/worker/npu/npu_ar_model_runner.py b/vllm_omni/worker/npu/npu_ar_model_runner.py index fcaf367953..a1b9fbe26a 100644 --- a/vllm_omni/worker/npu/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu/npu_ar_model_runner.py @@ -420,6 +420,7 @@ def _prepare_inputs( if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only + decode_req_ids = [] for req_index, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests.get(req_id) req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None @@ -433,6 +434,14 @@ def _prepare_inputs( req_input_ids, req_embeds, update_dict = self.model.preprocess( input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos ) + if hasattr(self.model, "talker_mtp") and span_len == 1: + last_talker_hidden, text_step = update_dict.pop("mtp_inputs") + decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) + self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) + self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) + self.text_step.gpu[decode_slice].copy_(text_step) + decode_req_ids.append(req_id) # TODO(Peiqi): the merge stage could move out from the critical path self._merge_additional_information_update(req_id, update_dict) @@ -442,6 +451,9 @@ def _prepare_inputs( if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: input_ids[s : s + seg_len] = req_input_ids + # run talker mtp decode + if hasattr(self.model, "talker_mtp"): + self._talker_mtp_forward(decode_req_ids, inputs_embeds) # -------------------------------------- Omni-new ------------------------------------------------- @@ -869,7 +881,9 @@ def execute_model( isinstance(hidden_states[0], torch.Tensor): hidden_states = hidden_states[0] sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states) + logits = self.model.compute_logits( + sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata + ) if broadcast_pp_output: model_output_broadcast_data = { "logits": logits.contiguous(), @@ -1277,3 +1291,31 @@ def _process_additional_information_updates( import traceback traceback.print_exc() + + def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: + decode_batch_size = len(decode_req_ids) + if decode_batch_size == 0: + return + _cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding( + num_tokens=decode_batch_size, + num_reqs=decode_batch_size, + num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size] + req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size] + last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size] + text_step = self.text_step.gpu[:decode_batch_size] + with set_ascend_forward_context( + None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc + ): + req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # update the inputs_embeds and code_predictor_codes + code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + for idx, req_id in enumerate(decode_req_ids): + req_index = self.input_batch.req_ids.index(req_id) + start_offset = int(self.query_start_loc.cpu[req_index]) + inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] + update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + self._merge_additional_information_update(req_id, update_dict) \ No newline at end of file diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py index 264d5e3413..05c55388db 100644 --- a/vllm_omni/worker/npu/npu_model_runner.py +++ b/vllm_omni/worker/npu/npu_model_runner.py @@ -10,6 +10,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.models.interfaces import supports_mrope from vllm.model_executor.models.interfaces_base import VllmModelForPooling @@ -40,6 +41,25 @@ def __init__(self, *args, **kwargs): self._omni_num_scheduled_tokens_np: np.ndarray | None = None self._omni_last_model_output: object | None = None + def load_model(self, *args, **kwargs) -> None: + super().load_model(*args, **kwargs) + # TODO move this model specific logic to a separate class + if hasattr(self.model, "talker_mtp") and self.model.talker is not None: + self.talker_mtp = self.model.talker_mtp + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None + if cudagraph_mode.has_full_cudagraphs(): + self.talker_mtp = ACLGraphWrapper( + self.model.talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) + hidden_size = self.model_config.hf_config.talker_config.text_config.hidden_size + self.talker_mtp_input_ids = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.talker_mtp_inputs_embeds = self._make_buffer( + self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False + ) + self.last_talker_hidden = self._make_buffer(self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False) + self.text_step = self._make_buffer(self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False) + def _init_mrope_positions(self, req_state: CachedRequestState): image_grid_thw = [] video_grid_thw = [] @@ -302,18 +322,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() @torch.inference_mode() - def extract_multimodal_outputs( - self, hidden_states: torch.Tensor | list[torch.Tensor] | OmniOutput - ) -> tuple[torch.Tensor, torch.Tensor | list[torch.Tensor] | dict]: - """Extract multimodal outputs from hidden states.""" - if hasattr(self.model, "have_multimodal_outputs") and self.model.have_multimodal_outputs: + def extract_multimodal_outputs(self, hidden_states: torch.Tensor | list[torch.Tensor] | OmniOutput) -> dict: + if ( + hasattr(self.model, "have_multimodal_outputs") + and self.model.have_multimodal_outputs + and isinstance(hidden_states, OmniOutput) + ): text_hidden_states = hidden_states.text_hidden_states multimodal_outputs = hidden_states.multimodal_outputs elif isinstance(hidden_states, torch.Tensor): text_hidden_states = hidden_states multimodal_outputs = {} - elif isinstance(hidden_states, list): + elif isinstance(hidden_states, list) or isinstance(hidden_states, tuple): text_hidden_states = hidden_states[0] multimodal_outputs = {} else: @@ -496,6 +517,13 @@ def dummy_drafter_compute_logits(hidden_states): model_instance=self.model, weight_prefetch_method=self.weight_prefetch_method, ): + if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): + hidden_states = self.talker_mtp( + self.talker_mtp_input_ids.gpu[:num_tokens_padded], + self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded], + self.last_talker_hidden.gpu[:num_tokens_padded], + self.text_step.gpu[:num_tokens_padded], + ) hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds ) @@ -663,6 +691,8 @@ def _model_forward( **model_kwargs, **model_kwargs_extra, ) + if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"): + model_output = self.model.make_omni_output(model_output, **model_kwargs_extra) # Cache model output so later sample_tokens can consume multimodal results. self._omni_last_model_output = model_output return model_output From 02c06c2370f3e0e1a085f96a522233bc859f31df Mon Sep 17 00:00:00 2001 From: gcanlin Date: Fri, 16 Jan 2026 10:08:52 +0000 Subject: [PATCH 02/11] fix lint Signed-off-by: gcanlin --- vllm_omni/worker/npu/npu_ar_model_runner.py | 2 +- vllm_omni/worker/npu/npu_model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/worker/npu/npu_ar_model_runner.py b/vllm_omni/worker/npu/npu_ar_model_runner.py index a1b9fbe26a..f740cb5562 100644 --- a/vllm_omni/worker/npu/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu/npu_ar_model_runner.py @@ -1318,4 +1318,4 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te start_offset = int(self.query_start_loc.cpu[req_index]) inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} - self._merge_additional_information_update(req_id, update_dict) \ No newline at end of file + self._merge_additional_information_update(req_id, update_dict) diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py index 05c55388db..e461f055ce 100644 --- a/vllm_omni/worker/npu/npu_model_runner.py +++ b/vllm_omni/worker/npu/npu_model_runner.py @@ -10,7 +10,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger -from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.models.interfaces import supports_mrope from vllm.model_executor.models.interfaces_base import VllmModelForPooling @@ -19,6 +18,7 @@ from vllm.utils.math_utils import cdiv from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.utils import enable_sp, lmhead_tp_enable from vllm_ascend.worker.model_runner_v1 import NPUModelRunner From 78cb352fcdfffdf136f7239ce42427bcfd27efdb Mon Sep 17 00:00:00 2001 From: gcanlin Date: Fri, 16 Jan 2026 16:19:07 +0000 Subject: [PATCH 03/11] revert to eager mode by default Signed-off-by: gcanlin --- .../model_executor/stage_configs/npu/qwen3_omni_moe.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml index f89f205a21..db51e6d7d3 100644 --- a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml @@ -15,7 +15,7 @@ stage_args: worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.6 - enforce_eager: false + enforce_eager: true trust_remote_code: true engine_output_type: latent # Output hidden states for talker distributed_executor_backend: "mp" @@ -44,7 +44,7 @@ stage_args: worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.2 - enforce_eager: false + enforce_eager: true trust_remote_code: true engine_output_type: latent # Output codec codes for code2wav # tensor_parallel_size: 2 From bcb3b10b9644abda36a29f34fd32e6f9fb21bc46 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 03:05:59 +0000 Subject: [PATCH 04/11] [NPU] Upgrade to v0.14.0 Signed-off-by: gcanlin --- .../qwen3_omni/qwen3_omni_moe_talker.py | 28 +- .../stage_configs/npu/qwen3_omni_moe.yaml | 4 +- vllm_omni/worker/npu/npu_ar_model_runner.py | 1040 ++--------------- vllm_omni/worker/npu/npu_ar_worker.py | 9 +- .../worker/npu/npu_generation_model_runner.py | 773 +++--------- vllm_omni/worker/npu/npu_generation_worker.py | 9 +- vllm_omni/worker/npu/npu_model_runner.py | 398 ++++++- 7 files changed, 659 insertions(+), 1602 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 2f1893e00c..5b032d86ea 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -563,7 +563,10 @@ class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module): - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight - mlp.shared_expert_gate.weight (sibling, not child) - The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x) + The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x). + + It also exposes the underlying shared_expert interface to keep + compatibility with backends that split shared-expert computation. """ def __init__( @@ -575,9 +578,30 @@ def __init__( self._shared_expert = shared_expert self._shared_expert_gate = shared_expert_gate + @property + def gate_up_proj(self): + return self._shared_expert.gate_up_proj + + @property + def down_proj(self): + return self._shared_expert.down_proj + + @property + def act_fn(self): + return self._shared_expert.act_fn + + def expert_gate(self, x: torch.Tensor): + gate_out = self._shared_expert_gate(x) + if isinstance(gate_out, tuple): + return gate_out + return gate_out, None + def forward(self, x: torch.Tensor) -> torch.Tensor: out = self._shared_expert(x) - gate_values = F.sigmoid(self._shared_expert_gate(x)) # [batch, 1] + gate_out = self._shared_expert_gate(x) + if isinstance(gate_out, tuple): + gate_out = gate_out[0] + gate_values = F.sigmoid(gate_out) # [batch, 1] return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden] diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml index db51e6d7d3..f89f205a21 100644 --- a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml @@ -15,7 +15,7 @@ stage_args: worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.6 - enforce_eager: true + enforce_eager: false trust_remote_code: true engine_output_type: latent # Output hidden states for talker distributed_executor_backend: "mp" @@ -44,7 +44,7 @@ stage_args: worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.2 - enforce_eager: true + enforce_eager: false trust_remote_code: true engine_output_type: latent # Output codec codes for code2wav # tensor_parallel_size: 2 diff --git a/vllm_omni/worker/npu/npu_ar_model_runner.py b/vllm_omni/worker/npu/npu_ar_model_runner.py index f740cb5562..543e9b4548 100644 --- a/vllm_omni/worker/npu/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu/npu_ar_model_runner.py @@ -3,35 +3,31 @@ from __future__ import annotations -import math +from copy import copy from typing import Any, NamedTuple import numpy as np import torch -import torch.nn as nn from vllm.config import CUDAGraphMode -from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group -from vllm.distributed.parallel_state import get_pcp_group, get_pp_group, get_tp_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.sequence import IntermediateTensors -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - ModelRunnerOutput, + ECConnectorOutput, make_empty_encoder_model_runner_output, ) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata # yapf conflicts with isort for this block # yapf: disable @@ -41,9 +37,8 @@ update_mla_attn_dcp_pcp_params, update_mla_attn_params, ) -from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort -from vllm_ascend.spec_decode.interface import SpecDcodeType -from vllm_ascend.utils import ProfileExecuteDuration, enable_sp, lmhead_tp_enable +from vllm_ascend.ops.rotary_embedding import update_cos_sin +from vllm_ascend.utils import ProfileExecuteDuration from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.npu.npu_model_runner import OmniNPUModelRunner @@ -62,6 +57,7 @@ class ExecuteModelState(NamedTuple): kv_connector_output: KVConnectorOutput | None attn_metadata: dict[str, Any] positions: torch.Tensor + ec_connector_output: ECConnectorOutput | None multimodal_outputs: Any class NPUARModelRunner(OmniNPUModelRunner): @@ -73,6 +69,7 @@ def __init__(self, *args, **kwargs): # each model stage has their own hidden size self.hidden_size = self.model_config.hf_text_config.hidden_size self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) + self.omni_connector = None def _make_buffer(self, *size, dtype, numpy=True): # Prevent ray from pinning the buffer due to large size @@ -87,648 +84,12 @@ def _make_buffer(self, *size, dtype, numpy=True): with maybe_disable_pin_memory_for_ray(self, total_bytes): return super()._make_buffer(*size, dtype=dtype, numpy=numpy) - def _prepare_inputs( - self, - scheduler_output: SchedulerOutput, - intermediate_tensors: IntermediateTensors | None = None, - ) -> tuple[ - dict[str, Any], - torch.Tensor, - np.ndarray, - int, - torch.Tensor, - int, - torch.Tensor, - SpecDecodeMetadata, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - int - ]: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit_block_table(num_reqs) - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - _, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - positions_np = np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - ) - - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) - if self.pcp_size > 1: - if not self.vllm_config.model_config.use_mla: - self.generate_kv_idx(scheduler_output) - tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( - tokens) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - else: - position_pcp, pcp_unpad_mask = None, None - self.num_pcp_pads = self.num_pcp_pads[:num_reqs] - - total_num_pcp_pads = sum(self.num_pcp_pads) - max_num_scheduled_tokens = max(tokens) - num_valid_tokens = np.array([ - num_tokens - - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) - for num_tokens, i in zip(tokens, req_ids) - ], - dtype=np.int32) - - if (self.use_aclgraph and total_num_scheduled_tokens - <= self.cudagraph_batch_sizes[-1]): - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - total_num_scheduled_tokens) - elif self.use_aclgraph and enable_sp(self.vllm_config): - # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, - # the model will fall back to running its FX graph in eager mode. - # In this case, when sequence parallelism is enabled, we need to pad tokens to align - # with tp_size because pad_size cannot be captured by the FX graph - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - num_input_tokens = math.ceil( - total_num_scheduled_tokens / tp_size) * tp_size - else: - # Eager mode. - num_input_tokens = total_num_scheduled_tokens - - # Get the attention state. - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) - self.attn_state = attn_state # type: ignore - - # Determine if it's a splitfuse batch - with_prefill = attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - - self.query_lens = torch.from_numpy(num_scheduled_tokens) - - # Get info across DP ranks. - # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, - # Otherwise, it's just max_tokens_across_dp_cpu - (maybe_padded_num_tokens, num_tokens_across_dp, - with_prefill) = self._sync_metadata_across_dp(num_input_tokens, - with_prefill) - - # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens - # We should consider removing maybe_padded_num_tokens later - num_input_tokens = maybe_padded_num_tokens - - # Hot-Swap lora model - if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) - - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - if self.pcp_size > 1: - positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - position_pcp[:total_num_scheduled_tokens], - out=positions_np) - else: - self.positions.np[:total_num_scheduled_tokens] = positions_np - - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) - - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) - token_indices_tensor = torch.from_numpy(token_indices) - # Prepare input_ids. - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens]) - if self.enable_prompt_embeds: - is_token_ids = self.input_batch.is_token_ids_tensor.flatten() - torch.index_select( - is_token_ids, - 0, - token_indices_tensor, - out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) - - # Because we did not pre-allocate a massive prompt_embeds CPU tensor on - # the InputBatch, we need to fill in the prompt embeds into the expected - # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. - if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or - self.enable_prompt_embeds): - output_idx = 0 - for req_idx in range(num_reqs): - num_sched = num_scheduled_tokens[req_idx] - - # Skip if this request doesn't have embeddings - if req_idx not in self.input_batch.req_prompt_embeds: - output_idx += num_sched - continue - - # Skip if no tokens scheduled - if num_sched <= 0: - output_idx += num_sched - continue - - req_embeds = self.input_batch.req_prompt_embeds[req_idx] - start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] - - # Skip if trying to read beyond available embeddings - if start_pos >= req_embeds.shape[0]: - output_idx += num_sched - continue - - # Copy available embeddings - end_pos = start_pos + num_sched - actual_end = min(end_pos, req_embeds.shape[0]) - actual_num_sched = actual_end - start_pos - - if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx:output_idx + - actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) - - output_idx += num_sched - - self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens - self.query_start_loc.copy_to_gpu() - - self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - self.seq_lens.copy_to_gpu() - - # Fill unused with -1. Needed for reshape_and_cache - self.query_start_loc.gpu[num_reqs + 1:].fill_(-1) - self.seq_lens.gpu[num_reqs:].fill_(0) - - self.query_lens = torch.from_numpy(num_scheduled_tokens) - - # Copy the tensors to the NPU. - self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, - cu_num_tokens) - self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_() - self.positions.copy_to_gpu() - - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) - self.attn_mask = self._make_attention_mask(attn_state) - self.attn_state = attn_state # type: ignore - - self.with_prefill = with_prefill - self.num_tokens_across_dp = num_tokens_across_dp - attn_metadata: dict[str, Any] = {} - - # Record the index of requests that should not be sampled, - # so that we could clear the sampled tokens before returning - num_tokens = [ - self.requests[r].num_tokens for r in self.input_batch.req_ids - ] - num_tokens_np = np.array(num_tokens, dtype=np.int32) - num_reqs = self.input_batch.num_reqs - if self.pcp_size > 1: - # while pcp > 1, we need the original num_scheduled_tokens before split - # to calculate discard_requests_mask - tokens_original = [ - scheduler_output.num_scheduled_tokens[i] for i in req_ids - ] - original_seq_lens_np = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - np.array(tokens_original, dtype=np.int32)) - discard_requests_mask = original_seq_lens_np < num_tokens_np - else: - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np - - discard_request_indices = np.nonzero(discard_requests_mask)[0] - self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[:self.num_discarded_requests] = ( - discard_request_indices) - self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) - - # _prepare_inputs may reorder the batch, so we must gather - # multi-modal outputs after that to ensure the correct order - if self.is_multimodal_model: - with self.maybe_get_ec_connector_output( - scheduler_output, - encoder_cache=self.encoder_cache, - ): - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids.gpu[:total_num_scheduled_tokens] - mm_embeds, is_mm_embed = self._gather_mm_embeddings( - scheduler_output) - - inputs_embeds = self.model.embed_input_ids( - input_ids, - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) - - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_( - inputs_embeds) - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - # -------------------------------------- Omni-new ------------------------------------------------- - input_ids = self.input_ids.gpu[:num_input_tokens] - # -------------------------------------- Omni-new ------------------------------------------------- - elif self.enable_prompt_embeds and get_pp_group().is_first_rank: - # Get the input embeddings for the tokens that are not input embeds, - # then put them into the appropriate positions. - # TODO(qthequartermasterman): Since even when prompt embeds are - # enabled, (a) not all requests will use prompt embeds, and (b) - # after the initial prompt is processed, the rest of the generated - # tokens will be token ids, it is not desirable to have the - # embedding layer outside of the acl graph all the time. The v0 - # engine avoids this by "double compiling" the acl graph, once - # with input_ids and again with inputs_embeds, for all num_tokens. - # If a batch only has token ids, then including the embedding layer - # in the acl graph will be more performant (like in the else case - # below). - token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens] \ - .nonzero(as_tuple=False) \ - .squeeze(1) - # Some tokens ids may need to become embeds - if token_ids_idx.numel() > 0: - token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.embed_input_ids( - input_ids=token_ids) - self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds - - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the ACL graph. - input_ids = self.input_ids.gpu[:num_input_tokens] - inputs_embeds = None - positions = self.positions.gpu[:num_input_tokens] - if self.uses_mrope: - positions = self.mrope_positions.gpu[:, :num_input_tokens] - - # -------------------------------------- Omni-new ------------------------------------------------- - self._omni_num_scheduled_tokens_np = num_scheduled_tokens - - # Note: only prefill need collect additional_information for now. - # Decode don't need per_req_additional_information anymore. - if inputs_embeds is not None: - # Prefill: overlay prompt_embeds and collect additional_information - self._collect_additional_information_for_prefill(num_scheduled_tokens) - - if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: - # Overlay custom prompt_embeds per request for the prompt portion; - # collect additional_information (tensor/list) for prefill portion only - decode_req_ids = [] - for req_index, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests.get(req_id) - req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None - - start_offset = int(self.query_start_loc.cpu[req_index]) - sched_tokens = int(num_scheduled_tokens[req_index]) - s, e = start_offset, start_offset + sched_tokens - span_len = int(e) - int(s) - - # call the custom process function - req_input_ids, req_embeds, update_dict = self.model.preprocess( - input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos - ) - if hasattr(self.model, "talker_mtp") and span_len == 1: - last_talker_hidden, text_step = update_dict.pop("mtp_inputs") - decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) - self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) - self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) - self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) - self.text_step.gpu[decode_slice].copy_(text_step) - decode_req_ids.append(req_id) - # TODO(Peiqi): the merge stage could move out from the critical path - self._merge_additional_information_update(req_id, update_dict) - - # update the inputs_embeds and input_ids - seg_len = min(span_len, req_embeds.shape[0]) - inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] - if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: - input_ids[s : s + seg_len] = req_input_ids - - # run talker mtp decode - if hasattr(self.model, "talker_mtp"): - self._talker_mtp_forward(decode_req_ids, inputs_embeds) - # -------------------------------------- Omni-new ------------------------------------------------- - - - # type: ignore - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - assert intermediate_tensors is not None - assert self.intermediate_tensors is not None - # If both flashcomm1 and pp are used simultaneously, - # the shape of the received data and the shape of the space to be copied to will not match, - # requiring a recalculation of the incoming data's shape. - tp_size = get_tensor_model_parallel_world_size() - num_input_tokens_with_flashcomm1 = num_input_tokens - if enable_sp(): - num_input_tokens_with_flashcomm1 = (num_input_tokens + - tp_size - 1) // tp_size - for k, v in intermediate_tensors.items(): - self.intermediate_tensors[ - k][:num_input_tokens_with_flashcomm1].copy_( - v[:num_input_tokens_with_flashcomm1], - non_blocking=True) - intermediate_tensors = IntermediateTensors({ - k: - v[:num_input_tokens_with_flashcomm1] - for k, v in self.intermediate_tensors.items() - }) - - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if not use_spec_decode: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - spec_decode_metadata = None - if self.pcp_size * self.dcp_size > 1: - logits_indices = torch.from_numpy( - cu_num_tokens - ) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1 - logits_indices = logits_indices.pin_memory().to( - self.device, non_blocking=True) - else: - logits_indices = self.query_start_loc.gpu[1:num_reqs + 1] - 1 - else: - # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - # For chunked prefills, use -1 as mask rather than 0, as guided - # decoding may rollback speculative tokens. - num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx]) else -1) - - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs]) - logits_indices = spec_decode_metadata.logits_indices - - # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[: - num_reqs] = num_decode_draft_tokens - self.num_decode_draft_tokens.np[num_reqs:].fill(-1) - self.num_decode_draft_tokens.copy_to_gpu() - # save logits_indices for pcp spec decode usage - self.logits_indices = logits_indices - - # Used in the below loop. - # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) - self.spec_decode_common_attn_metadata = None - if use_spec_decode and self.need_accepted_tokens: - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() - - if self.speculative_config and self.pcp_size > 1: - self._generate_pcp_mtp_input( - num_reqs, scheduler_output.total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens) - - long_seq_metadata = self._generate_pcp_metadata( - total_num_scheduled_tokens) - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - # NOTE: This is strange, why did we use total_num_scheduled_tokens before? - slot_mapping_size = (total_num_scheduled_tokens - if self.pcp_size == 1 else - total_num_scheduled_tokens * self.pcp_size - - total_num_pcp_pads) - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): - # Encoder-only layers do not have KV cache, so we need to - # create a dummy block table and slot mapping for them. - blk_table_tensor = torch.zeros( - (num_reqs, 1), - dtype=torch.int32, - device=self.device, - ) - slot_mapping = torch.zeros( - (total_num_scheduled_tokens, ), - dtype=torch.int64, - device=self.device, - ) - else: - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor() - blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0) - if self.pcp_size > 1: - slot_mapping_for_pcp = blk_table.slot_mapping.gpu[: - long_seq_metadata - . - num_actual_tokens_pcp_padded] - slot_mapping_for_pcp[slot_mapping_size:].fill_(-1) - assert pcp_unpad_mask is not None - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: - pcp_unpad_mask - . - shape[ - 0]] - pcp_padded_slot_mapping.fill_(-1) - pcp_padded_slot_mapping[ - pcp_unpad_mask] = slot_mapping_for_pcp[: - slot_mapping_size] - slot_mapping_for_pcp[:long_seq_metadata. - num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping - blk_table.slot_mapping.gpu[:long_seq_metadata.num_actual_tokens_pcp_padded] = \ - slot_mapping_for_pcp - slot_mapping = blk_table.slot_mapping.gpu - - # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs - # has been split to multiple parts, and there are 3 parts that is related to this - # `num_reqs`, we'll take `query_start_loc` as an example: - # 1. self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens - # 2. get `num_reqs_padded`, this depends on dispatcher and which is why we have the - # following simplified `dispatch` logic here, we try to minimize the impact - # 3. query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1] - uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) \ - and (total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs) - - # TODO: We should make this official ASAP. Also note that if we pad here, - # the builders won’t need to add any extra padding. - if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - uniform_decode: - num_reqs_padded = num_input_tokens // self.uniform_decode_query_len - pad_size = num_reqs_padded - num_reqs - if pad_size > 0: - last_query_loc = self.query_start_loc.gpu[num_reqs] - - steps = torch.arange(1, - pad_size + 1, - device=self.device, - dtype=self.query_start_loc.gpu.dtype) - fill_values = last_query_loc + ( - steps * self.uniform_decode_query_len) - - self.query_start_loc.gpu[num_reqs + 1:num_reqs_padded + - 1] = fill_values - # So we are trying to simulate the behavior of GPUModelRunner's - # prepare_inputs for uniform decode mode by padding query_start_loc - num_reqs = num_reqs_padded - - # Make AscendCommonAttentionMetadata - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.gpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=slot_mapping_size, - num_input_tokens=num_input_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - # TODO: change this to the right block table for linear attn - block_table_tensor=blk_table_tensor[:num_reqs], - slot_mapping=slot_mapping, - num_computed_tokens_cpu=num_computed_tokens_cpu, - positions=self.positions.gpu, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - attn_state=self.attn_state, - is_only_prefill=bool(np.all(num_valid_tokens != 1)), - max_query_len=max_num_scheduled_tokens, - decode_token_per_req=self.decode_token_per_req, - prefill_context_parallel_metadata=long_seq_metadata, - ) - - if self.speculative_config and self.pcp_size > 1: - # For pcp + spec decode, we flatten block_table - # to avoid irregular spec_attn_mask shape, e.g., - # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, - # ori block_table: # [d0, d1, p0, p1, p2] - # (num_reqs_d + num_reqs_p, max_num_blocks), - # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] - # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs+1] - \ - self.query_start_loc_pcp_full.cpu[:num_reqs] - num_prefill_reqs = (ori_query_lens - > self.decode_threshold).sum().item() - num_decode_reqs = num_reqs - num_prefill_reqs - num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold - blk_table_tensor[ - num_decode_reqs_flatten:num_decode_reqs_flatten + - num_prefill_reqs].copy_( - blk_table_tensor[num_decode_reqs:num_decode_reqs + - num_prefill_reqs].clone()) - blk_table_tensor[:num_decode_reqs_flatten].copy_( - blk_table_tensor[:num_decode_reqs].repeat_interleave( - self.decode_threshold, dim=0)) - common_attn_metadata.block_table_tensor = \ - blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] - - if self.speculative_config and \ - self.spec_decode_common_attn_metadata is None: - self.spec_decode_common_attn_metadata = common_attn_metadata - - for attn_group in self.attn_groups[kv_cache_group_id]: - common_prefix_len = 0 - extra_attn_metadata_args = {} - builder = attn_group.get_metadata_builder() - if isinstance(builder, GDNAttentionMetadataBuilder): - if use_spec_decode: - patch_torch_npu_argsort() - extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens. - gpu[:num_reqs], - num_decode_draft_tokens_cpu=self. - num_decode_draft_tokens.cpu[:num_reqs], - ) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - elif self.model_config.runner_type == "pooling": - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - else: - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - model=self.get_model(), - **extra_attn_metadata_args) - - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - if lmhead_tp_enable(): - max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len - logits_indices = nn.functional.pad( - logits_indices, - (0, max_num_reqs_across_dp - logits_indices.shape[0])) - - return (attn_metadata, positions, num_scheduled_tokens, - num_input_tokens, num_tokens_across_dp, - maybe_padded_num_tokens, logits_indices, spec_decode_metadata, - input_ids, inputs_embeds, intermediate_tensors, - max_num_scheduled_tokens) - @torch.inference_mode() def execute_model( self, scheduler_output: SchedulerOutput, intermediate_tensors: IntermediateTensors | None = None, - ) -> OmniModelRunnerOutput | IntermediateTensors | None: + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called " "after execute_model() returns None.") @@ -743,7 +104,7 @@ def execute_model( with self.maybe_get_ec_connector_output( scheduler_output, encoder_cache=self.encoder_cache, - ): + ) as ec_connector_output: self._execute_mm_encoder(scheduler_output) return make_empty_encoder_model_runner_output( scheduler_output) @@ -761,21 +122,23 @@ def execute_model( if self.dynamic_eplb: self.eplb_updator.forward_before() - (attn_metadata, positions, num_scheduled_tokens_np, - num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, - logits_indices, spec_decode_metadata, input_ids, inputs_embeds, - intermediate_tensors, - max_query_len) = (self._prepare_inputs(scheduler_output, - intermediate_tensors)) + (attn_metadata, num_scheduled_tokens_np, num_input_tokens, + num_tokens_across_dp, logits_indices, spec_decode_metadata, + max_query_len) = self._prepare_inputs(scheduler_output) + + (input_ids, inputs_embeds, positions, intermediate_tensors, + model_kwargs, ec_connector_output) = self._preprocess(scheduler_output, + num_input_tokens, + intermediate_tensors) + + # update global cos, sin + update_cos_sin(positions) if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() - moe_comm_type = self._select_moe_comm_method(num_input_tokens) # prevent debugger is None - need_dump = self.dump_enable and self.debugger is not None - if need_dump: - assert self.debugger is not None + if self.debugger is not None: dbg_cfg = getattr(self.debugger, "config", None) dump_level = str( getattr(dbg_cfg, "level", @@ -791,7 +154,16 @@ def execute_model( has_lora = len(self.input_batch.lora_id_to_lora_request) > 0 aclgraph_runtime_mode, batch_descriptor = \ self.cudagraph_dispatcher.dispatch( - num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + has_lora=has_lora + ) + + if self.ascend_config.enable_async_exponential: + self.sampler.do_async_exponential( + b_s=logits_indices.shape[0], + head_dim=self.model_config.get_vocab_size(), + generators=self.input_batch.sampling_metadata.generators) # Run forward pass with ProfileExecuteDuration().capture_async("forward"): @@ -800,27 +172,23 @@ def execute_model( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=self.with_prefill, - moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. total_num_scheduled_tokens, - prefetch_stream=self.prefetch_stream, - model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method): + model_instance=self.model): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( - maybe_padded_num_tokens, input_ids, positions, - intermediate_tensors, inputs_embeds) + num_input_tokens, input_ids, positions, + intermediate_tensors, inputs_embeds, model_kwargs) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( scheduler_output) aux_hidden_states = None - if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = hidden_states kv_connector_output = KVConnectorOutput( @@ -854,8 +222,8 @@ def execute_model( # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: hidden_states.kv_connector_output = kv_connector_output - if need_dump: - assert self.debugger is not None + self.kv_connector_output = kv_connector_output + if self.debugger is not None: self.debugger.stop() self.debugger.step() return hidden_states @@ -868,22 +236,13 @@ def execute_model( pool_output = self._pool( hidden_states, scheduler_output.total_num_scheduled_tokens, - num_scheduled_tokens_np) - if need_dump: - assert self.debugger is not None + num_scheduled_tokens_np, kv_connector_output) + if self.debugger is not None: self.debugger.stop() self.debugger.step() return pool_output - # Sometimes, after the model is compiled through the AOT backend, - # the model output may become a list containing only one Tensor object. - if isinstance(hidden_states, list) and \ - len(hidden_states) == 1 and \ - isinstance(hidden_states[0], torch.Tensor): - hidden_states = hidden_states[0] sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits( - sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata - ) + logits = self.model.compute_logits(sample_hidden_states) if broadcast_pp_output: model_output_broadcast_data = { "logits": logits.contiguous(), @@ -905,18 +264,33 @@ def execute_model( kv_connector_output, attn_metadata, positions, - multimodal_outputs # Omni-new + ec_connector_output, + multimodal_outputs, # Omni-new ) + self.kv_connector_output = kv_connector_output return None - @torch.inference_mode + @torch.inference_mode() def sample_tokens( - self, grammar_output: GrammarOutput | None - ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + self, + grammar_output: GrammarOutput | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. - return None # noqa - need_dump = self.dump_enable and self.debugger is not None + if not kv_connector_output: + return None # noqa + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + # Unpack ephemeral state. ( scheduler_output, @@ -924,160 +298,34 @@ def sample_tokens( spec_decode_metadata, hidden_states, sample_hidden_states, - aux_hidden_states, # noqa + aux_hidden_states, kv_connector_output, attn_metadata, positions, - multimodal_outputs, # Omni-new + ec_connector_output, + multimodal_outputs, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None # Apply structured output bitmasks if present. if grammar_output is not None: - logits = self.apply_grammar_bitmask(scheduler_output, - grammar_output, logits) + # here we are different from gpu_model_runner, + # the apply_grammar_bitmask uses torch.compile to optimize this,ascend does not support it now + logits_dtype = logits.dtype + logits = logits.to("cpu").float() + apply_grammar_bitmask(scheduler_output, grammar_output, + self.input_batch, logits) + logits = logits.to(self.device).to(logits_dtype) with ProfileExecuteDuration().capture_async("Sample"): - # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata - if spec_decode_metadata is None: - if lmhead_tp_enable() and logits is not None: - logits = logits[:self.input_batch.num_reqs] - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - else: - if lmhead_tp_enable() and logits is not None: - logits = logits[:len(spec_decode_metadata.logits_indices)] - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[ - spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[ - spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - if self.need_accepted_tokens: - self._update_states_after_model_execute(output_token_ids) - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] - for i in discard_sampled_tokens_req_indices: - generator = self.input_batch.generators.get(int(i)) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - - # Copy some objects so they don't get modified after returning. - # This is important when using async scheduling. - req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() - - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:scheduler_output.total_num_scheduled_tokens], - scheduler_output, - ) - - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] - sampled_token_ids = sampler_output.sampled_token_ids - - if not self.use_async_scheduling: - # Get the valid generated tokens. - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. It's a tensor. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. It's a numpy array - valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() - else: - valid_sampled_token_ids = [] - invalid_req_indices = discard_sampled_tokens_req_indices.tolist( - ) - invalid_req_indices_set = set(invalid_req_indices) - if self.num_spec_tokens <= 0: - assert sampled_token_ids.shape[-1] == 1 - # Cache the sampled tokens on the NPU and avoid CPU sync. - # These will be copied into input_ids in the next step - # when preparing inputs. - self.input_batch.prev_sampled_token_ids = sampled_token_ids - - - self.input_batch.prev_sampled_token_ids_invalid_indices = \ - invalid_req_indices_set - self.input_batch.prev_req_id_to_index = { - req_id: i - for i, req_id in enumerate(self.input_batch.req_ids) - if i not in invalid_req_indices_set - } - # Cache the sampled tokens in the model runner, so that the scheduler - # doesn't need to send them back. - # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # the sampled tokens back, because there's no direct communication - # between the first-stage worker and the last-stage worker. - for req_idx in range(num_sampled_tokens): - if self.use_async_scheduling: - sampled_ids = [-1] * 1 if \ - req_idx not in invalid_req_indices_set else None - else: - sampled_ids = valid_sampled_token_ids[req_idx] - if not sampled_ids: - continue - - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.model_config.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.model_config.max_model_len}") - - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, - start_idx:end_idx] = True - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) + sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( sampled_token_ids, - sampling_metadata, + self.input_batch.sampling_metadata, scheduler_output, spec_decode_metadata, positions, @@ -1085,12 +333,30 @@ def propose_draft_token_ids(sampled_token_ids): hidden_states, attn_metadata, aux_hidden_states, + sample_hidden_states ) + self._copy_draft_token_ids_to_cpu(scheduler_output) + + ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) with ProfileExecuteDuration().capture_async("Draft"): if self.speculative_config: use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.method == "mtp" and \ + self.speculative_config.use_eagle() and \ not self.speculative_config.disable_padded_drafter_batch if use_padded_batch_for_eagle: # EAGLE speculative decoding can use the GPU sampled tokens @@ -1171,151 +437,63 @@ def propose_draft_token_ids(sampled_token_ids): if self.dynamic_eplb: self.eplb_updator.forward_end() if not self.use_async_scheduling: - if need_dump: + if self.debugger is not None: assert self.debugger is not None self.debugger.stop() self.debugger.step() return model_runner_output - if need_dump: + if self.debugger is not None: assert self.debugger is not None self.debugger.stop() self.debugger.step() return AsyncGPUModelRunnerOutput( model_runner_output=model_runner_output, - sampled_token_ids=sampled_token_ids, + sampled_token_ids=sampler_output.sampled_token_ids, logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, vocab_size=self.input_batch.vocab_size, ) - def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: - req_state = self.requests.get(req_id) - if req_state is None: - return - existing = getattr(req_state, "additional_information_cpu", {}) - if not isinstance(existing, dict): - existing = {} - merged = dict(existing) - for k, v in upd.items(): - if isinstance(v, torch.Tensor): - merged[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, list): - merged[k] = [ - (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v - ] - else: - merged[k] = v - setattr(req_state, "additional_information_cpu", merged) - - def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens, + def _generate_process_reqs_hidden_states(self, num_input_tokens, input_ids, positions, intermediate_tensors, - inputs_embeds): + inputs_embeds, model_kwargs): + assert self.model is not None hidden_states = self._model_forward(input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **self._init_model_kwargs()) + **model_kwargs) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ and not self.use_sparse: - # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead if self.vllm_config.model_config.use_mla: if self.pcp_size * self.dcp_size > 1: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, - maybe_padded_num_tokens) + num_input_tokens) else: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, - maybe_padded_num_tokens, + num_input_tokens, self.speculative_config) else: if self.pcp_size * self.dcp_size > 1: update_attn_dcp_pcp_params(self.update_stream, forward_context, - maybe_padded_num_tokens) + num_input_tokens) else: update_attn_params(self.update_stream, forward_context, - maybe_padded_num_tokens) + num_input_tokens, + self.vllm_config) if get_forward_context().sp_enabled and not isinstance( hidden_states, IntermediateTensors): - hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) - pad_size = get_forward_context().pad_size - if pad_size > 0: - hidden_states = hidden_states[:-pad_size, :] - - if self.pcp_size > 1: - hidden_states = get_pcp_group().all_gather( - hidden_states[:self.num_actual_tokens_pcp_padded // - self.pcp_size], 0) - hidden_states = torch.index_select( - hidden_states, 0, - self.pcp_allgather_restore_idx[:hidden_states.shape[0]]) - return hidden_states - - def _process_additional_information_updates( - self, - hidden_states: torch.Tensor, - multimodal_outputs: object, - num_scheduled_tokens_np: np.ndarray, - ) -> None: - """Process model-provided per-request additional_information updates and merge into request state.""" - try: - # execute the custom postprocess function - # TODO(Peiqi): do we have a more elegant way to do this? - if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: - for req_index, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests.get(req_id) - req_infos = ( - getattr(req_state, "additional_information_cpu", None) if req_state is not None else None - ) - start_offset = int(self.query_start_loc.cpu[req_index]) - sched_tokens = int(num_scheduled_tokens_np[req_index]) - s, e = start_offset, start_offset + sched_tokens - # only consider to store data into update dict. - hidden_states_slice = hidden_states[s:e] - update_dict = self.model.postprocess(hidden_states_slice, **req_infos) - self._merge_additional_information_update(req_id, update_dict) - except Exception as e: - logger.error( - f"Error merging for requests:{self.input_batch.req_ids} " - f"additional information update: {e}, with the multimodal_outputs " - f"as {multimodal_outputs}" - ) - import traceback - - traceback.print_exc() - - def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: - decode_batch_size = len(decode_req_ids) - if decode_batch_size == 0: - return - _cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding( - num_tokens=decode_batch_size, - num_reqs=decode_batch_size, - num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), - max_num_scheduled_tokens=1, - use_cascade_attn=False, - ) - req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size] - req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size] - last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size] - text_step = self.text_step.gpu[:decode_batch_size] - with set_ascend_forward_context( - None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc - ): - req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) - # update the inputs_embeds and code_predictor_codes - code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() - for idx, req_id in enumerate(decode_req_ids): - req_index = self.input_batch.req_ids.index(req_id) - start_offset = int(self.query_start_loc.cpu[req_index]) - inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} - self._merge_additional_information_update(req_id, update_dict) + hidden_states = self._all_gather_hidden_states_and_aux( + hidden_states) + return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states( + hidden_states) diff --git a/vllm_omni/worker/npu/npu_ar_worker.py b/vllm_omni/worker/npu/npu_ar_worker.py index fb2f0586eb..c4895ee1b2 100644 --- a/vllm_omni/worker/npu/npu_ar_worker.py +++ b/vllm_omni/worker/npu/npu_ar_worker.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm_ascend.worker.worker_v1 import NPUWorker +from vllm.v1.worker.workspace import init_workspace_manager +from vllm_ascend.worker.worker import NPUWorker from vllm_omni.worker.npu.npu_ar_model_runner import NPUARModelRunner @@ -10,6 +11,8 @@ class NPUARWorker(NPUWorker): """NPU AR worker for thinker/talker stages in Omni model.""" def init_device(self): - device = self._init_device() + self.device = self._init_device() + num_ubatches = 1 + init_workspace_manager(self.device, num_ubatches) - self.model_runner: NPUARModelRunner = NPUARModelRunner(self.vllm_config, device) + self.model_runner = NPUARModelRunner(self.vllm_config, self.device) diff --git a/vllm_omni/worker/npu/npu_generation_model_runner.py b/vllm_omni/worker/npu/npu_generation_model_runner.py index e6641ca124..c6ad1f12a1 100644 --- a/vllm_omni/worker/npu/npu_generation_model_runner.py +++ b/vllm_omni/worker/npu/npu_generation_model_runner.py @@ -5,623 +5,86 @@ import gc import math -from typing import Any +from copy import copy import numpy as np import torch -import torch.nn as nn from vllm.config import CUDAGraphMode from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import has_kv_transfer_group from vllm.distributed.parallel_state import get_pp_group from vllm.logger import logger -from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, make_empty_encoder_model_runner_output from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm_ascend.ascend_forward_context import MoECommType, get_mc2_tokens_capacity, set_ascend_forward_context -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort +from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.platform import NPUPlatform -from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.utils import ProfileExecuteDuration, enable_sp, lmhead_tp_enable from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.npu.npu_ar_model_runner import ExecuteModelState from vllm_omni.worker.npu.npu_model_runner import OmniNPUModelRunner class NPUGenerationModelRunner(OmniNPUModelRunner): """Generation model runner for vLLM-omni on NPU (non-autoregressive).""" - def _prepare_inputs( - self, - scheduler_output: SchedulerOutput, - intermediate_tensors: IntermediateTensors | None = None, - ) -> tuple[ - dict[str, Any], - torch.Tensor, - np.ndarray, - int, - torch.Tensor, - int, - torch.Tensor, - SpecDecodeMetadata, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - int, - ]: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit_block_table(num_reqs) - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - - req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - _, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - positions_np = np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - ) - - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - if self.pcp_size > 1: - if not self.vllm_config.model_config.use_mla: - self.generate_kv_idx(scheduler_output) - tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(tokens) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - else: - position_pcp, pcp_unpad_mask = None, None - self.num_pcp_pads = self.num_pcp_pads[:num_reqs] - - total_num_pcp_pads = sum(self.num_pcp_pads) - max_num_scheduled_tokens = max(tokens) - num_valid_tokens = np.array( - [ - num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) - for num_tokens, i in zip(tokens, req_ids) - ], - dtype=np.int32, - ) - - if self.use_aclgraph and total_num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]: - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph(total_num_scheduled_tokens) - elif self.use_aclgraph and enable_sp(self.vllm_config): - # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, - # the model will fall back to running its FX graph in eager mode. - # In this case, when sequence parallelism is enabled, we need to pad tokens to align - # with tp_size because pad_size cannot be captured by the FX graph - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size - else: - # Eager mode. - num_input_tokens = total_num_scheduled_tokens - - # Get the attention state. - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) - self.attn_state = attn_state # type: ignore - - # Determine if it's a splitfuse batch - with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] - - self.query_lens = torch.from_numpy(num_scheduled_tokens) - - # Get info across DP ranks. - # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, - # Otherwise, it's just max_tokens_across_dp_cpu - (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp( - num_input_tokens, with_prefill - ) - - # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens - # We should consider removing maybe_padded_num_tokens later - num_input_tokens = maybe_padded_num_tokens - - # Hot-Swap lora model - if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) - - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - - if self.pcp_size > 1: - positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - position_pcp[:total_num_scheduled_tokens], - out=positions_np, - ) - else: - self.positions.np[:total_num_scheduled_tokens] = positions_np - - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) - - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True - ) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] - token_indices_tensor = torch.from_numpy(token_indices) - # Prepare input_ids. - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select( - self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens], - ) - if self.enable_prompt_embeds: - is_token_ids = self.input_batch.is_token_ids_tensor.flatten() - torch.index_select( - is_token_ids, 0, token_indices_tensor, out=self.is_token_ids.cpu[:total_num_scheduled_tokens] - ) - - # Because we did not pre-allocate a massive prompt_embeds CPU tensor on - # the InputBatch, we need to fill in the prompt embeds into the expected - # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. - if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or self.enable_prompt_embeds): - output_idx = 0 - for req_idx in range(num_reqs): - num_sched = num_scheduled_tokens[req_idx] - - # Skip if this request doesn't have embeddings - if req_idx not in self.input_batch.req_prompt_embeds: - output_idx += num_sched - continue - - # Skip if no tokens scheduled - if num_sched <= 0: - output_idx += num_sched - continue - - req_embeds = self.input_batch.req_prompt_embeds[req_idx] - start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] - - # Skip if trying to read beyond available embeddings - if start_pos >= req_embeds.shape[0]: - output_idx += num_sched - continue - - # Copy available embeddings - end_pos = start_pos + num_sched - actual_end = min(end_pos, req_embeds.shape[0]) - actual_num_sched = actual_end - start_pos - - if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx : output_idx + actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) - - output_idx += num_sched - - self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens - self.query_start_loc.copy_to_gpu() - - self.seq_lens.np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens - self.seq_lens.copy_to_gpu() - - # Fill unused with -1. Needed for reshape_and_cache - self.query_start_loc.gpu[num_reqs + 1 :].fill_(-1) - self.seq_lens.gpu[num_reqs:].fill_(0) - - self.query_lens = torch.from_numpy(num_scheduled_tokens) - - # Copy the tensors to the NPU. - self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, cu_num_tokens) - self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_() - self.positions.copy_to_gpu() - - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) - self.attn_mask = self._make_attention_mask(attn_state) - self.attn_state = attn_state # type: ignore - - self.with_prefill = with_prefill - self.num_tokens_across_dp = num_tokens_across_dp - attn_metadata: dict[str, Any] = {} - - # Record the index of requests that should not be sampled, - # so that we could clear the sampled tokens before returning - num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] - num_tokens_np = np.array(num_tokens, dtype=np.int32) - num_reqs = self.input_batch.num_reqs - if self.pcp_size > 1: - # while pcp > 1, we need the original num_scheduled_tokens before split - # to calculate discard_requests_mask - tokens_original = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - original_seq_lens_np = self.input_batch.num_computed_tokens_cpu[:num_reqs] + np.array( - tokens_original, dtype=np.int32 - ) - discard_requests_mask = original_seq_lens_np < num_tokens_np - else: - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np - - discard_request_indices = np.nonzero(discard_requests_mask)[0] - self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[: self.num_discarded_requests] = discard_request_indices - self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) - - # _prepare_inputs may reorder the batch, so we must gather - # multi-modal outputs after that to ensure the correct order - if self.is_multimodal_model: - with self.maybe_get_ec_connector_output( - scheduler_output, - encoder_cache=self.encoder_cache, - ): - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids.gpu[:total_num_scheduled_tokens] - mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) - - inputs_embeds = self.model.embed_input_ids( - input_ids, - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) - - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(inputs_embeds) - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - # -------------------------------------- Omni-new ------------------------------------------------- - # NOTE(gcanlin): We don't set input_ids to None in vllm-omni. - model_kwargs = { - **self._init_model_kwargs(), - **self._extract_mm_kwargs(scheduler_output), - } - # -------------------------------------- Omni-new ------------------------------------------------- - elif self.enable_prompt_embeds and get_pp_group().is_first_rank: - # Get the input embeddings for the tokens that are not input embeds, - # then put them into the appropriate positions. - # TODO(qthequartermasterman): Since even when prompt embeds are - # enabled, (a) not all requests will use prompt embeds, and (b) - # after the initial prompt is processed, the rest of the generated - # tokens will be token ids, it is not desirable to have the - # embedding layer outside of the acl graph all the time. The v0 - # engine avoids this by "double compiling" the acl graph, once - # with input_ids and again with inputs_embeds, for all num_tokens. - # If a batch only has token ids, then including the embedding layer - # in the acl graph will be more performant (like in the else case - # below). - token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens].nonzero(as_tuple=False).squeeze(1) - # Some tokens ids may need to become embeds - if token_ids_idx.numel() > 0: - token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids) - self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds - - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the ACL graph. - input_ids = self.input_ids.gpu[:num_input_tokens] - inputs_embeds = None - positions = self.positions.gpu[:num_input_tokens] - if self.uses_mrope: - positions = self.mrope_positions.gpu[:, :num_input_tokens] - - # type: ignore - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - assert intermediate_tensors is not None - assert self.intermediate_tensors is not None - # If both flashcomm1 and pp are used simultaneously, - # the shape of the received data and the shape of the space to be copied to will not match, - # requiring a recalculation of the incoming data's shape. - tp_size = get_tensor_model_parallel_world_size() - num_input_tokens_with_flashcomm1 = num_input_tokens - if enable_sp(): - num_input_tokens_with_flashcomm1 = (num_input_tokens + tp_size - 1) // tp_size - for k, v in intermediate_tensors.items(): - self.intermediate_tensors[k][:num_input_tokens_with_flashcomm1].copy_( - v[:num_input_tokens_with_flashcomm1], non_blocking=True - ) - intermediate_tensors = IntermediateTensors( - {k: v[:num_input_tokens_with_flashcomm1] for k, v in self.intermediate_tensors.items()} - ) - - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - if not use_spec_decode: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - spec_decode_metadata = None - if self.pcp_size * self.dcp_size > 1: - logits_indices = torch.from_numpy(cu_num_tokens) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1 - logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True) - else: - logits_indices = self.query_start_loc.gpu[1 : num_reqs + 1] - 1 - else: - # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - # For chunked prefills, use -1 as mask rather than 0, as guided - # decoding may rollback speculative tokens. - num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in scheduler_output.scheduled_spec_decode_tokens.items(): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = ( - len(draft_token_ids) - if ( - self.input_batch.num_computed_tokens_cpu[req_idx] >= self.input_batch.num_prompt_tokens[req_idx] - ) - else -1 - ) - - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs] - ) - logits_indices = spec_decode_metadata.logits_indices - - # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens - self.num_decode_draft_tokens.np[num_reqs:].fill(-1) - self.num_decode_draft_tokens.copy_to_gpu() - # save logits_indices for pcp spec decode usage - self.logits_indices = logits_indices - - # Used in the below loop. - # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] - num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] - self.spec_decode_common_attn_metadata = None - if use_spec_decode and self.need_accepted_tokens: - self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() - - if self.speculative_config and self.pcp_size > 1: - self._generate_pcp_mtp_input( - num_reqs, scheduler_output.total_num_scheduled_tokens, scheduler_output.num_scheduled_tokens - ) - - long_seq_metadata = self._generate_pcp_metadata(total_num_scheduled_tokens) - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): - # NOTE: This is strange, why did we use total_num_scheduled_tokens before? - slot_mapping_size = ( - total_num_scheduled_tokens - if self.pcp_size == 1 - else total_num_scheduled_tokens * self.pcp_size - total_num_pcp_pads - ) - if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): - # Encoder-only layers do not have KV cache, so we need to - # create a dummy block table and slot mapping for them. - blk_table_tensor = torch.zeros( - (num_reqs, 1), - dtype=torch.int32, - device=self.device, - ) - slot_mapping = torch.zeros( - (total_num_scheduled_tokens,), - dtype=torch.int64, - device=self.device, - ) - else: - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor() - blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0) - if self.pcp_size > 1: - slot_mapping_for_pcp = blk_table.slot_mapping.gpu[: long_seq_metadata.num_actual_tokens_pcp_padded] - slot_mapping_for_pcp[slot_mapping_size:].fill_(-1) - assert pcp_unpad_mask is not None - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: pcp_unpad_mask.shape[0]] - pcp_padded_slot_mapping.fill_(-1) - pcp_padded_slot_mapping[pcp_unpad_mask] = slot_mapping_for_pcp[:slot_mapping_size] - slot_mapping_for_pcp[: long_seq_metadata.num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping - blk_table.slot_mapping.gpu[: long_seq_metadata.num_actual_tokens_pcp_padded] = slot_mapping_for_pcp - slot_mapping = blk_table.slot_mapping.gpu - - # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs - # has been split to multiple parts, and there are 3 parts that is related to this - # `num_reqs`, we'll take `query_start_loc` as an example: - # 1. self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens - # 2. get `num_reqs_padded`, this depends on dispatcher and which is why we have the - # following simplified `dispatch` logic here, we try to minimize the impact - # 3. query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1] - uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) and ( - total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs - ) - - # TODO: We should make this official ASAP. Also note that if we pad here, - # the builders won’t need to add any extra padding. - if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and uniform_decode: - num_reqs_padded = num_input_tokens // self.uniform_decode_query_len - pad_size = num_reqs_padded - num_reqs - if pad_size > 0: - last_query_loc = self.query_start_loc.gpu[num_reqs] - - steps = torch.arange(1, pad_size + 1, device=self.device, dtype=self.query_start_loc.gpu.dtype) - fill_values = last_query_loc + (steps * self.uniform_decode_query_len) - - self.query_start_loc.gpu[num_reqs + 1 : num_reqs_padded + 1] = fill_values - # So we are trying to simulate the behavior of GPUModelRunner's - # prepare_inputs for uniform decode mode by padding query_start_loc - num_reqs = num_reqs_padded - - # Make AscendCommonAttentionMetadata - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], - seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.gpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=slot_mapping_size, - num_input_tokens=num_input_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - # TODO: change this to the right block table for linear attn - block_table_tensor=blk_table_tensor[:num_reqs], - slot_mapping=slot_mapping, - num_computed_tokens_cpu=num_computed_tokens_cpu, - positions=self.positions.gpu, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - attn_state=self.attn_state, - is_only_prefill=bool(np.all(num_valid_tokens != 1)), - max_query_len=max_num_scheduled_tokens, - decode_token_per_req=self.decode_token_per_req, - prefill_context_parallel_metadata=long_seq_metadata, - ) - - if self.speculative_config and self.pcp_size > 1: - # For pcp + spec decode, we flatten block_table - # to avoid irregular spec_attn_mask shape, e.g., - # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, - # ori block_table: # [d0, d1, p0, p1, p2] - # (num_reqs_d + num_reqs_p, max_num_blocks), - # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] - # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens = ( - self.query_start_loc_pcp_full.cpu[1 : num_reqs + 1] - self.query_start_loc_pcp_full.cpu[:num_reqs] - ) - num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() - num_decode_reqs = num_reqs - num_prefill_reqs - num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold - blk_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_( - blk_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone() - ) - blk_table_tensor[:num_decode_reqs_flatten].copy_( - blk_table_tensor[:num_decode_reqs].repeat_interleave(self.decode_threshold, dim=0) - ) - common_attn_metadata.block_table_tensor = blk_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs] - - if self.speculative_config and self.spec_decode_common_attn_metadata is None: - self.spec_decode_common_attn_metadata = common_attn_metadata - - for attn_group in self.attn_groups[kv_cache_group_id]: - common_prefix_len = 0 - extra_attn_metadata_args = {} - builder = attn_group.get_metadata_builder() - if isinstance(builder, GDNAttentionMetadataBuilder): - if use_spec_decode: - patch_torch_npu_argsort() - extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], - num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[:num_reqs], - ) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args, - ) - elif self.model_config.runner_type == "pooling": - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args, - ) - else: - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - model=self.get_model(), - **extra_attn_metadata_args, - ) - - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - if lmhead_tp_enable(): - max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len - logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) - - return ( - attn_metadata, - positions, - num_scheduled_tokens, - num_input_tokens, - num_tokens_across_dp, - maybe_padded_num_tokens, - logits_indices, - spec_decode_metadata, - input_ids, - inputs_embeds, - intermediate_tensors, - max_num_scheduled_tokens, - model_kwargs, - ) - @torch.inference_mode() def execute_model( self, scheduler_output: SchedulerOutput, intermediate_tensors: IntermediateTensors | None = None, ) -> OmniModelRunnerOutput | IntermediateTensors: + if self.execute_model_state is not None: + raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") + with ProfileExecuteDuration().capture_async("prepare input"): self._update_states(scheduler_output) + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + logger.debug("skip this step for we receive the data from remote disaggregate prefill node") + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) if self.dynamic_eplb: self.eplb_updator.forward_before() ( attn_metadata, - positions, num_scheduled_tokens_np, num_input_tokens, num_tokens_across_dp, - maybe_padded_num_tokens, logits_indices, spec_decode_metadata, - input_ids, - inputs_embeds, - intermediate_tensors, max_query_len, - model_kwargs, - ) = self._prepare_inputs(scheduler_output, intermediate_tensors) + ) = self._prepare_inputs(scheduler_output) + + (input_ids, inputs_embeds, positions, intermediate_tensors, model_kwargs, ec_connector_output) = ( + self._preprocess(scheduler_output, num_input_tokens, intermediate_tensors) + ) + + # update global cos, sin + update_cos_sin(positions) if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() - moe_comm_type = self._select_moe_comm_method(num_input_tokens) # prevent debugger is None - need_dump = self.dump_enable and self.debugger is not None - if need_dump: - assert self.debugger is not None + if self.debugger is not None: dbg_cfg = getattr(self.debugger, "config", None) dump_level = str(getattr(dbg_cfg, "level", "L1")).upper() if dbg_cfg is not None else "L1" if dump_level in ("L0", "MIX"): @@ -637,6 +100,13 @@ def execute_model( num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora ) + if self.ascend_config.enable_async_exponential: + self.sampler.do_async_exponential( + b_s=logits_indices.shape[0], + head_dim=self.model_config.get_vocab_size(), + generators=self.input_batch.sampling_metadata.generators, + ) + # Run forward pass with ProfileExecuteDuration().capture_async("forward"): with set_ascend_forward_context( @@ -644,23 +114,19 @@ def execute_model( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=self.with_prefill, - moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output.total_num_scheduled_tokens, - prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method, ): self.maybe_setup_kv_connector(scheduler_output) # -------------------------------------- Omni-new ------------------------------------------------- - outputs = self._run_generation( + outputs = self._run_generation_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - multimodal_kwargs=model_kwargs, + model_kwargs=model_kwargs, logits_indices=logits_indices, ) # -------------------------------------- Omni-new ------------------------------------------------- @@ -669,15 +135,70 @@ def execute_model( finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) aux_hidden_states = None - if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: - hidden_states, aux_hidden_states = outputs + if self.use_aux_hidden_state_outputs: + outputs, aux_hidden_states = outputs kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) finished_sending = None finished_recving = None - # -------------------------------------- Omni-new ------------------------------------------------- - # We don't need any post-process for generation model outputs + _, multimodal_outputs = self.extract_multimodal_outputs(outputs) + # Apply structured output bitmasks if present + self.execute_model_state = ExecuteModelState( + scheduler_output, + None, + spec_decode_metadata, + outputs, + None, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ec_connector_output, + multimodal_outputs, + ) + self.kv_connector_output = kv_connector_output + return None + + @torch.inference_mode() + def sample_tokens( + self, + grammar_output: GrammarOutput | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # noqa + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ec_connector_output, + multimodal_outputs, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # -------------------------------------- Omni-new ------------------------------------------------- pooler_output: list[object] = [] if isinstance(multimodal_outputs, torch.Tensor): assert multimodal_outputs.shape[0] == 1, ( @@ -699,35 +220,39 @@ def execute_model( pooler_output.append({key: out.detach().to("cpu").contiguous() if out is not None else None}) else: raise RuntimeError("Unsupported diffusion output type") + req_ids_output_copy = list(self.input_batch.req_ids) + req_id_to_index_output_copy = dict(self.input_batch.req_id_to_index) output = OmniModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, kv_connector_output=kv_connector_output, num_nans_in_logits={}, + ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, ) # -------------------------------------- Omni-new ------------------------------------------------- if not self.use_async_scheduling: return output - return AsyncGPUModelRunnerOutput( model_runner_output=output, - sampled_token_ids=[], + sampled_token_ids=torch.tensor([], device=self.device), + logprobs_tensors=None, invalid_req_indices=[], async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, ) - def _run_generation( + def _run_generation_model( self, *, - input_ids: torch.Tensor, - positions: torch.Tensor, + input_ids: torch.Tensor | None, + positions: torch.Tensor | None, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None, - multimodal_kwargs: dict, + model_kwargs: dict, logits_indices: torch.Tensor, ) -> torch.Tensor | list[torch.Tensor]: """Run generation from codec codes to waveforms. @@ -739,12 +264,13 @@ def _run_generation( Returns: Audio waveforms: [batch, 1, waveform_len] or list of tensors """ + # Keep inputs identical to AR runner kwargs = dict( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs(multimodal_kwargs, device=self.device), + **model_kwargs, sampling_metadata=self.input_batch.sampling_metadata, logits_index=logits_indices, sampler=self.sampler, @@ -754,7 +280,7 @@ def _run_generation( return self._model_forward(**kwargs) raise RuntimeError( - "The loaded model does not expose generation interfaces 'sample', " + "The loaded model does not expose diffusion interfaces 'sample', " "'forward', or 'diffuse'. Please implement one of them or adapt the runner." ) @@ -768,12 +294,18 @@ def _dummy_run( self, num_tokens: int, with_prefill: bool = False, - aclgraph_runtime_mode: CUDAGraphMode | None = None, + cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, + is_profile: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, ) -> torch.Tensor: # only support eager mode and piecewise graph now - assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL, @@ -788,8 +320,15 @@ def _dummy_run( if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True + has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) + # Padding for DP - (num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp(num_tokens, with_prefill) + (num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp( + batch_descriptor.num_tokens, with_prefill + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -830,13 +369,13 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - if not self.in_profile_run and self.dynamic_eplb: + if not is_profile and self.dynamic_eplb: self.eplb_updator.forward_before() - has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False - _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora - ) + if num_tokens != batch_descriptor.num_tokens: + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) num_tokens_padded = batch_descriptor.num_tokens num_reqs_padded = batch_descriptor.num_reqs if batch_descriptor.num_reqs is not None else num_reqs @@ -845,19 +384,17 @@ def _dummy_run( num_tokens_across_dp[:] = num_tokens_padded num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) - moe_comm_type = self._select_moe_comm_method(num_tokens_padded) - # filter out the valid batch descriptor - if aclgraph_runtime_mode is not None: + if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for aclgraph capture - if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode: + if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode: raise ValueError( f"Aclgraph runtime mode mismatch at dummy_run. " - f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}." + f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}." ) else: - aclgraph_runtime_mode = _ag_mode + cudagraph_runtime_mode = _ag_mode # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup # and not supported in ASCEND now. We could remove it in the future. @@ -866,15 +403,16 @@ def _dummy_run( num_reqs=num_reqs_padded, num_tokens=num_tokens_padded, max_query_len=max_query_len, - aclgraph_runtime_mode=aclgraph_runtime_mode, + aclgraph_runtime_mode=cudagraph_runtime_mode, force_attention=force_attention, + is_graph_capturing=is_graph_capturing, num_scheduled_tokens=num_scheduled_tokens, ) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, num_sampled_tokens): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - if self.is_multimodal_model: + if self.is_multimodal_model and not self.model_config.is_encoder_decoder: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] elif self.enable_prompt_embeds: @@ -886,9 +424,14 @@ def _dummy_run( if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens_padded] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: positions = self.positions.gpu[:num_tokens_padded] + # update global cos, sin + update_cos_sin(positions) + if get_pp_group().is_first_rank: intermediate_tensors = None else: @@ -908,7 +451,7 @@ def _dummy_run( {k: v[:num_tokens_padded] for k, v in self.intermediate_tensors.items()} ) - need_dummy_logits = not self.in_profile_run and lmhead_tp_enable() + need_dummy_logits = not is_profile and lmhead_tp_enable() max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) @@ -928,16 +471,11 @@ def dummy_drafter_compute_logits(hidden_states): self.vllm_config, num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=with_prefill, - in_profile_run=self.in_profile_run, - # reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_type=moe_comm_type, + in_profile_run=is_profile, num_actual_tokens=0, - aclgraph_runtime_mode=aclgraph_runtime_mode, + aclgraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, - prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method, ): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds @@ -950,14 +488,15 @@ def dummy_drafter_compute_logits(hidden_states): with_prefill=with_prefill, num_reqs=num_reqs_padded, num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=aclgraph_runtime_mode, + aclgraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, dummy_compute_logits=dummy_drafter_compute_logits, - skip_attn=not force_attention, + in_graph_capturing=not force_attention, + is_profile=is_profile, ) - if self.in_profile_run and self.dynamic_eplb: + if is_profile and self.dynamic_eplb: self.model.clear_all_moe_loads() - if not self.in_profile_run and self.dynamic_eplb: + if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() # -------------------------------------- Omni-new ------------------------------------------------- diff --git a/vllm_omni/worker/npu/npu_generation_worker.py b/vllm_omni/worker/npu/npu_generation_worker.py index 23cdf41c78..588d21fc1f 100644 --- a/vllm_omni/worker/npu/npu_generation_worker.py +++ b/vllm_omni/worker/npu/npu_generation_worker.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm_ascend.worker.worker_v1 import NPUWorker +from vllm.v1.worker.workspace import init_workspace_manager +from vllm_ascend.worker.worker import NPUWorker from vllm_omni.worker.npu.npu_generation_model_runner import NPUGenerationModelRunner @@ -10,6 +11,8 @@ class NPUGenerationWorker(NPUWorker): """NPU generation worker for code2wav stage in Omni model.""" def init_device(self): - device = self._init_device() + self.device = self._init_device() + num_ubatches = 1 + init_workspace_manager(self.device, num_ubatches) - self.model_runner: NPUGenerationModelRunner = NPUGenerationModelRunner(self.vllm_config, device) + self.model_runner = NPUGenerationModelRunner(self.vllm_config, self.device) diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py index e461f055ce..82ca90c2ef 100644 --- a/vllm_omni/worker/npu/npu_model_runner.py +++ b/vllm_omni/worker/npu/npu_model_runner.py @@ -19,6 +19,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.compilation.acl_graph import ACLGraphWrapper +from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.utils import enable_sp, lmhead_tp_enable from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -119,6 +120,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -139,7 +141,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint, + # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds + # apart from the forced-preemption case in reset_prefix_cache. And in + # that case we include the resumed_req_ids in the unscheduled set so + # that they get cleared from the persistent batch before being re-scheduled + # in the normal resumed request path. + unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids) # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct @@ -231,21 +240,64 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: except Exception as e: logger.error(f"Error decoding additional information: {e}") pass - # ------------------------------------------------------------------------------------------------ + # -------------------------------------- Omni-new ------------------------------------------------- + + if sampling_params and sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = ( + self.input_batch.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._init_xdrope_positions(req_state) + reqs_to_add.append(self.requests[req_id]) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] + resumed_from_preemption = req_id in req_data.resumed_req_ids + num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + if req_state.prev_num_draft_len and self.use_async_scheduling: + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt length), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -263,6 +315,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = self.input_batch.num_prompt_tokens[req_index] + num_output_tokens + self.input_batch.num_tokens_no_spec[req_index] = end_idx # Update the block IDs. if not resumed_from_preemption: @@ -271,6 +330,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: + assert req_index is None assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -281,6 +341,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] + reqs_to_add.append(req_state) continue @@ -297,22 +364,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index - self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()) - if spec_token_ids: - num_spec_tokens = len(spec_token_ids) - start_index = self.input_batch.num_tokens_no_spec[req_index] - end_token_index = start_index + num_spec_tokens - self.input_batch.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec tokens. - self.input_batch.num_tokens[req_index] += num_spec_tokens + self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: self.input_batch.add_request(request) + self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -346,12 +406,18 @@ def _dummy_run( self, num_tokens: int, with_prefill: bool = False, - aclgraph_runtime_mode: CUDAGraphMode | None = None, + cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, + is_profile: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, ) -> torch.Tensor: # only support eager mode and piecewise graph now - assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL, @@ -366,8 +432,15 @@ def _dummy_run( if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True + has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) + # Padding for DP - (num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp(num_tokens, with_prefill) + (num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp( + batch_descriptor.num_tokens, with_prefill + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -408,13 +481,13 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - if not self.in_profile_run and self.dynamic_eplb: + if not is_profile and self.dynamic_eplb: self.eplb_updator.forward_before() - has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False - _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora - ) + if num_tokens != batch_descriptor.num_tokens: + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) num_tokens_padded = batch_descriptor.num_tokens num_reqs_padded = batch_descriptor.num_reqs if batch_descriptor.num_reqs is not None else num_reqs @@ -423,19 +496,17 @@ def _dummy_run( num_tokens_across_dp[:] = num_tokens_padded num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) - moe_comm_type = self._select_moe_comm_method(num_tokens_padded) - # filter out the valid batch descriptor - if aclgraph_runtime_mode is not None: + if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for aclgraph capture - if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode: + if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode: raise ValueError( f"Aclgraph runtime mode mismatch at dummy_run. " - f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}." + f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}." ) else: - aclgraph_runtime_mode = _ag_mode + cudagraph_runtime_mode = _ag_mode # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup # and not supported in ASCEND now. We could remove it in the future. @@ -444,15 +515,16 @@ def _dummy_run( num_reqs=num_reqs_padded, num_tokens=num_tokens_padded, max_query_len=max_query_len, - aclgraph_runtime_mode=aclgraph_runtime_mode, + aclgraph_runtime_mode=cudagraph_runtime_mode, force_attention=force_attention, + is_graph_capturing=is_graph_capturing, num_scheduled_tokens=num_scheduled_tokens, ) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, num_sampled_tokens): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - if self.is_multimodal_model: + if self.is_multimodal_model and not self.model_config.is_encoder_decoder: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] elif self.enable_prompt_embeds: @@ -464,9 +536,14 @@ def _dummy_run( if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens_padded] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: positions = self.positions.gpu[:num_tokens_padded] + # update global cos, sin + update_cos_sin(positions) + if get_pp_group().is_first_rank: intermediate_tensors = None else: @@ -486,7 +563,7 @@ def _dummy_run( {k: v[:num_tokens_padded] for k, v in self.intermediate_tensors.items()} ) - need_dummy_logits = not self.in_profile_run and lmhead_tp_enable() + need_dummy_logits = not is_profile and lmhead_tp_enable() max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) @@ -506,16 +583,11 @@ def dummy_drafter_compute_logits(hidden_states): self.vllm_config, num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=with_prefill, - in_profile_run=self.in_profile_run, - # reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_type=moe_comm_type, + in_profile_run=is_profile, num_actual_tokens=0, - aclgraph_runtime_mode=aclgraph_runtime_mode, + aclgraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, - prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method, ): if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): hidden_states = self.talker_mtp( @@ -529,26 +601,26 @@ def dummy_drafter_compute_logits(hidden_states): ) dummy_compute_logits(hidden_states) + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + if self.drafter: self.drafter.dummy_run( num_tokens=num_tokens_padded, with_prefill=with_prefill, num_reqs=num_reqs_padded, num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=aclgraph_runtime_mode, + aclgraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, dummy_compute_logits=dummy_drafter_compute_logits, - skip_attn=not force_attention, + in_graph_capturing=not force_attention, + is_profile=is_profile, ) - if self.in_profile_run and self.dynamic_eplb: + if is_profile and self.dynamic_eplb: self.model.clear_all_moe_loads() - if not self.in_profile_run and self.dynamic_eplb: + if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() - - hidden_states, _ = self.extract_multimodal_outputs(hidden_states) - - return hidden_states + return hidden_states, hidden_states def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput") -> None: """Decode per-request prompt_embeds and additional_information for newly @@ -643,6 +715,39 @@ def _build_model_kwargs_extra(self) -> dict: traceback.print_exc() return model_kwargs_extra + def _process_additional_information_updates( + self, + hidden_states: torch.Tensor, + multimodal_outputs: object, + num_scheduled_tokens_np: np.ndarray, + ) -> None: + """Process model-provided per-request additional_information updates and merge into request state.""" + try: + # execute the custom postprocess function + # TODO(Peiqi): do we have a more elegant way to do this? + if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests.get(req_id) + req_infos = ( + getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + # only consider to store data into update dict. + hidden_states_slice = hidden_states[s:e] + update_dict = self.model.postprocess(hidden_states_slice, **req_infos) + self._merge_additional_information_update(req_id, update_dict) + except Exception as e: + logger.error( + f"Error merging for requests:{self.input_batch.req_ids} " + f"additional information update: {e}, with the multimodal_outputs " + f"as {multimodal_outputs}" + ) + import traceback + + traceback.print_exc() + def _collect_additional_information_for_prefill( self, num_scheduled_tokens_np: np.ndarray, @@ -666,6 +771,192 @@ def _collect_additional_information_for_prefill( start_offset = int(self.query_start_loc.cpu[req_index]) self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + num_input_tokens: int, + intermediate_tensors: IntermediateTensors | None = None, + ): + """Align with v0.14.0 preprocess and omni's additional information handling.""" + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + is_first_rank = get_pp_group().is_first_rank + is_encoder_decoder = self.model_config.is_encoder_decoder + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + ec_connector_output = None + + if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder: + # Run the multimodal encoder if any. + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + inputs_embeds_scheduled = self.model.embed_input_ids( + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) + + input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens) + model_kwargs = { + **self._init_model_kwargs(), + **self._extract_mm_kwargs(scheduler_output), + } + elif self.enable_prompt_embeds and is_first_rank: + # Get the input embeddings for the tokens that are not input embeds, + # then put them into the appropriate positions. + # TODO(qthequartermasterman): Since even when prompt embeds are + # enabled, (a) not all requests will use prompt embeds, and (b) + # after the initial prompt is processed, the rest of the generated + # tokens will be token ids, it is not desirable to have the + # embedding layer outside of the CUDA graph all the time. The v0 + # engine avoids this by "double compiling" the CUDA graph, once + # with input_ids and again with inputs_embeds, for all num_tokens. + # If a batch only has token ids, then including the embedding layer + # in the CUDA graph will be more performant (like in the else case + # below). + token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens].nonzero(as_tuple=False).squeeze(1) + # Some tokens ids may need to become embeds + if token_ids_idx.numel() > 0: + token_ids = self.input_ids.gpu[token_ids_idx] + tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids) + self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs() + input_ids = self.input_ids.gpu[:num_input_tokens] + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids.gpu[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs() + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_input_tokens] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_input_tokens] + else: + positions = self.positions.gpu[:num_input_tokens] + + if is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True + ) + + if is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: + # Run the encoder, just like we do with other multimodal inputs. + # For an encoder-decoder model, our processing here is a bit + # simpler, because the outputs are just passed to the decoder. + # We are not doing any prompt replacement. We also will only + # ever have a single encoder input. + encoder_outputs = self._execute_mm_encoder(scheduler_output) + model_kwargs.update({"encoder_outputs": encoder_outputs}) + + req_ids = self.input_batch.req_ids + num_scheduled_tokens_np = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], + dtype=np.int32, + ) + self._omni_num_scheduled_tokens_np = num_scheduled_tokens_np + + # Note: only prefill need collect additional_information for now. + # Decode don't need per_req_additional_information anymore. + if inputs_embeds is not None: + # Prefill: overlay prompt_embeds and collect additional_information + self._collect_additional_information_for_prefill(num_scheduled_tokens_np) + + if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: + # Overlay custom prompt_embeds per request for the prompt portion; + # collect additional_information (tensor/list) for prefill portion only + decode_req_ids = [] + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests.get(req_id) + req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + span_len = int(e) - int(s) + + # call the custom process function + req_input_ids, req_embeds, update_dict = self.model.preprocess( + input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos + ) + if hasattr(self.model, "talker_mtp") and span_len == 1: + last_talker_hidden, text_step = update_dict.pop("mtp_inputs") + decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) + self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) + self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) + self.text_step.gpu[decode_slice].copy_(text_step) + decode_req_ids.append(req_id) + + # TODO(Peiqi): the merge stage could move out from the critical path + self._merge_additional_information_update(req_id, update_dict) + + # update the inputs_embeds and input_ids + seg_len = min(span_len, req_embeds.shape[0]) + inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] + if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: + input_ids[s : s + seg_len] = req_input_ids + + # run talker mtp decode + if hasattr(self.model, "talker_mtp"): + self._talker_mtp_forward(decode_req_ids, inputs_embeds) + + return ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) + + def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: + decode_batch_size = len(decode_req_ids) + if decode_batch_size == 0: + return + _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( + num_tokens=decode_batch_size, + num_reqs=decode_batch_size, + num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size] + req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size] + last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size] + text_step = self.text_step.gpu[:decode_batch_size] + with set_ascend_forward_context( + None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc + ): + req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # update the inputs_embeds and code_predictor_codes + code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + for idx, req_id in enumerate(decode_req_ids): + req_index = self.input_batch.req_ids.index(req_id) + start_offset = int(self.query_start_loc.cpu[req_index]) + inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] + update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + self._merge_additional_information_update(req_id, update_dict) + def _model_forward( self, input_ids: torch.Tensor | None = None, @@ -696,3 +987,22 @@ def _model_forward( # Cache model output so later sample_tokens can consume multimodal results. self._omni_last_model_output = model_output return model_output + + def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: + req_state = self.requests.get(req_id) + if req_state is None: + return + existing = getattr(req_state, "additional_information_cpu", {}) + if not isinstance(existing, dict): + existing = {} + merged = dict(existing) + for k, v in upd.items(): + if isinstance(v, torch.Tensor): + merged[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + merged[k] = [ + (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v + ] + else: + merged[k] = v + setattr(req_state, "additional_information_cpu", merged) From 74362bf9105dca1cd3e9c38c91cac6cdbe780621 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 03:37:14 +0000 Subject: [PATCH 05/11] remove patch Signed-off-by: gcanlin --- vllm_omni/patch.py | 62 ---------------------------------------------- 1 file changed, 62 deletions(-) diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index 9ecb5eeb09..f63acea2d0 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -32,65 +32,3 @@ module.Request = OmniRequest if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest: module.EngineCoreRequest = OmniEngineCoreRequest - - -# Patch for vllm-ascend prefetch functions bug fix -# Issue: The original functions access forward_context attributes like -# prefetch_mlp_gate_up_proj, prefetch_mlp_down_proj, layer_idx without checking -# if they exist, which causes AttributeError when prefetch_mlp_enabled is not set. -# TODO: Remove this patch after upgrading to vllm-ascend v0.13.0 or later. -# This issue has been fixed in https://github.com/vllm-project/vllm-ascend/pull/5035 -if is_npu(): - import torch - import torch.nn as nn - from vllm.model_executor.models.qwen2_5_omni_thinker import Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs - from vllm_ascend.ascend_forward_context import set_ascend_forward_context - - from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import ( - Qwen2_5OmniThinkerForConditionalGeneration, - ) - - class AscendQwen2_5OmniThinkerForConditionalGeneration(nn.Module): - def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": - return image_input["image_embeds"].type(self.visual.dtype) - - grid_thw = image_input["image_grid_thw"] - assert grid_thw.ndim == 2 - - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - with set_ascend_forward_context(None, self.vllm_config): - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - # Split concatenated embeddings for each image item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - - return image_embeds.split(sizes.tolist()) - - def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs, - video_hashes: list[str] | None = None, - cached_video_embeds: torch.Tensor | None = None, - ) -> torch.Tensor: - if video_input["type"] == "video_embeds": - return video_input["video_embeds"].type(self.visual.dtype) - - grid_thw = video_input["video_grid_thw"] - assert grid_thw.ndim == 2 - - pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) - with set_ascend_forward_context(None, self.vllm_config): - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) - # Split concatenated embeddings for each video item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - - return video_embeds.split(sizes.tolist()) - - Qwen2_5OmniThinkerForConditionalGeneration._process_image_input = ( - AscendQwen2_5OmniThinkerForConditionalGeneration._process_image_input - ) - Qwen2_5OmniThinkerForConditionalGeneration._process_video_input = ( - AscendQwen2_5OmniThinkerForConditionalGeneration._process_video_input - ) From 2b31ab0383473afbc3a8dd9232b01ad340125475 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 03:37:40 +0000 Subject: [PATCH 06/11] remove patch Signed-off-by: gcanlin --- vllm_omni/patch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index f63acea2d0..687ff51865 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -14,7 +14,6 @@ from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.layers.mrope import MRotaryEmbedding from vllm_omni.request import OmniRequest -from vllm_omni.utils import is_npu for module_name, module in sys.modules.items(): # only do patch on module of vllm, pass others From 81198844cbd86eb76cc92878e51be218aa2d4a02 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 07:19:31 +0000 Subject: [PATCH 07/11] disable async scheduling in stage2 Signed-off-by: gcanlin --- vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml index f89f205a21..eadfea5705 100644 --- a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml @@ -75,6 +75,7 @@ stage_args: scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true trust_remote_code: true + async_scheduling: false enable_prefix_caching: false engine_output_type: audio # Final output: audio waveform gpu_memory_utilization: 0.1 From 7fecf335dc77cde37a714682185ff787dfa61d65 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 08:23:41 +0000 Subject: [PATCH 08/11] don't need copy anymore Signed-off-by: gcanlin --- vllm_omni/worker/npu/npu_ar_model_runner.py | 28 +++---------------- .../worker/npu/npu_generation_model_runner.py | 6 ++-- 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/vllm_omni/worker/npu/npu_ar_model_runner.py b/vllm_omni/worker/npu/npu_ar_model_runner.py index 543e9b4548..3d632ee474 100644 --- a/vllm_omni/worker/npu/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu/npu_ar_model_runner.py @@ -31,12 +31,7 @@ # yapf conflicts with isort for this block # yapf: disable -from vllm_ascend.compilation.acl_graph import ( - update_attn_dcp_pcp_params, - update_attn_params, - update_mla_attn_dcp_pcp_params, - update_mla_attn_params, -) +from vllm_ascend.compilation.acl_graph import update_full_graph_params from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.utils import ProfileExecuteDuration @@ -472,24 +467,9 @@ def _generate_process_reqs_hidden_states(self, num_input_tokens, and not self.use_sparse: if self.vllm_config.model_config.use_mla: if self.pcp_size * self.dcp_size > 1: - # FIXME: Try using `auto_dispatch_capture=True` - update_mla_attn_dcp_pcp_params(self.update_stream, - forward_context, - num_input_tokens) - else: - # FIXME: Try using `auto_dispatch_capture=True` - update_mla_attn_params(self.update_stream, forward_context, - num_input_tokens, - self.speculative_config) - else: - if self.pcp_size * self.dcp_size > 1: - update_attn_dcp_pcp_params(self.update_stream, - forward_context, - num_input_tokens) - else: - update_attn_params(self.update_stream, forward_context, - num_input_tokens, - self.vllm_config) + update_full_graph_params(self.attn_backend, self.update_stream, forward_context, + num_input_tokens, self.vllm_config, + self.vllm_config.speculative_config) if get_forward_context().sp_enabled and not isinstance( hidden_states, IntermediateTensors): diff --git a/vllm_omni/worker/npu/npu_generation_model_runner.py b/vllm_omni/worker/npu/npu_generation_model_runner.py index c6ad1f12a1..f0494b93af 100644 --- a/vllm_omni/worker/npu/npu_generation_model_runner.py +++ b/vllm_omni/worker/npu/npu_generation_model_runner.py @@ -220,11 +220,9 @@ def sample_tokens( pooler_output.append({key: out.detach().to("cpu").contiguous() if out is not None else None}) else: raise RuntimeError("Unsupported diffusion output type") - req_ids_output_copy = list(self.input_batch.req_ids) - req_id_to_index_output_copy = dict(self.input_batch.req_id_to_index) output = OmniModelRunnerOutput( - req_ids=req_ids_output_copy, - req_id_to_index=req_id_to_index_output_copy, + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, From 254d4461799b62fd0a6d2eddb619cd9ddddfe88f Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 09:35:20 +0000 Subject: [PATCH 09/11] align with GPUModelRunner Signed-off-by: gcanlin --- vllm_omni/worker/npu/npu_generation_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_omni/worker/npu/npu_generation_model_runner.py b/vllm_omni/worker/npu/npu_generation_model_runner.py index f0494b93af..e40671fd16 100644 --- a/vllm_omni/worker/npu/npu_generation_model_runner.py +++ b/vllm_omni/worker/npu/npu_generation_model_runner.py @@ -216,8 +216,11 @@ def sample_tokens( {"model_outputs": out.detach().to("cpu").contiguous() if out is not None else None} ) elif isinstance(multimodal_outputs, dict): + mm_payload = {} for key, out in multimodal_outputs.items(): - pooler_output.append({key: out.detach().to("cpu").contiguous() if out is not None else None}) + if out is not None and isinstance(out, torch.Tensor): + mm_payload[key] = out.detach().to("cpu").contiguous() + pooler_output.append(mm_payload) else: raise RuntimeError("Unsupported diffusion output type") output = OmniModelRunnerOutput( From daf45202143b043a108ec1efd3609efa04f5a99e Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 15:28:37 +0000 Subject: [PATCH 10/11] align bagel and chunks Signed-off-by: gcanlin --- vllm_omni/worker/npu/npu_ar_model_runner.py | 229 +++++++++++++++++- .../worker/npu/npu_generation_model_runner.py | 15 ++ vllm_omni/worker/npu/npu_model_runner.py | 57 ++++- 3 files changed, 290 insertions(+), 11 deletions(-) diff --git a/vllm_omni/worker/npu/npu_ar_model_runner.py b/vllm_omni/worker/npu/npu_ar_model_runner.py index 3d632ee474..0eedb61fb4 100644 --- a/vllm_omni/worker/npu/npu_ar_model_runner.py +++ b/vllm_omni/worker/npu/npu_ar_model_runner.py @@ -35,6 +35,7 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.utils import ProfileExecuteDuration +from vllm_omni.core.sched.omni_ar_scheduler import KVCacheTransferData from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.npu.npu_model_runner import OmniNPUModelRunner @@ -89,6 +90,11 @@ def execute_model( raise RuntimeError("State error: sample_tokens() must be called " "after execute_model() returns None.") + # -------------------------------------- Omni-new ------------------------------------------------- + # [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) + self.kv_extracted_req_ids = self._handle_finished_requests_kv_transfer(scheduler_output) + # -------------------------------------- Omni-new ------------------------------------------------- + with ProfileExecuteDuration().capture_async("prepare input"): # -------------------------------------- Omni-new ------------------------------------------------- self._update_states(scheduler_output) @@ -193,8 +199,7 @@ def execute_model( finished_recving = None with ProfileExecuteDuration().capture_async("post process"): # -------------------------------------- Omni-new ------------------------------------------------- - multimodal_outputs = hidden_states.multimodal_outputs - hidden_states = hidden_states.text_hidden_states + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) if multimodal_outputs is not None: keys_or_type = ( @@ -237,7 +242,15 @@ def execute_model( self.debugger.step() return pool_output sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states) + # -------------------------------------- Omni-new ------------------------------------------------- + # Try with sampling_metadata first; fall back to without for models that don't support it + try: + logits = self.model.compute_logits( + sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata + ) + except TypeError: + logits = self.model.compute_logits(sample_hidden_states) + # -------------------------------------- Omni-new ------------------------------------------------- if broadcast_pp_output: model_output_broadcast_data = { "logits": logits.contiguous(), @@ -273,6 +286,11 @@ def sample_tokens( kv_connector_output = self.kv_connector_output self.kv_connector_output = None + # -------------------------------------- Omni-new ------------------------------------------------- + kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None) + self.kv_extracted_req_ids = None + # -------------------------------------- Omni-new ------------------------------------------------- + if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. if not kv_connector_output: @@ -375,7 +393,9 @@ def propose_draft_token_ids(sampled_token_ids): dtype=np.int32, ) - self._process_additional_information_updates(hidden_states, multimodal_outputs, num_scheduled_tokens_np) + self._process_additional_information_updates( + hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + ) pooler_output: list[dict[str, object]] = [] for rid in req_ids_output_copy: @@ -418,6 +438,7 @@ def propose_draft_token_ids(sampled_token_ids): pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), kv_connector_output=kv_connector_output, ) + output.kv_extracted_req_ids = kv_extracted_req_ids # -------------------------------------- Omni-new ------------------------------------------------- durations = ProfileExecuteDuration().pop_captured_sync() @@ -477,3 +498,203 @@ def _generate_process_reqs_hidden_states(self, num_input_tokens, hidden_states) return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states( hidden_states) + + def _handle_finished_requests_kv_transfer(self, scheduler_output: SchedulerOutput) -> list[str]: + """Handle KV cache transfer for finished requests. + + Returns list of request IDs that were processed (for Scheduler to free blocks). + """ + finished_reqs = getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}) + if not finished_reqs: + return [] + + logger.debug(f"Processing KV transfer for {len(finished_reqs)} requests") + + extracted_ids = [] + for req_id, data in finished_reqs.items(): + try: + seq_len = data.get("seq_len", 0) + block_ids = data.get("block_ids", []) + if not block_ids: + logger.warning(f"Request {req_id} has no block IDs, skipping") + continue + + # Extract KV cache from GPU blocks -> CPU tensors + kv_data = self._extract_kv_cache(req_id, block_ids, seq_len) + if kv_data: + # Transfer to downstream stage via connector + self._transfer_kv_cache(kv_data) + + except Exception as e: + logger.error(f"Failed KV transfer for {req_id}: {e}") + finally: + extracted_ids.append(req_id) + + return extracted_ids + + def _extract_kv_cache(self, req_id: str, block_ids: list[int], seq_len: int) -> KVCacheTransferData | None: + """Extract KV cache from GPU blocks for a single request.""" + num_layers = len(self.kv_caches) + key_cache = [None] * num_layers + value_cache = [None] * num_layers + + for layer_idx, kv_tensor in enumerate(self.kv_caches): + # Validate block IDs + max_block = kv_tensor.shape[1] - 1 + valid_ids = [bid for bid in block_ids if 0 <= bid <= max_block] + if not valid_ids: + continue + + # Extract and reshape: [2, n_blocks, block_size, n_heads, head_dim] + # -> [2, seq_len, n_heads, head_dim] + selected = kv_tensor[:, valid_ids] # [2, n_valid, block_size, n_heads, head_dim] + n_kv, n_blks, blk_sz, n_heads, d_head = selected.shape + flat = selected.reshape(n_kv, n_blks * blk_sz, n_heads, d_head) + if seq_len < flat.shape[1]: + flat = flat[:, :seq_len] + + # Move to CPU + flat_cpu = flat.detach().cpu().contiguous() + key_cache[layer_idx] = flat_cpu[0] + value_cache[layer_idx] = flat_cpu[1] + + if not any(k is not None for k in key_cache): + return None + + return KVCacheTransferData( + request_id=req_id, + layer_blocks={"key_cache": key_cache, "value_cache": value_cache}, + block_ids=block_ids, + metadata={ + "block_size": self.cache_config.block_size, + "num_layers": num_layers, + "dtype": str(self.cache_config.cache_dtype), + "seq_len": seq_len, + }, + ) + + def _transfer_kv_cache(self, kv_data: KVCacheTransferData) -> None: + """Transfer KV cache data to downstream stage via OmniConnector.""" + connector = self._get_or_create_connector() + if not connector: + return + + # Resolve global request ID if available + transfer_req_id = self._resolve_global_request_id(kv_data.request_id) + from_stage, to_stage = self._detect_transfer_stages() + + # Prepare data and transfer with retry + data_dict = kv_data.to_dict() + data_dict["request_id"] = transfer_req_id + + success, size, _ = self._transfer_with_retry( + connector, from_stage, to_stage, f"kv_cache_{transfer_req_id}", data_dict + ) + + if success: + logger.info(f"KV transfer OK: {transfer_req_id}, {size} bytes") + else: + logger.error(f"KV transfer FAILED: {transfer_req_id}") + + def _get_or_create_connector(self) -> Any | None: + """Get existing connector or create one from config.""" + if self.omni_connector: + return self.omni_connector + + from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory + from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec + + config = self._get_omni_connector_config() + if not config or not isinstance(config, dict): + logger.warning("No valid OmniConnector config found") + return None + + c_type = config.get("type") + if not c_type: + logger.error("OmniConnector config missing 'type' field") + return None + + c_extra = {k: v for k, v in config.items() if k != "type"} + self.omni_connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra)) + return self.omni_connector + + def _get_omni_connector_config(self) -> dict[str, Any] | None: + """Get OmniConnector configuration from model config.""" + # Primary: omni_kv_config from YAML + omni_kv = getattr(self.model_config, "omni_kv_config", None) + if isinstance(omni_kv, dict): + cfg = omni_kv.get("connector_config") + if isinstance(cfg, dict) and cfg: + return cfg + + # Fallback: kv_transfer_config + kv_cfg = getattr(self.vllm_config, "kv_transfer_config", None) + if kv_cfg: + direct = getattr(kv_cfg, "omni_connector_config", None) + if isinstance(direct, dict) and direct: + return direct + extra = getattr(kv_cfg, "kv_connector_extra_config", None) + if isinstance(extra, dict): + omni = extra.get("omni_connector_config") + if isinstance(omni, dict) and omni: + return omni + + return None + + def _detect_transfer_stages(self) -> tuple[str, str]: + """Detect source and target stages for KV transfer.""" + omni_kv = getattr(self.model_config, "omni_kv_config", None) + if isinstance(omni_kv, dict): + from_s = omni_kv.get("omni_from_stage") + to_s = omni_kv.get("omni_to_stage") + if from_s and to_s: + return str(from_s), str(to_s) + + raise ValueError( + "KV transfer stages not configured. Please set 'omni_from_stage' and 'omni_to_stage' in omni_kv_config." + ) + + def _resolve_global_request_id(self, req_id: str) -> str: + """Resolve global request ID from request state.""" + req_state = self.requests.get(req_id) + if not req_state: + return req_id + + add_info = getattr(req_state, "additional_information_cpu", {}) or {} + global_id = add_info.get("global_request_id") + if global_id: + if isinstance(global_id, list) and global_id: + global_id = global_id[0] + if isinstance(global_id, bytes): + return global_id.decode("utf-8") + return str(global_id) + return req_id + + def _transfer_with_retry( + self, + connector: Any, + from_stage: str, + to_stage: str, + request_id: str, + data: dict[str, Any], + max_retries: int = 3, + ) -> tuple[bool, int, dict[str, Any] | None]: + """Transfer data with retry and exponential backoff.""" + import time + + for attempt in range(max_retries): + try: + put_key = f"omni_{from_stage}_to_{to_stage}_{request_id}" + success, size, metadata = connector.put( + from_stage=from_stage, to_stage=to_stage, put_key=put_key, data=data + ) + if success: + return success, size, metadata + logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}") + except Exception as e: + logger.warning(f"Transfer attempt {attempt + 1} exception: {e}") + + if attempt < max_retries - 1: + time.sleep(0.1 * (2**attempt)) + + return False, 0, None diff --git a/vllm_omni/worker/npu/npu_generation_model_runner.py b/vllm_omni/worker/npu/npu_generation_model_runner.py index e40671fd16..e1beef9825 100644 --- a/vllm_omni/worker/npu/npu_generation_model_runner.py +++ b/vllm_omni/worker/npu/npu_generation_model_runner.py @@ -34,6 +34,17 @@ class NPUGenerationModelRunner(OmniNPUModelRunner): """Generation model runner for vLLM-omni on NPU (non-autoregressive).""" + def _update_request_states(self, scheduler_output: SchedulerOutput): + cached_reqs = scheduler_output.scheduled_cached_reqs + for _, req_id in enumerate(cached_reqs.req_ids): + req_state = self.requests.get(req_id) + assert req_state is not None + req_state.prompt_token_ids = cached_reqs.prompt_token_ids.get(req_id) + self.input_batch.remove_request(req_id) + # update the request state in self.input_batch + self.input_batch.add_request(req_state) + self._init_mrope_positions(req_state) + @torch.inference_mode() def execute_model( self, @@ -44,6 +55,10 @@ def execute_model( raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") with ProfileExecuteDuration().capture_async("prepare input"): + # -------------------------------------- Omni-new ------------------------------------------------- + if self.model_config.async_chunk: + self._update_request_states(scheduler_output) + # -------------------------------------- Omni-new ------------------------------------------------- self._update_states(scheduler_output) if has_ec_transfer() and get_ec_transfer().is_producer: with self.maybe_get_ec_connector_output( diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py index 82ca90c2ef..b0fbf7a4e2 100644 --- a/vllm_omni/worker/npu/npu_model_runner.py +++ b/vllm_omni/worker/npu/npu_model_runner.py @@ -720,6 +720,7 @@ def _process_additional_information_updates( hidden_states: torch.Tensor, multimodal_outputs: object, num_scheduled_tokens_np: np.ndarray, + scheduler_output: "SchedulerOutput", ) -> None: """Process model-provided per-request additional_information updates and merge into request state.""" try: @@ -727,10 +728,13 @@ def _process_additional_information_updates( # TODO(Peiqi): do we have a more elegant way to do this? if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: for req_index, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests.get(req_id) - req_infos = ( - getattr(req_state, "additional_information_cpu", None) if req_state is not None else None - ) + if self.model_config.async_chunk: + req_infos = self._get_additional_information(scheduler_output, req_id) + else: + req_state = self.requests.get(req_id) + req_infos = ( + getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + ) start_offset = int(self.query_start_loc.cpu[req_index]) sched_tokens = int(num_scheduled_tokens_np[req_index]) s, e = start_offset, start_offset + sched_tokens @@ -771,6 +775,40 @@ def _collect_additional_information_for_prefill( start_offset = int(self.query_start_loc.cpu[req_index]) self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + def _get_additional_information(self, scheduler_output: "SchedulerOutput", req_id: str) -> dict: + req_infos = None + req_state = self.requests.get(req_id) + additional_information_cpu = getattr(req_state, "additional_information_cpu", None) + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id == req_id: + payload_info = getattr(new_req, "additional_information", None) + if payload_info is not None: + return payload_info + + if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"): + cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {}) + if isinstance(cached_infos, dict) and req_id in cached_infos: + req_infos = cached_infos[req_id] + if not isinstance(req_infos, dict): + req_infos = None + + if req_infos is None or req_infos.get("last_talker_hidden", None) is None: + if req_infos is None: + additional_information_cpu.pop("thinker_embeddings", None) + req_infos = additional_information_cpu + else: + req_infos["last_talker_hidden"] = additional_information_cpu.get("last_talker_hidden", None) + req_infos["num_processed_thinker_tokens"] = additional_information_cpu.get( + "num_processed_thinker_tokens", 0 + ) + if not isinstance(req_infos, dict): + req_infos = None + + if req_infos is None: + logger.warning(f"No additional_information found for req_id: {req_id}") + + return req_infos + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -886,9 +924,14 @@ def _preprocess( # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] for req_index, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests.get(req_id) - req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None - + # Try to get additional_information from multiple sources + if self.vllm_config.model_config.async_chunk: + req_infos = self._get_additional_information(scheduler_output, req_id) + else: + req_state = self.requests.get(req_id) + req_infos = ( + getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + ) start_offset = int(self.query_start_loc.cpu[req_index]) sched_tokens = int(num_scheduled_tokens_np[req_index]) s, e = start_offset, start_offset + sched_tokens From 25b545fad74a47d76b319a9d29f6b28643b55456 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 26 Jan 2026 15:54:58 +0000 Subject: [PATCH 11/11] Support Qwen3-TTS Signed-off-by: gcanlin --- .../stage_configs/npu/qwen3_tts.yaml | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 vllm_omni/model_executor/stage_configs/npu/qwen3_tts.yaml diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen3_tts.yaml new file mode 100644 index 0000000000..86b5ca60ea --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/npu/qwen3_tts.yaml @@ -0,0 +1,22 @@ +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: qwen3_tts + model_arch: Qwen3TTSForConditionalGeneration + worker_cls: vllm_omni.worker.npu.npu_generation_worker.NPUGenerationWorker + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + + final_output: true + final_output_type: audio