Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
)
from vllm_omni.engine.serialization import deserialize_additional_information
from vllm_omni.worker_v2.model_states.intermediate_buffer import (
_resolve_additional_information,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -461,6 +463,9 @@ def update_from_output(
)
)
if self.chunk_transfer_adapter is not None:
# Only clean receiver-side state here. Sender-side
# cleanup (cleanup_sender) is unsafe while save_async()
# background threads may still reference sender dicts.
self.chunk_transfer_adapter.cleanup_receiver(
request.request_id,
)
Expand Down Expand Up @@ -616,14 +621,15 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
}
# Also update request.additional_information for good measure
add_info = getattr(request, "additional_information", None)
# If additional_information is an AdditionalInformationPayload-like object,
# unpack it into a plain dict.
# If additional_information is an AdditionalInformationPayload-like
# object, fully resolve it into a plain dict (tensor_data → Tensor,
# list_data → list, scalar_data → scalar).
if (
add_info is not None
and hasattr(add_info, "entries")
and isinstance(getattr(add_info, "entries"), dict)
):
request.additional_information = deserialize_additional_information(add_info)
request.additional_information = _resolve_additional_information(add_info)
add_info = request.additional_information
if add_info is None:
request.additional_information = {}
Expand Down
6 changes: 5 additions & 1 deletion vllm_omni/worker_v2/model_states/intermediate_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ def _resolve_additional_information(payload: Any) -> dict[str, Any]:
arr = np.frombuffer(tensor_data, dtype=dt)
arr = arr.reshape(getattr(entry, "tensor_shape", ()))
info[k] = torch.from_numpy(arr.copy())
elif getattr(entry, "list_data", None) is not None:
info[k] = entry.list_data
elif getattr(entry, "scalar_data", None) is not None:
info[k] = entry.scalar_data
else:
info[k] = getattr(entry, "list_data", None)
info[k] = None
return info
except Exception:
logger.exception("Failed to decode additional_information payload")
Expand Down
89 changes: 55 additions & 34 deletions vllm_omni/worker_v2/model_states/omni_model_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,42 +83,63 @@ def __init__(
OmniModelState._rope_patch_lock = threading.Lock()

def _safe_get_rope(model_config: Any, mdl: Any, **kwargs: Any) -> Any:
result = None
needs_mrope_override = False
try:
return _orig_get_rope(model_config, mdl, **kwargs)
result = _orig_get_rope(model_config, mdl, **kwargs)
except (AssertionError, TypeError):
if not model_config.uses_mrope:
return None
logger.info(
"Model uses M-RoPE (config) but does not implement SupportsMRoPE; creating RopeState(num_dims=3)."
)
# Add get_mrope_input_positions if missing.
# Returns 3D sequential positions with delta=0
# (pure text, no vision token offsets).
if not hasattr(mdl, "get_mrope_input_positions"):

def _default_mrope_positions(
self_model: Any,
input_tokens: list[int],
mm_features: list,
) -> tuple[torch.Tensor, int]:
"""Return 3D sequential positions with zero delta.

For non-vision Omni models (e.g. TTS Talker),
all 3 M-RoPE dimensions use the same sequential
positions. Delta=0 means decode-step positions
are simply ``num_computed + offset``, identical
to the 1D case but broadcast to 3 dims.
"""
n = len(input_tokens)
pos = torch.arange(n, dtype=torch.long)
return pos.unsqueeze(0).expand(3, -1), 0

mdl.get_mrope_input_positions = types.MethodType(_default_mrope_positions, mdl)
# has_delta=True is required so init_prefill_positions
# calls get_mrope_input_positions (not the XD-RoPE
# path). delta=0 (returned above) means no offset is
# applied during decode — positions stay sequential.
return RopeState(num_dims=3, has_delta=True, **kwargs)
# Model does not implement SupportsMRoPE — may still
# need M-RoPE if config declares mrope_section.
needs_mrope_override = model_config.uses_mrope

if result is not None and not needs_mrope_override:
# Upstream returned a rope but check dimensionality:
# config has mrope_section but upstream returned a 1D
# rope (e.g. rope_type="default" with mrope_section).
if model_config.uses_mrope and getattr(result, "num_dims", 0) < 3:
logger.info(
"Upstream returned %dD rope but config has mrope_section; "
"overriding with RopeState(num_dims=3).",
getattr(result, "num_dims", 0),
)
needs_mrope_override = True
else:
return result

if not needs_mrope_override:
return None

logger.info(
"Model uses M-RoPE (config) but does not implement SupportsMRoPE; creating RopeState(num_dims=3)."
)
# Add get_mrope_input_positions if missing.
# Returns 3D sequential positions with delta=0
# (pure text, no vision token offsets).
if not hasattr(mdl, "get_mrope_input_positions"):

def _default_mrope_positions(
self_model: Any,
input_tokens: list[int],
mm_features: list,
) -> tuple[torch.Tensor, int]:
"""Return 3D sequential positions with zero delta.

For non-vision Omni models (e.g. TTS Talker),
all 3 M-RoPE dimensions use the same sequential
positions. Delta=0 means decode-step positions
are simply ``num_computed + offset``, identical
to the 1D case but broadcast to 3 dims.
"""
n = len(input_tokens)
pos = torch.arange(n, dtype=torch.long)
return pos.unsqueeze(0).expand(3, -1), 0

mdl.get_mrope_input_positions = types.MethodType(_default_mrope_positions, mdl)
# has_delta=True is required so init_prefill_positions
# calls get_mrope_input_positions (not the XD-RoPE
# path). delta=0 (returned above) means no offset is
# applied during decode — positions stay sequential.
return RopeState(num_dims=3, has_delta=True, **kwargs)

with OmniModelState._rope_patch_lock:
_orig_get_rope = _default_mod.get_rope_state
Expand Down
94 changes: 30 additions & 64 deletions vllm_omni/worker_v2/omni_generation_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_uniform_token_count,
)

