Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,39 @@ def execute_model(

return None

def _sample(
self,
logits: torch.Tensor | None,
spec_decode_metadata: Any,
):
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
model_sample = getattr(self.model, "sample", None)
self.input_batch.update_async_output_token_ids()
if logits is not None and callable(model_sample) and getattr(self.model, "prefer_model_sampler", False):
# Apply logit bias (min_tokens, allowed_token_ids) before
# the custom model sampler — the standard GPU sampler does
# this internally, but prefer_model_sampler bypasses it.
if hasattr(self.sampler, "logit_bias_state"):
self.sampler.logit_bias_state.apply_logit_bias(
logits,
self.input_batch.expanded_idx_mapping,
self.input_batch.idx_mapping_np,
self.input_batch.positions[self.input_batch.logits_indices],
)
sampler_output = model_sample(
logits,
self._sampling_metadata_for_model_sampler(sampling_metadata),
)
if sampler_output is not None:
return sampler_output
return self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)

return super()._sample(logits, spec_decode_metadata)

@torch.inference_mode()
def sample_tokens(
self, grammar_output: GrammarOutput | None
Expand Down
47 changes: 0 additions & 47 deletions vllm_omni/worker/gpu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,53 +109,6 @@ def _make_buffer(self, *size, dtype, numpy=True):
with maybe_disable_pin_memory_for_ray(self, total_bytes):
return super()._make_buffer(*size, dtype=dtype, numpy=numpy)

def _build_model_sampler_output_token_ids(self) -> list[list[int]]:
"""Build decoded-token history for custom model samplers.

vLLM only populates sampling_metadata.output_token_ids when penalties or
logits processors require it. CosyVoice3's custom RAS sampler also
depends on this history, so we reconstruct it directly from the input
batch for prefer_model_sampler models.
"""
req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", [])
req_ids = list(getattr(self.input_batch, "req_ids", []))
output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))]

sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None)
async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None)
prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None)
if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None:
return output_token_ids

sampled_token_ids: list[list[int]] | None = None
for index, req_id in enumerate(req_ids):
prev_index = prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_history = output_token_ids[index]
if not req_history or req_history[-1] != -1:
continue
if sampled_token_ids is None:
assert async_copy_ready_event is not None
async_copy_ready_event.synchronize()
sampled_token_ids = sampled_token_ids_cpu.tolist()
new_ids = list(sampled_token_ids[prev_index])
if not new_ids:
continue
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
first_placeholder = req_history.index(-1)
num_placeholders = len(req_history) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders)
req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace]

return output_token_ids

def _sampling_metadata_for_model_sampler(self, sampling_metadata):
output_token_ids = self._build_model_sampler_output_token_ids()
if output_token_ids == sampling_metadata.output_token_ids:
return sampling_metadata
return replace(sampling_metadata, output_token_ids=output_token_ids)

def _request_final_stage_id(self, req_id: str) -> int | None:
info = self.model_intermediate_buffer.get(req_id)
if not isinstance(info, dict):
Expand Down
49 changes: 49 additions & 0 deletions vllm_omni/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
from collections.abc import Callable
from dataclasses import replace
from typing import TYPE_CHECKING, Any, cast

import numpy as np
Expand Down Expand Up @@ -121,6 +122,54 @@ def load_model(self, *args, **kwargs) -> None:
self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)

def _build_model_sampler_output_token_ids(self) -> list[list[int]]:
"""Build decoded-token history for ``prefer_model_sampler`` models.

vLLM only populates ``sampling_metadata.output_token_ids`` when penalties
or logits processors require it. Models that opt into a custom sampler
(e.g. CosyVoice3 RAS sampler, HunyuanImage3 stage-transition sampler)
also depend on this history, so we reconstruct it directly from the
input batch. Shared by GPU and NPU AR runners.
"""
req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", [])
req_ids = list(getattr(self.input_batch, "req_ids", []))
output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))]

sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None)
async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None)
prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None)
if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None:
return output_token_ids

sampled_token_ids: list[list[int]] | None = None
for index, req_id in enumerate(req_ids):
prev_index = prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_history = output_token_ids[index]
if not req_history or req_history[-1] != -1:
continue
if sampled_token_ids is None:
assert async_copy_ready_event is not None
async_copy_ready_event.synchronize()
sampled_token_ids = sampled_token_ids_cpu.tolist()
new_ids = list(sampled_token_ids[prev_index])
if not new_ids:
continue
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
first_placeholder = req_history.index(-1)
num_placeholders = len(req_history) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders)
req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace]

return output_token_ids

def _sampling_metadata_for_model_sampler(self, sampling_metadata):
output_token_ids = self._build_model_sampler_output_token_ids()
if output_token_ids == sampling_metadata.output_token_ids:
return sampling_metadata
return replace(sampling_metadata, output_token_ids=output_token_ids)

def _init_mrope_positions(self, req_state: CachedRequestState):
"""Initialize M-RoPE positions for multimodal inputs.

Expand Down
Loading