diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index 82d628ae978..8d058335e9d 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -730,6 +730,39 @@ def execute_model( return None + def _sample( + self, + logits: torch.Tensor | None, + spec_decode_metadata: Any, + ): + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + model_sample = getattr(self.model, "sample", None) + self.input_batch.update_async_output_token_ids() + if logits is not None and callable(model_sample) and getattr(self.model, "prefer_model_sampler", False): + # Apply logit bias (min_tokens, allowed_token_ids) before + # the custom model sampler — the standard GPU sampler does + # this internally, but prefer_model_sampler bypasses it. + if hasattr(self.sampler, "logit_bias_state"): + self.sampler.logit_bias_state.apply_logit_bias( + logits, + self.input_batch.expanded_idx_mapping, + self.input_batch.idx_mapping_np, + self.input_batch.positions[self.input_batch.logits_indices], + ) + sampler_output = model_sample( + logits, + self._sampling_metadata_for_model_sampler(sampling_metadata), + ) + if sampler_output is not None: + return sampler_output + return self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + return super()._sample(logits, spec_decode_metadata) + @torch.inference_mode() def sample_tokens( self, grammar_output: GrammarOutput | None diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 2bcd4ba0169..34d571c115f 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -109,53 +109,6 @@ def _make_buffer(self, *size, dtype, numpy=True): with maybe_disable_pin_memory_for_ray(self, total_bytes): return super()._make_buffer(*size, dtype=dtype, numpy=numpy) - def _build_model_sampler_output_token_ids(self) -> list[list[int]]: - """Build decoded-token history for custom model samplers. - - vLLM only populates sampling_metadata.output_token_ids when penalties or - logits processors require it. CosyVoice3's custom RAS sampler also - depends on this history, so we reconstruct it directly from the input - batch for prefer_model_sampler models. - """ - req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", []) - req_ids = list(getattr(self.input_batch, "req_ids", [])) - output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))] - - sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None) - async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None) - prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None) - if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None: - return output_token_ids - - sampled_token_ids: list[list[int]] | None = None - for index, req_id in enumerate(req_ids): - prev_index = prev_req_id_to_index.get(req_id) - if prev_index is None: - continue - req_history = output_token_ids[index] - if not req_history or req_history[-1] != -1: - continue - if sampled_token_ids is None: - assert async_copy_ready_event is not None - async_copy_ready_event.synchronize() - sampled_token_ids = sampled_token_ids_cpu.tolist() - new_ids = list(sampled_token_ids[prev_index]) - if not new_ids: - continue - num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) - first_placeholder = req_history.index(-1) - num_placeholders = len(req_history) - first_placeholder - num_to_replace = min(num_sampled_ids, num_placeholders) - req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace] - - return output_token_ids - - def _sampling_metadata_for_model_sampler(self, sampling_metadata): - output_token_ids = self._build_model_sampler_output_token_ids() - if output_token_ids == sampling_metadata.output_token_ids: - return sampling_metadata - return replace(sampling_metadata, output_token_ids=output_token_ids) - def _request_final_stage_id(self, req_id: str) -> int | None: info = self.model_intermediate_buffer.get(req_id) if not isinstance(info, dict): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index eae7c0a11a3..5f76f7be18d 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1,5 +1,6 @@ import contextlib from collections.abc import Callable +from dataclasses import replace from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -121,6 +122,54 @@ def load_model(self, *args, **kwargs) -> None: self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + def _build_model_sampler_output_token_ids(self) -> list[list[int]]: + """Build decoded-token history for ``prefer_model_sampler`` models. + + vLLM only populates ``sampling_metadata.output_token_ids`` when penalties + or logits processors require it. Models that opt into a custom sampler + (e.g. CosyVoice3 RAS sampler, HunyuanImage3 stage-transition sampler) + also depend on this history, so we reconstruct it directly from the + input batch. Shared by GPU and NPU AR runners. + """ + req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", []) + req_ids = list(getattr(self.input_batch, "req_ids", [])) + output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))] + + sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None) + async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None) + prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None) + if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None: + return output_token_ids + + sampled_token_ids: list[list[int]] | None = None + for index, req_id in enumerate(req_ids): + prev_index = prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_history = output_token_ids[index] + if not req_history or req_history[-1] != -1: + continue + if sampled_token_ids is None: + assert async_copy_ready_event is not None + async_copy_ready_event.synchronize() + sampled_token_ids = sampled_token_ids_cpu.tolist() + new_ids = list(sampled_token_ids[prev_index]) + if not new_ids: + continue + num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) + first_placeholder = req_history.index(-1) + num_placeholders = len(req_history) - first_placeholder + num_to_replace = min(num_sampled_ids, num_placeholders) + req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace] + + return output_token_ids + + def _sampling_metadata_for_model_sampler(self, sampling_metadata): + output_token_ids = self._build_model_sampler_output_token_ids() + if output_token_ids == sampling_metadata.output_token_ids: + return sampling_metadata + return replace(sampling_metadata, output_token_ids=output_token_ids) + def _init_mrope_positions(self, req_state: CachedRequestState): """Initialize M-RoPE positions for multimodal inputs.