From f9363e196be73dd20a40edd4b9614ad01c76af2e Mon Sep 17 00:00:00 2001 From: gcanlin Date: Sat, 14 Mar 2026 12:48:55 +0000 Subject: [PATCH 1/4] [NPU] Upgrade to v0.17.0 Signed-off-by: gcanlin --- .../qwen3_omni_moe_code_predictor_mtp.py | 6 + .../npu/worker/npu_ar_model_runner.py | 150 +++++++++++++++--- .../npu/worker/npu_generation_model_runner.py | 92 ++++++++--- .../platforms/npu/worker/npu_model_runner.py | 29 ++-- 4 files changed, 224 insertions(+), 53 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py index ad372c8f03..2b823364f6 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py @@ -22,6 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm_omni.platforms import current_omni_platform + logger = init_logger(__name__) @@ -343,6 +345,10 @@ def _ensure_cached_refs(self) -> None: def _ensure_model_fwd(self) -> None: if self._model_fwd is not None: return + if not current_omni_platform.supports_torch_inductor(): + logger.warning_once("code_predictor: torch.compile disabled") + self._model_fwd = self.model.forward + return self._model_fwd = torch.compile( self.model.forward, mode="default", diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index f1c18e543b..22badb49ef 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import RoutedExpertsCapturer from vllm.sequence import IntermediateTensors from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.outputs import ( @@ -33,7 +34,7 @@ # yapf conflicts with isort for this block # yapf: disable from vllm_ascend.ops.rotary_embedding import update_cos_sin -from vllm_ascend.utils import enable_sp +from vllm_ascend.utils import enable_sp, global_stream from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput @@ -55,7 +56,7 @@ class ExecuteModelState(NamedTuple): positions: torch.Tensor ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None - multimodal_outputs: Any + multimodal_outputs: Any # Omni-Specific class NPUARModelRunner(OmniNPUModelRunner): """Autoregressive NPU model runner that returns hidden states per request.""" @@ -88,13 +89,36 @@ def execute_model( scheduler_output: SchedulerOutput, intermediate_tensors: IntermediateTensors | None = None, ) -> OmniModelRunnerOutput | IntermediateTensors | None: + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() + else: + logger.warning("RoutedExpertsCapturer is not initialized.") if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") # -------------------------------------- Omni-new ------------------------------------------------- # [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) + if not getattr(self, "_warmup_state_cleared", False): + self._warmup_state_cleared = True + if hasattr(self.model, "_clear_warmup_state"): + self.model._clear_warmup_state() + + # [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) + finished_reqs = getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}) + if finished_reqs and hasattr(self.model, "get_kv_transfer_metadata"): + for req_id, data in finished_reqs.items(): + try: + model_meta = self.model.get_kv_transfer_metadata(req_id) + if model_meta: + existing = data.get("custom_metadata") or {} + existing.update(model_meta) + data["custom_metadata"] = existing + except Exception as e: + logger.warning(f"Failed to get custom metadata from model for {req_id}: {e}") self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer( - finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}), + finished_reqs=finished_reqs, kv_caches=self.kv_caches, block_size=self.cache_config.block_size, cache_dtype=str(self.cache_config.cache_dtype), @@ -115,10 +139,8 @@ def execute_model( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with record_function_or_nullcontext("prepare input"): with self.synchronize_input_prep(): - # -------------------------------------- Omni-new ------------------------------------------------- - self._update_states(scheduler_output) - # ------------------------------------------------------------------------------------------------ # Update persistent batch states. + self._update_states(scheduler_output) if has_ec_transfer() and get_ec_transfer().is_producer: with self.maybe_get_ec_connector_output( @@ -150,6 +172,7 @@ def execute_model( "logprobs for prompt tokens, tokens, please disable " "it when the requests need prompt logprobs" ) + num_reqs = self.input_batch.num_reqs req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] @@ -159,6 +182,7 @@ def execute_model( ( logits_indices, spec_decode_metadata, + total_num_scheduled_tokens, ) = self._prepare_inputs( scheduler_output, num_scheduled_tokens_np, @@ -224,10 +248,19 @@ def execute_model( # Currently, Graph Mode and SP will both pad num_tokens, # Another possible condition is num_tokens_padded != num_tokens_unpadded # but this scope is way too big and the consequences are unpredictable - num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) + old_num_reqs_padded = num_reqs_padded + num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_mode, batch_desc.num_reqs + ) + if enable_sp() and num_tokens_padded == num_tokens_unpadded: + if num_reqs_padded > old_num_reqs_padded: + num_reqs_padded = old_num_reqs_padded + self.query_start_loc.np[num_reqs_padded + 1] = 0 (attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata( - num_tokens=num_tokens_unpadded, + num_tokens=num_tokens_unpadded + if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn) + else total_num_scheduled_tokens, num_tokens_padded=num_tokens_padded, num_reqs=num_reqs, num_reqs_padded=num_reqs_padded, @@ -247,9 +280,21 @@ def execute_model( intermediate_tensors, model_kwargs, ec_connector_output, - ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) + ) = self._preprocess( + scheduler_output, + num_tokens_padded + if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn) + else total_num_scheduled_tokens, + intermediate_tensors, + ) + # update global cos, sin update_cos_sin(positions) + + if self.dynamic_eplb: + with record_function_or_nullcontext("EPLB weight D2D"): + self.eplb_updator.forward_before() + # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible # with CUDA graph capture. @@ -278,6 +323,7 @@ def execute_model( has_encoder_input = self.model_config.is_encoder_decoder and num_encoder_reqs > 0 # Run forward pass + clear_kv_metadata = self.speculative_config is None with ( record_function_or_nullcontext("forward"), set_ascend_forward_context( @@ -289,15 +335,22 @@ def execute_model( batch_descriptor=batch_desc, num_actual_tokens=scheduler_output.total_num_scheduled_tokens, model_instance=self.model, + max_tokens_across_pcp=0 if self.pcp_size == 1 else self.pcp_manager.max_num_tokens_across_pcp, skip_compiled=has_encoder_input, ), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, clear_metadata=clear_kv_metadata + ) as kv_connector_output, ): hidden_states = self._model_forward( num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs ) with record_function_or_nullcontext("post process"): # -------------------------------------- Omni-new ------------------------------------------------- + # [Omni] Map pending ropes metadata to req_ids. + if hasattr(self.model, "flush_pending_metadata"): + self.model.flush_pending_metadata(list(req_ids)) + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) if multimodal_outputs is not None: @@ -310,13 +363,18 @@ def execute_model( else: logger.debug("[AR] execute_model: multimodal_outputs is None") # -------------------------------------- Omni-new ------------------------------------------------- + aux_hidden_states = None + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = hidden_states if self.pcp_size > 1: # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states) - aux_hidden_states = None - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = hidden_states + if aux_hidden_states is not None: + aux_hidden_states = [ + self.pcp_manager.get_restore_hidden_states(aux_hidden_states_pcp) + for aux_hidden_states_pcp in aux_hidden_states + ] if not self.broadcast_pp_output: # Common case. @@ -392,7 +450,7 @@ def execute_model( positions, ec_connector_output, cudagraph_stats, - multimodal_outputs, + multimodal_outputs, # Omni-specific ) self.kv_connector_output = kv_connector_output return None @@ -412,6 +470,11 @@ def sample_tokens( if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. + # receive sampled token ids from the last PP rank when using + # async scheduling + pipeline parallelism so downstream code + # (e.g., PCP input preparation) can access them. + if self.use_async_scheduling and get_pp_group().world_size > 1: + self._pp_receive_prev_sampled_token_ids_to_input_batch() if not kv_connector_output: return None # noqa # In case of PP with kv transfer, we need to pass through the @@ -436,7 +499,7 @@ def sample_tokens( positions, ec_connector_output, cudagraph_stats, - multimodal_outputs, + multimodal_outputs, # Omni-Specific ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -450,9 +513,27 @@ def sample_tokens( apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits) logits = logits.to(self.device).to(logits_dtype) + # -------------------------------------- Omni-new ------------------------------------------------- + # Correct padding values of prompt_token_ids to match the logits vocabulary size. + if logits is not None and not self.input_batch.sampling_metadata.no_penalties: + smd = self.input_batch.sampling_metadata + if smd.prompt_token_ids is not None: + logits_vocab = logits.shape[-1] + if self.input_batch.vocab_size > logits_vocab: + smd.prompt_token_ids = smd.prompt_token_ids.clamp(max=logits_vocab) + # -------------------------------------- Omni-new ------------------------------------------------- + + with record_function_or_nullcontext("sample_token"): sampler_output = self._sample(logits, spec_decode_metadata) + if self.need_accepted_tokens: + if self.sampling_done_event is None: + self.sampling_done_event = torch.npu.Event() + + assert self.sampling_done_event is not None + self.sampling_done_event.record() + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( @@ -464,7 +545,6 @@ def propose_draft_token_ids(sampled_token_ids): positions, scheduler_output.total_num_scheduled_tokens, hidden_states, - attn_metadata, aux_hidden_states, sample_hidden_states, ) @@ -488,22 +568,30 @@ def propose_draft_token_ids(sampled_token_ids): with record_function_or_nullcontext("draft_token"): if self.speculative_config: - use_padded_batch_for_eagle = ( + use_padded_batch = ( self.speculative_config - and self.speculative_config.use_eagle() + and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()) and not self.speculative_config.disable_padded_drafter_batch ) - if use_padded_batch_for_eagle: + if use_padded_batch: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - if self.speculative_config and not use_padded_batch_for_eagle: + if self.speculative_config and not use_padded_batch: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() + + if self.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.cpu_slot_mapping) + else: + logger.warning("RoutedExpertsCapturer is not initialized.") + # -------------------------------------- Omni-new ------------------------------------------------- hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None) @@ -562,13 +650,31 @@ def propose_draft_token_ids(sampled_token_ids): model_runner_output.kv_extracted_req_ids = kv_extracted_req_ids # -------------------------------------- Omni-new ------------------------------------------------- - if self.dynamic_eplb: - self.eplb_updator.forward_end() + with record_function_or_nullcontext("EPLB update"): + self.eplb_updator.forward_end() if self.debugger is not None: self.debugger.stop() self.debugger.step() + + if self.need_accepted_tokens: + assert self.sampling_done_event is not None + with ( + record_function_or_nullcontext("async_state_update"), + torch.npu.stream(global_stream()), + ): + global_stream().wait_event(self.sampling_done_event) + self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) + + # In async scheduling + PP, broadcast sampled token ids from the + # last PP rank so other PP ranks can receive them without going + # through the scheduler/engine IPC path. + if self.use_async_scheduling: + pp = get_pp_group() + if pp.world_size > 1 and pp.is_last_rank: + self._pp_broadcast_prev_sampled_token_ids(sampler_output.sampled_token_ids) + if not self.use_async_scheduling: return model_runner_output return AsyncGPUModelRunnerOutput( diff --git a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py index 7651e365a3..0296c042d5 100644 --- a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py @@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer import has_kv_transfer_group from vllm.distributed.parallel_state import get_pp_group from vllm.logger import logger +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import RoutedExpertsCapturer from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -69,7 +70,13 @@ def execute_model( self, scheduler_output: SchedulerOutput, intermediate_tensors: IntermediateTensors | None = None, - ) -> OmniModelRunnerOutput | IntermediateTensors: + ) -> OmniModelRunnerOutput | IntermediateTensors | None: + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() + else: + logger.warning("RoutedExpertsCapturer is not initialized.") if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") # self._draft_token_ids is None when `input_fits_in_drafter=False` @@ -123,6 +130,7 @@ def execute_model( "logprobs for prompt tokens, tokens, please disable " "it when the requests need prompt logprobs" ) + num_reqs = self.input_batch.num_reqs req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] @@ -132,6 +140,7 @@ def execute_model( ( logits_indices, spec_decode_metadata, + total_num_scheduled_tokens, ) = self._prepare_inputs( scheduler_output, num_scheduled_tokens_np, @@ -197,10 +206,19 @@ def execute_model( # Currently, Graph Mode and SP will both pad num_tokens, # Another possible condition is num_tokens_padded != num_tokens_unpadded # but this scope is way too big and the consequences are unpredictable - num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) + old_num_reqs_padded = num_reqs_padded + num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_mode, batch_desc.num_reqs + ) + if enable_sp() and num_tokens_padded == num_tokens_unpadded: + if num_reqs_padded > old_num_reqs_padded: + num_reqs_padded = old_num_reqs_padded + self.query_start_loc.np[num_reqs_padded + 1] = 0 (attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata( - num_tokens=num_tokens_unpadded, + num_tokens=num_tokens_unpadded + if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn) + else total_num_scheduled_tokens, num_tokens_padded=num_tokens_padded, num_reqs=num_reqs, num_reqs_padded=num_reqs_padded, @@ -220,7 +238,13 @@ def execute_model( intermediate_tensors, model_kwargs, ec_connector_output, - ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) + ) = self._preprocess( + scheduler_output, + num_tokens_padded + if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn) + else total_num_scheduled_tokens, + intermediate_tensors, + ) # [Omni] Pass token counts per request for code2wav output slicing model_kwargs["seq_token_counts"] = tokens @@ -228,6 +252,10 @@ def execute_model( # update global cos, sin update_cos_sin(positions) + if self.dynamic_eplb: + with record_function_or_nullcontext("EPLB weight D2D"): + self.eplb_updator.forward_before() + # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible # with CUDA graph capture. @@ -256,6 +284,7 @@ def execute_model( has_encoder_input = self.model_config.is_encoder_decoder and num_encoder_reqs > 0 # Run forward pass + clear_kv_metadata = self.speculative_config is None with ( record_function_or_nullcontext("forward"), set_ascend_forward_context( @@ -267,9 +296,12 @@ def execute_model( batch_descriptor=batch_desc, num_actual_tokens=scheduler_output.total_num_scheduled_tokens, model_instance=self.model, + max_tokens_across_pcp=0 if self.pcp_size == 1 else self.pcp_manager.max_num_tokens_across_pcp, skip_compiled=has_encoder_input, ), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, clear_metadata=clear_kv_metadata + ) as kv_connector_output, ): # -------------------------------------- Omni-new ------------------------------------------------- outputs = self._run_generation_model( @@ -296,7 +328,7 @@ def execute_model( positions, ec_connector_output, cudagraph_stats, - multimodal_outputs, # Omni-new: pass multimodal_outputs to ExecuteModelState + multimodal_outputs, # Omni-specific: pass multimodal_outputs to ExecuteModelState ) self.kv_connector_output = kv_connector_output return None @@ -310,6 +342,11 @@ def sample_tokens( if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. + # receive sampled token ids from the last PP rank when using + # async scheduling + pipeline parallelism so downstream code + # (e.g., PCP input preparation) can access them. + if self.use_async_scheduling and get_pp_group().world_size > 1: + self._pp_receive_prev_sampled_token_ids_to_input_batch() if not kv_connector_output: return None # noqa # In case of PP with kv transfer, we need to pass through the @@ -334,7 +371,7 @@ def sample_tokens( positions, ec_connector_output, cudagraph_stats, - multimodal_outputs, + multimodal_outputs, # Omni-Specific ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -464,7 +501,6 @@ def _dummy_run( allow_microbatching: bool = True, skip_eplb: bool = False, remove_lora: bool = True, - activate_lora: bool = False, is_graph_capturing: bool = False, num_active_loras: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -505,6 +541,9 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs + if not is_profile and self.dynamic_eplb: + self.eplb_updator.forward_before() + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) self.query_lens = torch.from_numpy(num_scheduled_tokens) num_tokens_unpadded = int(num_scheduled_tokens.sum()) @@ -525,7 +564,8 @@ def _dummy_run( # `force_has_lora` is used for cudagraph capture; because LoRA is # activated later in the context manager, but we need to know the # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, + force_has_lora=num_active_loras > 0, + force_num_active_loras=num_active_loras, ) if self.use_cp: self.pcp_manager.init_batch_info( @@ -551,9 +591,8 @@ def _dummy_run( # vllm-ascend does not support ubatch now ubatch_slices, ubatch_slices_padded = None, None attn_metadata: PerLayerAttnMetadata | None = None - # If force_attention is True, we always capture attention. Otherwise, - # it only happens for cudagraph_runtime_mode=FULL. - if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + # Build attention metadata for dummy_run + if self._should_build_dummy_attn_metadata(force_attention, is_profile, cudagraph_runtime_mode): if create_mixed_batch: raise NotImplementedError( "create_mixed_batch is used for warmup deepgemm, vllm-ascend does not need it" @@ -582,8 +621,9 @@ def _dummy_run( cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) self.query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - - num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) + num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_runtime_mode, batch_desc.num_reqs + ) pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( @@ -600,6 +640,11 @@ def _dummy_run( self.lora_config, num_scheduled_tokens, num_sampled_tokens, + remove_lora, + # TODO: The next line is a temporary workaround + # to fix the accuracy issue of test_llama32_lora.py, + # which is introduced by vllm-project/vllm#32005 + num_active_loras=(self.lora_config.max_loras if self.lora_config is not None else num_active_loras), ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens @@ -610,6 +655,19 @@ def _dummy_run( input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None + # -------------------------------------- Omni-new ------------------------------------------------- + model_kwargs = self._init_model_kwargs() + # Some generation-stage models (e.g. MammothModa2DiTPipeline) require + # model-specific runtime information (such as image size and conditioning + # embeddings) even during the dummy profiling run that vLLM uses to + # estimate KV-cache capacity. get_dummy_runtime_additional_information + # provides placeholder values of the correct shape so that the profiling + # run does not raise an error due to missing inputs. + if hasattr(self.model, "get_dummy_runtime_additional_information"): + runtime_addi = self.model.get_dummy_runtime_additional_information(num_reqs) + model_kwargs["runtime_additional_information"] = runtime_addi + # -------------------------------------- Omni-new ------------------------------------------------- + if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens_padded] elif self.uses_xdrope_dim > 0: @@ -656,8 +714,6 @@ def dummy_drafter_compute_logits(hidden_states): if hasattr(self.drafter, "model") and hasattr(self.drafter.model, "compute_logits"): return self.drafter.model.compute_logits(hidden_states[dummy_indices]) - model_kwargs = self._init_model_kwargs() - with set_ascend_forward_context( attn_metadata, self.vllm_config, @@ -674,9 +730,8 @@ def dummy_drafter_compute_logits(hidden_states): positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **model_kwargs, + **model_kwargs, # Omni-specific ) - if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -698,7 +753,6 @@ def dummy_drafter_compute_logits(hidden_states): if is_profile and self.dynamic_eplb: self.model.clear_all_moe_loads() if self.dynamic_eplb: - self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() # -------------------------------------- Omni-new ------------------------------------------------- diff --git a/vllm_omni/platforms/npu/worker/npu_model_runner.py b/vllm_omni/platforms/npu/worker/npu_model_runner.py index 254d27a4f4..9ff5f720e9 100644 --- a/vllm_omni/platforms/npu/worker/npu_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_model_runner.py @@ -71,7 +71,6 @@ def _dummy_run( allow_microbatching: bool = True, skip_eplb: bool = False, remove_lora: bool = True, - activate_lora: bool = False, is_graph_capturing: bool = False, num_active_loras: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -112,6 +111,9 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs + if not is_profile and self.dynamic_eplb: + self.eplb_updator.forward_before() + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) self.query_lens = torch.from_numpy(num_scheduled_tokens) num_tokens_unpadded = int(num_scheduled_tokens.sum()) @@ -132,7 +134,8 @@ def _dummy_run( # `force_has_lora` is used for cudagraph capture; because LoRA is # activated later in the context manager, but we need to know the # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, + force_has_lora=num_active_loras > 0, + force_num_active_loras=num_active_loras, ) if self.use_cp: self.pcp_manager.init_batch_info( @@ -158,9 +161,8 @@ def _dummy_run( # vllm-ascend does not support ubatch now ubatch_slices, ubatch_slices_padded = None, None attn_metadata: PerLayerAttnMetadata | None = None - # If force_attention is True, we always capture attention. Otherwise, - # it only happens for cudagraph_runtime_mode=FULL. - if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + # Build attention metadata for dummy_run + if self._should_build_dummy_attn_metadata(force_attention, is_profile, cudagraph_runtime_mode): if create_mixed_batch: raise NotImplementedError( "create_mixed_batch is used for warmup deepgemm, vllm-ascend does not need it" @@ -189,8 +191,9 @@ def _dummy_run( cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) self.query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - - num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) + num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_runtime_mode, batch_desc.num_reqs + ) pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( @@ -207,6 +210,11 @@ def _dummy_run( self.lora_config, num_scheduled_tokens, num_sampled_tokens, + remove_lora, + # TODO: The next line is a temporary workaround + # to fix the accuracy issue of test_llama32_lora.py, + # which is introduced by vllm-project/vllm#32005 + num_active_loras=(self.lora_config.max_loras if self.lora_config is not None else num_active_loras), ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens @@ -287,9 +295,7 @@ def dummy_drafter_compute_logits(hidden_states): ) self.compilation_config.cache_dir = None # ---------------------------------------Omni-new---------------------------------------------- - # NOTE: Directly call self.model() instead of self._model_forward() to match - # GPU behavior. _model_forward contains Omni-specific logic (make_omni_output) - # that requires valid runtime_additional_information, which is empty during dummy run. + # Call self.model() directly (like GPU) to avoid make_omni_output during dummy_run outputs = self.model( input_ids=input_ids, positions=positions, @@ -320,7 +326,6 @@ def dummy_drafter_compute_logits(hidden_states): if is_profile and self.dynamic_eplb: self.model.clear_all_moe_loads() if self.dynamic_eplb: - self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() return hidden_states, hidden_states @@ -382,7 +387,7 @@ def _model_forward( ) # NPU-specific: all-gather for sequence parallelism - if get_forward_context().sp_enabled and not isinstance(model_output, IntermediateTensors): + if get_forward_context().flash_comm_v1_enabled and not isinstance(model_output, IntermediateTensors): model_output = self._all_gather_hidden_states_and_aux(model_output) return model_output From 1eae4aa8a02f1d3ac705cfc3362924b4cff52a23 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Sat, 14 Mar 2026 13:03:52 +0000 Subject: [PATCH 2/4] add docs and dockerfile Signed-off-by: gcanlin --- docker/Dockerfile.npu | 12 +-- docker/Dockerfile.npu.a3 | 12 +-- .../installation/npu/npu.inc.md | 18 ++-- .../qwen3_omni_moe_async_chunk.yaml | 101 ++++++++++++++++++ 4 files changed, 125 insertions(+), 18 deletions(-) create mode 100644 vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml diff --git a/docker/Dockerfile.npu b/docker/Dockerfile.npu index 1bd8601ef0..a5abaa679f 100644 --- a/docker/Dockerfile.npu +++ b/docker/Dockerfile.npu @@ -1,17 +1,17 @@ ARG VLLM_ASCEND_IMAGE=quay.io/ascend/vllm-ascend -ARG VLLM_ASCEND_TAG=v0.14.0rc1 +ARG VLLM_ASCEND_TAG=v0.17.0rc1 FROM ${VLLM_ASCEND_IMAGE}:${VLLM_ASCEND_TAG} -WORKDIR /vllm-workspace/vllm-ascend -RUN git checkout e2175d9c7e62b437391dfee996b1375674ba7c18 -RUN pip install -v -e . - ARG APP_DIR=/vllm-workspace/vllm-omni WORKDIR ${APP_DIR} COPY . . -RUN pip install -v -e . +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + python3 -m pip install -v -e /vllm-workspace/vllm-omni/ --no-build-isolation ENV VLLM_WORKER_MULTIPROC_METHOD=spawn diff --git a/docker/Dockerfile.npu.a3 b/docker/Dockerfile.npu.a3 index d521288f66..340bc61759 100644 --- a/docker/Dockerfile.npu.a3 +++ b/docker/Dockerfile.npu.a3 @@ -1,17 +1,17 @@ ARG VLLM_ASCEND_IMAGE=quay.io/ascend/vllm-ascend -ARG VLLM_ASCEND_TAG=v0.14.0rc1-a3 +ARG VLLM_ASCEND_TAG=v0.17.0rc1-a3 FROM ${VLLM_ASCEND_IMAGE}:${VLLM_ASCEND_TAG} -WORKDIR /vllm-workspace/vllm-ascend -RUN git checkout e2175d9c7e62b437391dfee996b1375674ba7c18 -RUN pip install -v -e . - ARG APP_DIR=/vllm-workspace/vllm-omni WORKDIR ${APP_DIR} COPY . . -RUN pip install -v -e . +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + python3 -m pip install -v -e /vllm-workspace/vllm-omni/ --no-build-isolation ENV VLLM_WORKER_MULTIPROC_METHOD=spawn diff --git a/docs/getting_started/installation/npu/npu.inc.md b/docs/getting_started/installation/npu/npu.inc.md index 27a3518a8b..cf9ab49a77 100644 --- a/docs/getting_started/installation/npu/npu.inc.md +++ b/docs/getting_started/installation/npu/npu.inc.md @@ -33,9 +33,15 @@ docker run --rm \ -p 8000:8000 \ -it $IMAGE bash +cd /vllm-workspace/vllm-ascend +git pull origin main +git fetch origin --tags +git checkout v0.16.0 + # Because vllm-ascend will release v0.16.0rc1 after vllm-omni 0.16.0, # we have to pin vllm-ascend at the current commit. cd /vllm-workspace/vllm-ascend +git pull origin main git checkout e2175d9c7e62b437391dfee996b1375674ba7c18 pip install -v -e . @@ -44,7 +50,7 @@ cd /vllm-workspace git clone -b v0.16.0 https://github.com/vllm-project/vllm-omni.git cd vllm-omni -pip install -v -e . +pip install -v -e . --no-build-isolation export VLLM_WORKER_MULTIPROC_METHOD=spawn ``` @@ -61,22 +67,22 @@ We are keeping [issue #886](https://github.com/vllm-project/vllm-omni/issues/886 You can also build vLLM-Omni from the latest main branch if you want to use the latest features or bug fixes. (But sometimes it will break for a while. You can check [issue #886](https://github.com/vllm-project/vllm-omni/issues/886) for the status of the latest commit of vLLM-Omni main branch on NPU.) ```bash -# Pin vLLM version to 0.16.0 +# Pin vLLM version to 0.17.0 cd /vllm-workspace/vllm git pull origin main git fetch origin --tags -git checkout v0.16.0 +git checkout v0.17.0 VLLM_TARGET_DEVICE=empty pip install -v -e . # Because vllm-ascend has not yet entered continuous development and has not been officially released, we need to pin it to a specific commit. Please note that this commit may change over time. -cd ../vllm-ascend +cd /vllm-workspace/vllm-ascend git pull origin main git fetch origin --tags -git checkout e2175d9c7e62b437391dfee996b1375674ba7c18 +git checkout v0.17.0 pip install -v -e . # Install vLLM-Omni from the latest main branch -cd ../vllm-omni +cd /vllm-workspace/vllm-omni git clone https://github.com/vllm-project/vllm-omni.git pip install -v -e . --no-build-isolation # or VLLM_OMNI_TARGET_DEVICE=npu pip install -v -e . diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml new file mode 100644 index 0000000000..4ede584b59 --- /dev/null +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml @@ -0,0 +1,101 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) +# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +async_chunk: true +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0,1" + max_batch_size: 10 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 2 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "2" + max_batch_size: 10 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk + engine_input_source: [0] + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.0 + stop_token_ids: [2150] + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "2" + max_batch_size: 10 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 51200 # [TODO] if max_num_batch_tokens < max_batch_size * 800, there will be precision problem. + hf_config_name: thinker_config + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 From cf40cb83d461cd3a94ec4eac6e49000b1de9de99 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Sun, 15 Mar 2026 12:02:02 +0000 Subject: [PATCH 3/4] update docs Signed-off-by: gcanlin --- docs/getting_started/installation/npu/npu.inc.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/getting_started/installation/npu/npu.inc.md b/docs/getting_started/installation/npu/npu.inc.md index cf9ab49a77..bc2d3b60cb 100644 --- a/docs/getting_started/installation/npu/npu.inc.md +++ b/docs/getting_started/installation/npu/npu.inc.md @@ -33,7 +33,7 @@ docker run --rm \ -p 8000:8000 \ -it $IMAGE bash -cd /vllm-workspace/vllm-ascend +cd /vllm-workspace/vllm git pull origin main git fetch origin --tags git checkout v0.16.0 @@ -48,9 +48,10 @@ pip install -v -e . # Inside the container, install vLLM-Omni from source cd /vllm-workspace git clone -b v0.16.0 https://github.com/vllm-project/vllm-omni.git - cd vllm-omni pip install -v -e . --no-build-isolation +# or VLLM_OMNI_TARGET_DEVICE=npu pip install -v -e . + export VLLM_WORKER_MULTIPROC_METHOD=spawn ``` From f14f84818aa68d352521ad95240eaab10b97fd46 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Sun, 15 Mar 2026 12:30:44 +0000 Subject: [PATCH 4/4] fix qwen3-tts Signed-off-by: gcanlin --- .../models/qwen3_tts/qwen3_tts_code_predictor_vllm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index 7f3fdfd7a9..2520a1d87e 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -21,6 +21,8 @@ ) from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm_omni.platforms import current_omni_platform + from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig logger = init_logger(__name__) @@ -410,6 +412,10 @@ def _setup_compile(self) -> None: """ if self._compiled_model_fwd is not None: return + if not current_omni_platform.supports_torch_inductor(): + logger.warning_once("code_predictor: torch.compile disabled") + self._compiled_model_fwd = self.model.forward + return self._compiled_model_fwd = torch.compile( self.model.forward, mode="default",