from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData
from vllm_omni.core.sched.output import OmniCachedRequestData
from vllm_omni.model_executor.models.output_templates import OmniOutput
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker_v2.omni_model_runner import OmniGPUModelRunner
Expand Down Expand Up @@ -52,22 +52,25 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# ------------------------------------------------------------------

def _handle_async_chunk_updates(self, scheduler_output: SchedulerOutput) -> None:
"""Re-initialize cached requests whose prompt_token_ids changed.
"""In-place update cached requests whose prompt_token_ids changed.

In async_chunk mode, the ``ChunkTransferAdapter`` replaces
``Request.prompt_token_ids`` with new codec frames for each
chunk and resets ``num_computed_tokens`` to 0. The scheduler
propagates the new ``prompt_token_ids`` via
``OmniCachedRequestData``.

Upstream V2 ``update_requests`` only handles block allocation
and does **not** replace token state. This method removes the
stale request from ``req_states`` and re-adds it with the
updated tokens, mirroring V1's ``_update_request_states``.

Note: ``finish_requests`` / ``free_states`` (called before us)
already handle unscheduled request cleanup, so we only need to
process requests with new prompt_token_ids here.
Instead of remove + re-add (which involves free_indices churn
and redundant model_state init), we update the existing slot
in-place. This is safe for Code2Wav because:
- No KV cache / rope state to reinitialize
- staged writes are applied once at the end

