From a20ee78b063d1b79ae08e7ca74e921eb3400fd25 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Mon, 11 May 2026 16:10:21 +0000 Subject: [PATCH 1/4] [BugFix][NPU] Honor prefer_model_sampler in NPU AR runner] Signed-off-by: gcanlin --- .../npu/worker/npu_ar_model_runner.py | 34 +++++++++++++ vllm_omni/worker/gpu_ar_model_runner.py | 48 ------------------ vllm_omni/worker/gpu_model_runner.py | 49 +++++++++++++++++++ 3 files changed, 83 insertions(+), 48 deletions(-) diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index 8cff1849aa5..67a2f8f584a 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -717,6 +717,40 @@ def execute_model( deferred_state_corrections_fn() return None + def _sample( + self, + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, + ): + """Dispatch to the model's custom sampler when ``prefer_model_sampler`` + is set; otherwise delegate to the parent ``_sample`` (which handles + ``lmhead_tp_enable`` slicing and the spec-decode rejection sampler). + + Mirrors ``GPUARModelRunner._sample`` so models like HunyuanImage3 and + CosyVoice3 get the same stage-transition / RAS sampler behavior on NPU. + """ + if spec_decode_metadata is None: + model_sample = getattr(self.model, "sample", None) + if logits is not None and callable(model_sample) and getattr(self.model, "prefer_model_sampler", False): + sampling_metadata = self.input_batch.sampling_metadata + # Apply logit bias (min_tokens, allowed_token_ids) before the + # custom model sampler — the standard sampler does this + # internally, but prefer_model_sampler bypasses it. + if hasattr(self.sampler, "logit_bias_state"): + self.sampler.logit_bias_state.apply_logit_bias( + logits, + self.input_batch.expanded_idx_mapping, + self.input_batch.idx_mapping_np, + self.input_batch.positions[self.input_batch.logits_indices], + ) + sampler_output = model_sample( + logits, + self._sampling_metadata_for_model_sampler(sampling_metadata), + ) + if sampler_output is not None: + return sampler_output + return super()._sample(logits, spec_decode_metadata) + @torch.inference_mode() def sample_tokens( self, grammar_output: GrammarOutput | None diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 1decf13cb69..f689069c2e9 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -8,7 +8,6 @@ from contextlib import nullcontext from copy import copy -from dataclasses import replace from typing import Any, NamedTuple import numpy as np @@ -96,53 +95,6 @@ def _make_buffer(self, *size, dtype, numpy=True): with maybe_disable_pin_memory_for_ray(self, total_bytes): return super()._make_buffer(*size, dtype=dtype, numpy=numpy) - def _build_model_sampler_output_token_ids(self) -> list[list[int]]: - """Build decoded-token history for custom model samplers. - - vLLM only populates sampling_metadata.output_token_ids when penalties or - logits processors require it. CosyVoice3's custom RAS sampler also - depends on this history, so we reconstruct it directly from the input - batch for prefer_model_sampler models. - """ - req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", []) - req_ids = list(getattr(self.input_batch, "req_ids", [])) - output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))] - - sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None) - async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None) - prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None) - if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None: - return output_token_ids - - sampled_token_ids: list[list[int]] | None = None - for index, req_id in enumerate(req_ids): - prev_index = prev_req_id_to_index.get(req_id) - if prev_index is None: - continue - req_history = output_token_ids[index] - if not req_history or req_history[-1] != -1: - continue - if sampled_token_ids is None: - assert async_copy_ready_event is not None - async_copy_ready_event.synchronize() - sampled_token_ids = sampled_token_ids_cpu.tolist() - new_ids = list(sampled_token_ids[prev_index]) - if not new_ids: - continue - num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) - first_placeholder = req_history.index(-1) - num_placeholders = len(req_history) - first_placeholder - num_to_replace = min(num_sampled_ids, num_placeholders) - req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace] - - return output_token_ids - - def _sampling_metadata_for_model_sampler(self, sampling_metadata): - output_token_ids = self._build_model_sampler_output_token_ids() - if output_token_ids == sampling_metadata.output_token_ids: - return sampling_metadata - return replace(sampling_metadata, output_token_ids=output_token_ids) - def _request_final_stage_id(self, req_id: str) -> int | None: info = self.model_intermediate_buffer.get(req_id) if not isinstance(info, dict): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index f4eadedc046..336d552915d 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1,3 +1,4 @@ +from dataclasses import replace from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -114,6 +115,54 @@ def load_model(self, *args, **kwargs) -> None: self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + def _build_model_sampler_output_token_ids(self) -> list[list[int]]: + """Build decoded-token history for ``prefer_model_sampler`` models. + + vLLM only populates ``sampling_metadata.output_token_ids`` when penalties + or logits processors require it. Models that opt into a custom sampler + (e.g. CosyVoice3 RAS sampler, HunyuanImage3 stage-transition sampler) + also depend on this history, so we reconstruct it directly from the + input batch. Shared by GPU and NPU AR runners. + """ + req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", []) + req_ids = list(getattr(self.input_batch, "req_ids", [])) + output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))] + + sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None) + async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None) + prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None) + if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None: + return output_token_ids + + sampled_token_ids: list[list[int]] | None = None + for index, req_id in enumerate(req_ids): + prev_index = prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_history = output_token_ids[index] + if not req_history or req_history[-1] != -1: + continue + if sampled_token_ids is None: + assert async_copy_ready_event is not None + async_copy_ready_event.synchronize() + sampled_token_ids = sampled_token_ids_cpu.tolist() + new_ids = list(sampled_token_ids[prev_index]) + if not new_ids: + continue + num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) + first_placeholder = req_history.index(-1) + num_placeholders = len(req_history) - first_placeholder + num_to_replace = min(num_sampled_ids, num_placeholders) + req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace] + + return output_token_ids + + def _sampling_metadata_for_model_sampler(self, sampling_metadata): + output_token_ids = self._build_model_sampler_output_token_ids() + if output_token_ids == sampling_metadata.output_token_ids: + return sampling_metadata + return replace(sampling_metadata, output_token_ids=output_token_ids) + def _init_mrope_positions(self, req_state: CachedRequestState): """Initialize M-RoPE positions for multimodal inputs. From 8b3c64b1ee043247767b4cf8faca4e4730eb1194 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Fri, 22 May 2026 07:31:16 +0000 Subject: [PATCH 2/4] fix Signed-off-by: gcanlin --- .../npu/worker/npu_ar_model_runner.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index eb710e28aa7..8d058335e9d 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -733,22 +733,16 @@ def execute_model( def _sample( self, logits: torch.Tensor | None, - spec_decode_metadata: SpecDecodeMetadata | None, + spec_decode_metadata: Any, ): - """Dispatch to the model's custom sampler when ``prefer_model_sampler`` - is set; otherwise delegate to the parent ``_sample`` (which handles - ``lmhead_tp_enable`` slicing and the spec-decode rejection sampler). - - Mirrors ``GPUARModelRunner._sample`` so models like HunyuanImage3 and - CosyVoice3 get the same stage-transition / RAS sampler behavior on NPU. - """ + sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: model_sample = getattr(self.model, "sample", None) + self.input_batch.update_async_output_token_ids() if logits is not None and callable(model_sample) and getattr(self.model, "prefer_model_sampler", False): - sampling_metadata = self.input_batch.sampling_metadata - # Apply logit bias (min_tokens, allowed_token_ids) before the - # custom model sampler — the standard sampler does this - # internally, but prefer_model_sampler bypasses it. + # Apply logit bias (min_tokens, allowed_token_ids) before + # the custom model sampler — the standard GPU sampler does + # this internally, but prefer_model_sampler bypasses it. if hasattr(self.sampler, "logit_bias_state"): self.sampler.logit_bias_state.apply_logit_bias( logits, @@ -762,6 +756,11 @@ def _sample( ) if sampler_output is not None: return sampler_output + return self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return super()._sample(logits, spec_decode_metadata) @torch.inference_mode() From 50376bf59b68b5642595290ee719cd7853778959 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Fri, 22 May 2026 07:54:30 +0000 Subject: [PATCH 3/4] fix Signed-off-by: gcanlin --- vllm_omni/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 0f5d254ea40..5f76f7be18d 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ -from dataclasses import replace import contextlib from collections.abc import Callable +from dataclasses import replace from typing import TYPE_CHECKING, Any, cast import numpy as np From fd9dc7d6bbb7814a70f458ee0d9a2f9f4c1afce2 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Sat, 23 May 2026 14:27:44 +0000 Subject: [PATCH 4/4] fix lint Signed-off-by: gcanlin --- vllm_omni/worker/gpu_ar_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index e5dbeb02ce9..34d571c115f 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -8,6 +8,7 @@ from contextlib import nullcontext from copy import copy +from dataclasses import replace from typing import Any, NamedTuple import numpy as np