``additional_information`` is NOT merged here — the inherited
``OmniGPUModelRunner.update_requests`` (called right after this
method in ``execute_model``) is the single source of truth for
``intermediate_buffer`` updates. Doing it in both places would
clone every tensor to CPU twice per step.
"""
cached = scheduler_output.scheduled_cached_reqs
if not cached.req_ids:
Expand All @@ -80,68 +83,31 @@ def _handle_async_chunk_updates(self, scheduler_output: SchedulerOutput) -> None
if not new_prompt_ids:
return

addl_info = cached.additional_information
updated = False

# Phase 1: remove all stale states, collecting re-add data.
updates: list[tuple[str, list[int], Any]] = []
for req_id in cached.req_ids:
new_ids = new_prompt_ids.get(req_id)
if new_ids is None:
continue

if req_id not in self.req_states.req_id_to_index:
req_idx = self.req_states.req_id_to_index.get(req_id)
if req_idx is None:
continue

old_index = self.req_states.req_id_to_index[req_id]

self.req_states.remove_request(req_id)
self.model_state.remove_request(old_index)

updates.append((req_id, new_ids, addl_info.get(req_id)))

if not updates:
return

# Phase 2: re-add all requests with new tokens, then batch-apply.
for req_id, new_ids, info in updates:
# req_id_to_index is updated eagerly by add_request() (not
# staged), so new_index is valid before apply_staged_writes().
self.req_states.add_request(
req_id=req_id,
prompt_len=len(new_ids),
all_token_ids=new_ids,
num_computed_tokens=0,
)
# In-place update token state — same slot, no remove/re-add.
# .np[] = direct write (no GPU buffer); stage_write = GPU-synced.
n = len(new_ids)
self.req_states.prompt_len.np[req_idx] = n
self.req_states.prefill_len.np[req_idx] = n
self.req_states.total_len.stage_write_elem(req_idx, n)
self.req_states.all_token_ids.stage_write(req_idx, 0, new_ids)
self.req_states.num_computed_tokens.stage_write_elem(req_idx, 0)
self.req_states.num_computed_prefill_tokens[req_idx] = 0

new_index = self.req_states.req_id_to_index[req_id]

# Build a synthetic NewRequestData so model_state.add_request
# can run _resolve_additional_information and notify plugins.
# sampling_params is None because the generation runner does
# not sample tokens — Code2Wav output goes directly to
# pooler_output. block_ids is empty because generation
# models have no KV cache.
synthetic = OmniNewRequestData(
req_id=req_id,
prompt_token_ids=new_ids,
mm_features=None,
sampling_params=None,
pooling_params=None,
block_ids=([],),
num_computed_tokens=0,
lora_request=None,
prefill_token_ids=new_ids,
additional_information=info,
)
# model_state.add_request calls super().add_request()
# (DefaultModelState) internally. This is safe for Code2Wav
# because generation models have no attention/rope state to
# initialize — the super() call only touches intermediate
# buffer and encoder cache, both of which are idempotent
# overwrites on the same slot.
self.model_state.add_request(new_index, synthetic)
updated = True

self.req_states.apply_staged_writes()
if updated:
self.req_states.apply_staged_writes()

# ------------------------------------------------------------------
# profile / warmup — skip sampler since there are no logits
Expand Down Expand Up @@ -175,8 +141,8 @@ def execute_model(
self.finish_requests(scheduler_output)
self.free_states(scheduler_output)
# Handle async_chunk prompt_token_ids replacement for cached
# requests BEFORE add/update — so the stale request state is
# removed and re-created with the new chunk's tokens.
# requests BEFORE add/update — update the existing slot
# in-place with the new chunk's tokens.
self._handle_async_chunk_updates(scheduler_output)
self.add_requests(scheduler_output)
self.update_requests(scheduler_output)
Expand Down
32 changes: 32 additions & 0 deletions vllm_omni/worker_v2/omni_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@

from vllm_omni.model_executor.models.output_templates import OmniOutput
from vllm_omni.worker_v2.model_states import init_omni_model_state
from vllm_omni.worker_v2.model_states.intermediate_buffer import (
_resolve_additional_information,
)
from vllm_omni.worker_v2.model_states.omni_model_state import OmniModelState

logger = init_logger(__name__)
Expand Down Expand Up @@ -305,6 +308,35 @@ def execute_model(
assert isinstance(hidden_states, torch.Tensor)
return None

# ------------------------------------------------------------------
# Request lifecycle: update intermediate buffer from cached requests
# ------------------------------------------------------------------

def update_requests(self, scheduler_output: SchedulerOutput) -> None:
"""Merge updated additional_information into intermediate_buffer.

In async_chunk mode, chunk_transfer_adapter attaches updated
additional_information (e.g. thinker_decode_embeddings) to
OmniCachedRequestData for cached requests every schedule step.
Upstream GPUModelRunner.update_requests does not handle this
field, so we merge it into the intermediate buffer here.
"""
super().update_requests(scheduler_output)

cached = scheduler_output.scheduled_cached_reqs
addl_info = getattr(cached, "additional_information", None)
if not addl_info:
return
for req_id, info in addl_info.items():
if info is None:
continue
req_idx = self.req_states.req_id_to_index.get(req_id)
if req_idx is None:
continue
resolved = _resolve_additional_information(info)
if resolved:
self.model_state.intermediate_buffer.update(req_idx, resolved)

# ------------------------------------------------------------------
# Request lifecycle: clean up intermediate buffer on finish
# ------------------------------------------------------------------
Expand Down