From 76173da5ec06da7420658722b1f9b8fa26fceb9b Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 11:38:30 +0000 Subject: [PATCH 01/53] add OmniPayloadStruct (msgspec) alongside TypedDict for runtime payload validation Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 83 ++++++++++++++++++ vllm_omni/data_entry_keys.py | 154 ++++++++++++++++++++++++++++++++++ 2 files changed, 237 insertions(+) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index f4a50e677de..ecb9a7e4fda 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -1,12 +1,22 @@ """Tests for data_entry_keys: TypedDict payload structure, flatten/unflatten, serialize/deserialize.""" +import msgspec +import pytest import torch from vllm_omni.data_entry_keys import ( + CodesStruct, + EmbeddingsStruct, + HiddenStatesStruct, + IdsStruct, + MetaStruct, OmniPayload, + OmniPayloadStruct, deserialize_payload, flatten_payload, serialize_payload, + to_dict, + to_struct, unflatten_payload, ) from vllm_omni.engine import AdditionalInformationPayload @@ -250,3 +260,76 @@ def test_none_values_skipped(self): original: OmniPayload = {"meta": {"finished": None}} wire = serialize_payload(original) assert wire is None + + +class TestOmniPayloadStruct: + """Runtime-validated mirror of OmniPayload (msgspec.Struct).""" + + def test_construct_directly(self): + p = OmniPayloadStruct( + meta=MetaStruct(left_context_size=5, finished=torch.tensor(True)), + codes=CodesStruct(audio=torch.zeros(3, 8)), + ) + assert p.meta.left_context_size == 5 + assert torch.equal(p.codes.audio, torch.zeros(3, 8)) + + def test_to_struct_validates_dict(self): + d = {"meta": {"left_context_size": 25, "finished": torch.tensor(False)}} + s = to_struct(d) + assert s.meta.left_context_size == 25 + + def test_to_struct_rejects_legacy_flat_top_level(self): + with pytest.raises(msgspec.ValidationError, match="unknown field"): + to_struct({"code_predictor_codes": torch.zeros(3, 8)}) + + def test_to_struct_rejects_legacy_flat_meta_field(self): + # `left_context_size` at top level (legacy) instead of under `meta` + with pytest.raises(msgspec.ValidationError, match="unknown field"): + to_struct({"left_context_size": 25}) + + def test_to_struct_rejects_typo_in_subkey(self): + with pytest.raises(msgspec.ValidationError, match="unknown field"): + to_struct({"meta": {"finisheed": True}}) + + def test_to_struct_rejects_wrong_type(self): + with pytest.raises(msgspec.ValidationError, match="Expected"): + to_struct({"meta": {"left_context_size": "not_an_int"}}) + + def test_round_trip_dict_struct_dict(self): + original = { + "meta": {"left_context_size": 7, "finished": torch.tensor(True)}, + "codes": {"audio": torch.zeros(2, 8)}, + "hidden_states": {"output": torch.zeros(4, 16)}, + } + s = to_struct(original) + d = to_dict(s) + assert sorted(d.keys()) == sorted(original.keys()) + for top, sub in original.items(): + assert sorted(d[top].keys()) == sorted(sub.keys()) + + def test_to_dict_drops_unset_fields(self): + s = OmniPayloadStruct(meta=MetaStruct(left_context_size=10)) + d = to_dict(s) + assert d == {"meta": {"left_context_size": 10}} + + def test_struct_attr_access_catches_typos_at_lookup(self): + s = to_struct({"meta": {"finished": torch.tensor(True)}}) + with pytest.raises(AttributeError): + _ = s.meta.finisheed + + def test_struct_with_all_categories(self): + d = { + "hidden_states": {"output": torch.zeros(1)}, + "embed": {"prefill": torch.zeros(1), "tts_bos": torch.zeros(1)}, + "ids": {"all": [1, 2], "prompt": [1]}, + "codes": {"audio": torch.zeros(1)}, + "meta": {"left_context_size": 3, "num_processed_tokens": 7}, + } + s = to_struct(d) + assert isinstance(s.hidden_states, HiddenStatesStruct) + assert isinstance(s.embed, EmbeddingsStruct) + assert isinstance(s.ids, IdsStruct) + assert isinstance(s.codes, CodesStruct) + assert isinstance(s.meta, MetaStruct) + assert s.ids.all == [1, 2] + assert s.meta.num_processed_tokens == 7 diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index f22aad9d46d..f632b741185 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, TypedDict +import msgspec import numpy as np import torch @@ -105,6 +106,159 @@ class OmniPayload(TypedDict, total=False): request_id: str +# ── msgspec.Struct schema (runtime-validated mirror of the TypedDicts above) ── +# +# The TypedDicts give static type checking but are plain dicts at runtime, so +# producer/consumer key mismatches degrade silently (the regularize_data_entries +# refactor surfaced ~8 such bugs). These Struct types add runtime type +# checking via ``msgspec.convert``, attribute access (``p.meta.finished`` +# instead of ``p["meta"]["finished"]``), and unify the serialization path +# with the existing msgspec encoders. +# +# Dict and Struct forms coexist during migration; converters +# (:func:`to_struct`, :func:`to_dict`) bridge the two. + + +class _StructBase(msgspec.Struct, omit_defaults=True, kw_only=True, forbid_unknown_fields=True): + """Common base for nested payload structs. + + - ``omit_defaults``: skip ``None`` fields when serializing. + - ``kw_only``: mirror TypedDict construction style. + - ``forbid_unknown_fields``: reject typos and legacy flat keys at decode time. + """ + + +class HiddenStatesStruct(_StructBase): + output: torch.Tensor | None = None + trailing_text: torch.Tensor | None = None + last: torch.Tensor | None = None + layers: dict[int, torch.Tensor] | None = None + + +class EmbeddingsStruct(_StructBase): + prefill: torch.Tensor | None = None + decode: torch.Tensor | None = None + cached_decode: torch.Tensor | None = None + tts_bos: torch.Tensor | None = None + tts_eos: torch.Tensor | None = None + tts_pad: torch.Tensor | None = None + tts_pad_projected: torch.Tensor | None = None + voice: torch.Tensor | None = None + speech_feat: torch.Tensor | None = None + thinker_reply: torch.Tensor | None = None + + +class CodesStruct(_StructBase): + audio: torch.Tensor | None = None + ref: torch.Tensor | None = None + + +class IdsStruct(_StructBase): + all: list[int] | None = None + prompt: list[int] | None = None + output: list[int] | None = None + speech_token: list[int] | None = None + prior_image: list[int] | None = None + + +class MetaStruct(_StructBase): + finished: torch.Tensor | None = None + left_context_size: int | None = None + override_keys: list[tuple[str, str]] | None = None + num_processed_tokens: int | None = None + next_stage_prompt_len: int | None = None + ar_width: int | None = None + eol_token_id: int | None = None + visual_token_start_id: int | None = None + visual_token_end_id: int | None = None + gen_token_mask: torch.Tensor | None = None + omni_task: list[str] | None = None + height: int | None = None + width: int | None = None + decode_flag: bool | None = None + codec_streaming: bool | None = None + ref_code_len: int | None = None + talker_prefill_offset: int | None = None + + +class OmniPayloadStruct(_StructBase): + hidden_states: HiddenStatesStruct | None = None + embed: EmbeddingsStruct | None = None + ids: IdsStruct | None = None + codes: CodesStruct | None = None + meta: MetaStruct | None = None + latent: torch.Tensor | None = None + generated_len: int | None = None + model_outputs: list[torch.Tensor] | None = None + mtp_inputs: tuple[torch.Tensor, torch.Tensor] | None = None + speaker: Any = None + language: Any = None + request_id: str | None = None + + +_NESTED_STRUCTS: dict[str, type[_StructBase]] = { + "hidden_states": HiddenStatesStruct, + "embed": EmbeddingsStruct, + "ids": IdsStruct, + "codes": CodesStruct, + "meta": MetaStruct, +} + + +def _msgspec_dec_hook(typ: type, obj: Any) -> Any: + """Bridge non-msgspec types (``torch.Tensor``) when decoding into Structs.""" + if typ is torch.Tensor: + if isinstance(obj, torch.Tensor): + return obj + raise TypeError(f"cannot decode {type(obj).__name__} into torch.Tensor") + raise NotImplementedError(f"no decoder for {typ}") + + +def _msgspec_enc_hook(obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + return { + "__tensor__": True, + "data": obj.detach().cpu().contiguous().numpy().tobytes(), + "shape": list(obj.shape), + "dtype": _dtype_to_name(obj.dtype), + } + raise NotImplementedError(f"no encoder for {type(obj).__name__}") + + +def to_struct(payload: dict[str, Any]) -> OmniPayloadStruct: + """Convert a payload dict into ``OmniPayloadStruct``, validating types. + + Raises ``msgspec.ValidationError`` on: + * unknown top-level keys (typos, legacy flat keys) + * unknown sub-keys under any nested category + * type mismatches (e.g., ``meta.left_context_size`` not an ``int``) + """ + return msgspec.convert(payload, OmniPayloadStruct, dec_hook=_msgspec_dec_hook) + + +def to_dict(struct: OmniPayloadStruct) -> dict[str, Any]: + """Convert ``OmniPayloadStruct`` back to a plain dict, dropping unset fields. + + Used during migration when downstream code still expects dicts. + """ + out: dict[str, Any] = {} + for field in OmniPayloadStruct.__struct_fields__: + value = getattr(struct, field) + if value is None: + continue + if isinstance(value, _StructBase): + sub: dict[str, Any] = {} + for sk in value.__struct_fields__: + sv = getattr(value, sk) + if sv is not None: + sub[sk] = sv + if sub: + out[field] = sub + else: + out[field] = value + return out + + # ── Keys whose values are nested dicts (TypedDict sub-categories) ── _NESTED_KEYS = frozenset({"hidden_states", "embed", "ids", "codes", "meta"}) From dbc16d545aae9a08ded3cd5a62b94931576331b7 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 11:45:21 +0000 Subject: [PATCH 02/53] tighten dict to OmniPayload in qwen3_omni, qwen3_tts, mimo_audio, mixin, chunk_transfer_adapter signatures Signed-off-by: Divyansh Singhvi --- .../transfer_adapter/chunk_transfer_adapter.py | 4 ++-- .../stage_input_processors/mimo_audio.py | 5 +++-- .../stage_input_processors/qwen3_omni.py | 8 ++++---- .../stage_input_processors/qwen3_tts.py | 7 ++++--- .../worker/omni_connector_model_runner_mixin.py | 13 +++++++------ 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 4535c2596dd..b0c56157e54 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -8,7 +8,7 @@ import torch from vllm.v1.request import Request, RequestStatus -from vllm_omni.data_entry_keys import unflatten_payload +from vllm_omni.data_entry_keys import OmniPayload, unflatten_payload from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec @@ -193,7 +193,7 @@ def _poll_single_request(self, request: Request): return False - def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: + def _update_request_payload(self, req_id: str, payload_data: OmniPayload) -> OmniPayload: """Update the stored payload for *req_id* with the latest chunk.""" if req_id not in self.request_payload: self.request_payload[req_id] = payload_data diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index 96680b2dd94..4910552edd7 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -4,6 +4,7 @@ from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm_omni.data_entry_keys import OmniPayload from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID @@ -114,10 +115,10 @@ def _to_code_tensor(codes: Any) -> torch.Tensor | None: def llm2code2wav_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayload | None: """ Async chunk version: convert stage-0 pooling_output to code2wav payload (pooling / connector accumulation). diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index db87465d21c..61ed74c8655 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -287,10 +287,10 @@ def _get_streaming_codec_delta_len( def thinker2talker_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: OmniEngineCoreRequest, is_finished: bool = False, -) -> list[dict[str, Any]]: +) -> OmniPayload | None: """ Process thinker outputs to create talker inputs. 1. thinker's text generation outputs (token IDs + hidden states) @@ -500,10 +500,10 @@ def thinker2talker( def talker2code2wav_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: OmniEngineCoreRequest, is_finished: bool = False, -): +) -> OmniPayload | None: """ Pooling version. """ diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index a0c62cba30f..94851465cf8 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -5,6 +5,7 @@ import torch from vllm.logger import init_logger +from vllm_omni.data_entry_keys import OmniPayload from vllm_omni.model_executor.stage_input_processors.chunk_size_utils import ( compute_dynamic_initial_chunk_size, max_ic_for_chunk_size, @@ -118,7 +119,7 @@ def talker2code2wav( return code2wav_inputs -def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: +def _extract_last_frame(pooling_output: OmniPayload) -> torch.Tensor | None: audio_codes = pooling_output.get("codes", {}).get("audio") if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: return None @@ -134,10 +135,10 @@ def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: def talker2code2wav_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any] | None, + pooling_output: OmniPayload | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayload | None: request_id = request.external_req_id finished = bool(is_finished or request.is_finished()) request_payload = getattr(transfer_manager, "request_payload", None) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 8e8f5741fa6..215c3003897 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -24,6 +24,7 @@ from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger +from vllm_omni.data_entry_keys import OmniPayload from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec from vllm_omni.outputs import OmniConnectorOutput @@ -341,15 +342,15 @@ def prune_inactive_requests(self, active_req_ids: Any) -> set[str]: # Local payload cache (RFC §2.4 – Model Runner ownership) # ------------------------------------------------------------------ # - def put_local_stage_payload(self, req_id: str, payload: dict[str, Any]) -> None: + def put_local_stage_payload(self, req_id: str, payload: OmniPayload) -> None: """Store a full stage payload in the local cache.""" self._local_stage_payload_cache[req_id] = payload - def get_local_stage_payload(self, req_id: str) -> dict[str, Any] | None: + def get_local_stage_payload(self, req_id: str) -> OmniPayload | None: """Read a stage payload without removing it.""" return self._local_stage_payload_cache.get(req_id) - def pop_local_stage_payload(self, req_id: str) -> dict[str, Any] | None: + def pop_local_stage_payload(self, req_id: str) -> OmniPayload | None: """Remove and return a stage payload (consume after use).""" return self._local_stage_payload_cache.pop(req_id, None) @@ -366,7 +367,7 @@ def get_local_request_metadata(self, req_id: str) -> dict[str, Any] | None: # ------------------------------------------------------------------ # @classmethod - def _extract_scheduling_metadata(cls, payload: dict[str, Any]) -> dict[str, Any]: + def _extract_scheduling_metadata(cls, payload: OmniPayload) -> dict[str, Any]: """Extract only the fields the scheduler needs from a full payload.""" extracted: dict[str, Any] = {} if "next_stage_prompt_len" in payload: @@ -427,7 +428,7 @@ def _payload_audio_codes(payload: Any) -> Any: return None @classmethod - def _payload_is_consumable(cls, payload: dict[str, Any] | None) -> bool: + def _payload_is_consumable(cls, payload: OmniPayload | None) -> bool: """Return True when an async payload can drive a real forward step. Metadata-only wake-ups should not transition WAITING_FOR_CHUNK requests @@ -1791,7 +1792,7 @@ def _decrement_pending_save_count(self, request_id: str) -> None: # Payload accumulation (ported from OmniChunkTransferAdapter) # ------------------------------------------------------------------ # - def _accumulate_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: + def _accumulate_payload(self, req_id: str, payload_data: OmniPayload) -> OmniPayload: """Accumulate chunk payloads (concat tensors, extend lists). Returns a **shallow copy** of the accumulated state so callers From 01429d58a16832c1e35ecc9d72c976e84c6c792f Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 11:49:14 +0000 Subject: [PATCH 03/53] add validate_payload boundary helper that raises on schema drift Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 30 ++++++++++++++++++++++++++++++ vllm_omni/data_entry_keys.py | 15 +++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index ecb9a7e4fda..d00099b6ac0 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -333,3 +333,33 @@ def test_struct_with_all_categories(self): assert isinstance(s.meta, MetaStruct) assert s.ids.all == [1, 2] assert s.meta.num_processed_tokens == 7 + + +class TestValidatePayload: + def test_raises_on_unknown_top_level(self): + from vllm_omni.data_entry_keys import validate_payload + + with pytest.raises(msgspec.ValidationError, match="unknown field"): + validate_payload({"code_predictor_codes": torch.zeros(3, 8)}, context="test_boundary") + + def test_raises_on_unknown_sub_key(self): + from vllm_omni.data_entry_keys import validate_payload + + with pytest.raises(msgspec.ValidationError, match="unknown field"): + validate_payload({"meta": {"finisheed": True}}) + + def test_none_is_ok(self): + from vllm_omni.data_entry_keys import validate_payload + + validate_payload(None) # should not raise + + def test_valid_payload_passes(self): + from vllm_omni.data_entry_keys import validate_payload + + validate_payload({"meta": {"left_context_size": 5}}) + + def test_context_in_error_message(self): + from vllm_omni.data_entry_keys import validate_payload + + with pytest.raises(msgspec.ValidationError, match="my_call_site"): + validate_payload({"bad": 1}, context="my_call_site") diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index f632b741185..d83d95edb5e 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -236,6 +236,21 @@ def to_struct(payload: dict[str, Any]) -> OmniPayloadStruct: return msgspec.convert(payload, OmniPayloadStruct, dec_hook=_msgspec_dec_hook) +def validate_payload(payload: dict[str, Any] | None, *, context: str = "payload") -> None: + """Validate a payload matches the ``OmniPayload`` schema, raising on drift. + + Wraps :func:`to_struct` and re-raises ``msgspec.ValidationError`` with + the call-site ``context`` prepended. ``None`` is allowed (treated as + "no payload to check"). + """ + if payload is None: + return + try: + to_struct(payload) + except msgspec.ValidationError as exc: + raise msgspec.ValidationError(f"{context}: {exc}") from exc + + def to_dict(struct: OmniPayloadStruct) -> dict[str, Any]: """Convert ``OmniPayloadStruct`` back to a plain dict, dropping unset fields. From baae107d41e2b0e0db3718d2a826535a7da36e8f Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:02:19 +0000 Subject: [PATCH 04/53] migrate qwen3_tts talker2code2wav_async_chunk to construct OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- vllm_omni/data_entry_keys.py | 4 ++- .../stage_input_processors/qwen3_tts.py | 32 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index d83d95edb5e..460ed71a46d 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -149,7 +149,9 @@ class EmbeddingsStruct(_StructBase): class CodesStruct(_StructBase): - audio: torch.Tensor | None = None + # ``audio`` is a tensor inside the model (talker output), but a flattened + # ``list[int]`` once the codes are prepared for the wire / next-stage prompt. + audio: torch.Tensor | list[int] | None = None ref: torch.Tensor | None = None diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 94851465cf8..f20cac1b061 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -5,7 +5,13 @@ import torch from vllm.logger import init_logger -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.model_executor.stage_input_processors.chunk_size_utils import ( compute_dynamic_initial_chunk_size, max_ic_for_chunk_size, @@ -256,17 +262,13 @@ def talker2code2wav_async_chunk( num_frames = len(window_frames) code_predictor_codes = [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)] - info: dict[str, Any] = { - "codes": {"audio": code_predictor_codes}, - "meta": {"left_context_size": left_context_size, "finished": torch.tensor(finished, dtype=torch.bool)}, - } - # Propagate speaker and language from the request so they are available - # as runtime_additional_information in subsequent pipeline stages, consistent - # with qwen3-omni and qwen2.5-omni stage input processors. - speaker = extract_speaker_from_request(request) - if speaker is not None: - info["speaker"] = speaker - language = extract_language_from_request(request) - if language is not None: - info["language"] = language - return info + payload = OmniPayloadStruct( + codes=CodesStruct(audio=code_predictor_codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(finished, dtype=torch.bool), + ), + speaker=extract_speaker_from_request(request), + language=extract_language_from_request(request), + ) + return to_dict(payload) From 55f6b24bcdf112dab958bd6e982ebc5fbbcd2c7e Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:10:08 +0000 Subject: [PATCH 05/53] migrate qwen3_omni talker2code2wav_async_chunk to OmniPayloadStruct, widen codes.audio to Any Signed-off-by: Divyansh Singhvi --- vllm_omni/data_entry_keys.py | 3 ++- .../stage_input_processors/qwen3_omni.py | 21 ++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 460ed71a46d..94c31d32164 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -151,7 +151,8 @@ class EmbeddingsStruct(_StructBase): class CodesStruct(_StructBase): # ``audio`` is a tensor inside the model (talker output), but a flattened # ``list[int]`` once the codes are prepared for the wire / next-stage prompt. - audio: torch.Tensor | list[int] | None = None + # msgspec does not allow union of custom type + builtin, so use ``Any`` here. + audio: Any = None ref: torch.Tensor | None = None diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 61ed74c8655..d0f5bcc0133 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -11,7 +11,13 @@ from vllm.inputs import TextPrompt from vllm.platforms import current_platform -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.stage_input_processors.tts_utils import ( @@ -559,10 +565,15 @@ def talker2code2wav_async_chunk( .tolist() ) - return { - "codes": {"audio": codes}, - "meta": {"left_context_size": left_context_size, "finished": torch.tensor(is_finished, dtype=torch.bool)}, - } + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(is_finished, dtype=torch.bool), + ), + ) + ) def talker2code2wav( From 4e31b658794388fd34e3717ee912c63389a8878b Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:16:43 +0000 Subject: [PATCH 06/53] keep codes.audio as torch.Tensor end-to-end, convert to list only at vLLM prompt boundary Signed-off-by: Divyansh Singhvi --- vllm_omni/data_entry_keys.py | 5 +---- .../transfer_adapter/chunk_transfer_adapter.py | 6 +++++- .../model_executor/stage_input_processors/qwen3_omni.py | 7 +------ .../model_executor/stage_input_processors/qwen3_tts.py | 5 ++++- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 94c31d32164..d83d95edb5e 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -149,10 +149,7 @@ class EmbeddingsStruct(_StructBase): class CodesStruct(_StructBase): - # ``audio`` is a tensor inside the model (talker output), but a flattened - # ``list[int]`` once the codes are prepared for the wire / next-stage prompt. - # msgspec does not allow union of custom type + builtin, so use ``Any`` here. - audio: Any = None + audio: torch.Tensor | None = None ref: torch.Tensor | None = None diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index b0c56157e54..74fd4539e1d 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -162,7 +162,11 @@ def _poll_single_request(self, request: Request): if meta.get("finished"): self.finished_requests.add(req_id) - new_ids = payload_data.get("codes", {}).get("audio", []) + new_ids = payload_data.get("codes", {}).get("audio") + if isinstance(new_ids, torch.Tensor): + new_ids = new_ids.tolist() + elif new_ids is None: + new_ids = [] request.prompt_token_ids = new_ids prev_info = getattr(request, "additional_information", None) info = dict(prev_info) if isinstance(prev_info, dict) else {} diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index d0f5bcc0133..f22cca5b0b0 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -558,12 +558,7 @@ def talker2code2wav_async_chunk( left_context_size = max(0, min(length - context_length, left_context_size_config)) end_index = min(length, left_context_size + context_length) - codes = ( - torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) - .transpose(0, 1) - .reshape(-1) - .tolist() - ) + codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).transpose(0, 1).reshape(-1) return to_dict( OmniPayloadStruct( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index f20cac1b061..6d1f392e332 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -260,7 +260,10 @@ def talker2code2wav_async_chunk( num_quantizers = len(window_frames[0]) num_frames = len(window_frames) - code_predictor_codes = [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)] + code_predictor_codes = torch.tensor( + [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)], + dtype=torch.long, + ) payload = OmniPayloadStruct( codes=CodesStruct(audio=code_predictor_codes), From a6f031ee28870c4baacb9a67b0a09a86c3467c47 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:26:20 +0000 Subject: [PATCH 07/53] migrate qwen3_omni thinker2talker_async_chunk to construct OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- .../stage_input_processors/qwen3_omni.py | 100 ++++++++---------- 1 file changed, 43 insertions(+), 57 deletions(-) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index f22cca5b0b0..7e0e7ef5bd0 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -13,6 +13,9 @@ from vllm_omni.data_entry_keys import ( CodesStruct, + EmbeddingsStruct, + HiddenStatesStruct, + IdsStruct, MetaStruct, OmniPayload, OmniPayloadStruct, @@ -309,76 +312,59 @@ def thinker2talker_async_chunk( thinker_hs = pooling_output.get("hidden_states", {}) thinker_layers = thinker_hs.get("layers", {}) thinker_embed = pooling_output.get("embed", {}) + speaker = extract_speaker_from_request(request) + language = extract_language_from_request(request) if chunk_id == 0: - all_token_ids = request.all_token_ids # prefill + decode - prompt_token_ids = request.prompt_token_ids - # Convert ConstantList to regular list for OmniSerializer serialization - all_token_ids = _ensure_list(all_token_ids) - prompt_token_ids = _ensure_list(prompt_token_ids) - payload: OmniPayload = { - "embed": { - "prefill": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu(), - # Provide thinker-side TTS token embeddings for talker projection - "tts_bos": thinker_embed["tts_bos"].detach().cpu(), - "tts_eos": thinker_embed["tts_eos"].detach().cpu(), - "tts_pad": thinker_embed["tts_pad"].detach().cpu(), - }, - "hidden_states": {"output": thinker_layers[int(_HIDDEN_LAYER_KEY)].detach().cpu()}, - "ids": {"all": all_token_ids, "prompt": prompt_token_ids}, - "meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)}, - } - talker_additional_info = payload - speaker = extract_speaker_from_request(request) - if speaker is not None: - talker_additional_info["speaker"] = speaker - language = extract_language_from_request(request) - if language is not None: - talker_additional_info["language"] = language + all_token_ids = _ensure_list(request.all_token_ids) + prompt_token_ids = _ensure_list(request.prompt_token_ids) + payload = OmniPayloadStruct( + embed=EmbeddingsStruct( + prefill=thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu(), + tts_bos=thinker_embed["tts_bos"].detach().cpu(), + tts_eos=thinker_embed["tts_eos"].detach().cpu(), + tts_pad=thinker_embed["tts_pad"].detach().cpu(), + ), + hidden_states=HiddenStatesStruct(output=thinker_layers[int(_HIDDEN_LAYER_KEY)].detach().cpu()), + ids=IdsStruct(all=all_token_ids, prompt=prompt_token_ids), + meta=MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool)), + speaker=speaker, + language=language, + ) if transfer_manager.request_payload.get(request_id) is None: if not is_finished: - transfer_manager.request_payload[request_id] = talker_additional_info + transfer_manager.request_payload[request_id] = to_dict(payload) return None else: save_payload = transfer_manager.request_payload.pop(request_id) - talker_additional_info["embed"]["prefill"] = torch.cat( - ( - save_payload.get("embed", {}).get("prefill"), - talker_additional_info.get("embed", {}).get("prefill"), - ), - dim=0, + payload.embed.prefill = torch.cat( + (save_payload.get("embed", {}).get("prefill"), payload.embed.prefill), dim=0 ) - talker_additional_info["hidden_states"]["output"] = torch.cat( - ( - save_payload.get("hidden_states", {}).get("output"), - talker_additional_info.get("hidden_states", {}).get("output"), - ), - dim=0, + payload.hidden_states.output = torch.cat( + (save_payload.get("hidden_states", {}).get("output"), payload.hidden_states.output), dim=0 ) else: - output_token_ids = request.output_token_ids - # Convert ConstantList to regular list for OmniSerializer serialization - output_token_ids = _ensure_list(output_token_ids) - - talker_additional_info: OmniPayload = { - "meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)}, - } - speaker = extract_speaker_from_request(request) - if speaker is not None: - talker_additional_info["speaker"] = speaker - language = extract_language_from_request(request) - if language is not None: - talker_additional_info["language"] = language - + output_token_ids = _ensure_list(request.output_token_ids) + meta = MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool)) if output_token_ids: - talker_additional_info["meta"]["override_keys"] = [("embed", "decode"), ("ids", "output")] - talker_additional_info["embed"] = {"decode": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu()} - talker_additional_info["ids"] = {"output": output_token_ids} + meta.override_keys = [("embed", "decode"), ("ids", "output")] + payload = OmniPayloadStruct( + meta=meta, + embed=EmbeddingsStruct(decode=thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu()), + ids=IdsStruct(output=output_token_ids), + speaker=speaker, + language=language, + ) else: # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. - talker_additional_info["embed"] = {"prefill": thinker_layers[0].detach().cpu()} - talker_additional_info["hidden_states"] = {"output": thinker_layers[24].detach().cpu()} - return talker_additional_info + payload = OmniPayloadStruct( + meta=meta, + embed=EmbeddingsStruct(prefill=thinker_layers[0].detach().cpu()), + hidden_states=HiddenStatesStruct(output=thinker_layers[24].detach().cpu()), + speaker=speaker, + language=language, + ) + return to_dict(payload) def thinker2talker( From cb904b0b0d205c7ee5198f216ff952c3a4f3b9d8 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:30:04 +0000 Subject: [PATCH 08/53] migrate qwen3_omni thinker2talker (full payload) to construct OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- .../stage_input_processors/qwen3_omni.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 7e0e7ef5bd0..27d47aa36a2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -442,36 +442,26 @@ def thinker2talker( except Exception as exc: logger.warning("[PD] Could not merge prefill embeddings: %s", exc) - payload: OmniPayload = { - "embed": { - "prefill": thinker_emb, - "tts_bos": _resolve_tts_token_embedding( + payload = OmniPayloadStruct( + embed=EmbeddingsStruct( + prefill=thinker_emb, + tts_bos=_resolve_tts_token_embedding( "tts_bos", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device ), - "tts_eos": _resolve_tts_token_embedding( + tts_eos=_resolve_tts_token_embedding( "tts_eos", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device ), - "tts_pad": _resolve_tts_token_embedding( + tts_pad=_resolve_tts_token_embedding( "tts_pad", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device ), - }, - "hidden_states": { - "output": thinker_hid, - }, - "ids": { - "all": thinker_sequences, - "prompt": thinker_input_ids, - }, - } - info = payload - speaker = extract_speaker_from_prompt(prompt, index=i) - if speaker is not None: - info["speaker"] = speaker - language = extract_language_from_prompt(prompt, index=i) - if language is not None: - info["language"] = language - - prompt_len = _compute_talker_prompt_ids_length(payload, device=device) + ), + hidden_states=HiddenStatesStruct(output=thinker_hid), + ids=IdsStruct(all=thinker_sequences, prompt=thinker_input_ids), + speaker=extract_speaker_from_prompt(prompt, index=i), + language=extract_language_from_prompt(prompt, index=i), + ) + info = to_dict(payload) + prompt_len = _compute_talker_prompt_ids_length(info, device=device) talker_inputs.append( OmniTokensPrompt( From a50f745719abf57d88f6090a2a8d0644a5ff1d37 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:34:12 +0000 Subject: [PATCH 09/53] migrate mimo_audio llm2code2wav_async_chunk to OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- vllm_omni/data_entry_keys.py | 3 + .../stage_input_processors/mimo_audio.py | 65 +++++++++++-------- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index d83d95edb5e..3c765158f12 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -179,6 +179,9 @@ class MetaStruct(_StructBase): codec_streaming: bool | None = None ref_code_len: int | None = None talker_prefill_offset: int | None = None + codec_chunk_frames: int | None = None + codec_left_context_frames: int | None = None + code_flat_numel: int | None = None class OmniPayloadStruct(_StructBase): diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index 4910552edd7..cb376ddf4c4 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -4,7 +4,13 @@ from vllm.inputs import TextPrompt from vllm.logger import init_logger -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID @@ -56,7 +62,12 @@ def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torc def _make_finished_sentinel() -> dict[str, Any]: """Return a minimal payload with finished=True so Stage-1 can end the request.""" - return {"codes": {"audio": []}, "meta": {"finished": torch.tensor(True, dtype=torch.bool)}} + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) + ) def _flush_remaining_codes( @@ -77,18 +88,20 @@ def _flush_remaining_codes( # Align with qwen3_omni talker2code2wav_async_chunk: decoder strip uses explicit frame count. left_ctx_frames = max(0, min(length - context_length, left_context_size)) - flat_codes = torch.tensor(accumulated[-end_index:]).reshape(-1).tolist() - - return { - "codes": {"audio": flat_codes}, - "meta": { - "left_context_size": left_ctx_frames, - "codec_chunk_frames": chunk_size, - "codec_left_context_frames": left_context_size, - "code_flat_numel": len(flat_codes), - "finished": torch.tensor(True, dtype=torch.bool), - }, - } + flat_codes = torch.tensor(accumulated[-end_index:]).reshape(-1) + + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=flat_codes), + meta=MetaStruct( + left_context_size=left_ctx_frames, + codec_chunk_frames=chunk_size, + codec_left_context_frames=left_context_size, + code_flat_numel=int(flat_codes.numel()), + finished=torch.tensor(True, dtype=torch.bool), + ), + ) + ) def _is_codes_empty(codes: Any) -> bool: @@ -168,18 +181,18 @@ def llm2code2wav_async_chunk( left_ctx_frames = max(0, min(length - context_length, left_context_size)) flat_codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).reshape(-1).tolist() - return { - "codes": { - "audio": flat_codes, - }, - "meta": { - "left_context_size": left_ctx_frames, - "codec_chunk_frames": chunk_size, - "codec_left_context_frames": left_context_size, - "code_flat_numel": len(flat_codes), - "finished": torch.tensor(is_finished, dtype=torch.bool), - }, - } + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(flat_codes)), + meta=MetaStruct( + left_context_size=left_ctx_frames, + codec_chunk_frames=chunk_size, + codec_left_context_frames=left_context_size, + code_flat_numel=len(flat_codes), + finished=torch.tensor(is_finished, dtype=torch.bool), + ), + ) + ) def llm2code2wav( From d9e1624c005bbb7719054eae80c27e4c3e3d43d4 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:41:36 +0000 Subject: [PATCH 10/53] qwen3_tts and cosyvoice3: write meta.left_context_size as int, drop legacy list/flat shapes Signed-off-by: Divyansh Singhvi --- .../models/cosyvoice3/cosyvoice3.py | 13 +++++++++---- .../models/qwen3_tts/qwen3_tts_code2wav.py | 8 +------- .../stage_input_processors/cosyvoice3.py | 6 ++++-- .../stage_input_processors/qwen3_tts.py | 19 +++++++------------ 4 files changed, 21 insertions(+), 25 deletions(-) diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 2fba8fb8af1..466611d8b7b 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -716,7 +716,11 @@ def forward( if ( req_ids.numel() > 0 and info - and ("token_offset" in info or "left_context_size" in info or "generated_len" in info) + and ( + "token_offset" in info + or "left_context_size" in info.get("meta", {}) + or "generated_len" in info + ) ): info_keys = ",".join(sorted(info.keys())) if info else "" logger.warning_once( @@ -744,16 +748,17 @@ def forward( # `generated_len` is injected for many models by the generic # runner, so only explicit chunk-routing fields should switch # code2wav into the streaming path. + meta = info.get("meta", {}) if info else {} uses_streaming_decode = bool(info) and ( - "stream_finished" in info or "token_offset" in info or "left_context_size" in info + "stream_finished" in info or "token_offset" in info or "left_context_size" in meta ) if uses_streaming_decode: token_offset = 0 try: if info and "token_offset" in info: token_offset = max(0, int(info.get("token_offset", 0))) - elif info: - token_offset = max(0, int(info.get("left_context_size", 0))) + elif "left_context_size" in meta: + token_offset = max(0, int(meta.get("left_context_size", 0))) except (TypeError, ValueError): token_offset = 0 diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index b6c384881bf..1eef9d2a4b5 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -236,13 +236,7 @@ def forward( break meta = info.get("meta", {}) if "left_context_size" in meta: - # left_context_size may come through serialization as an int, [int], or tensor([int]). - value = meta["left_context_size"] - if isinstance(value, list): - value = value[0] if value else 0 - if isinstance(value, torch.Tensor): - value = value.reshape(-1)[0].item() if value.numel() > 0 else 0 - left_context_size[i] = int(value) + left_context_size[i] = int(meta["left_context_size"]) for i, req_ids in enumerate(request_ids_list): if req_ids.numel() < 1: parsed.append((0, 0)) diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 736d3a4bdfa..1202773cacd 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -226,9 +226,11 @@ def talker2code2wav_async_chunk( payload = { "codes": {"audio": code_predictor_codes}, - "meta": {"finished": torch.tensor(finished, dtype=torch.bool)}, + "meta": { + "finished": torch.tensor(finished, dtype=torch.bool), + "left_context_size": token_offset, + }, "token_offset": token_offset, - "left_context_size": token_offset, "req_id": [request_id], "stream_finished": torch.tensor(finished, dtype=torch.bool), } diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 6d1f392e332..621873d1697 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -102,18 +102,13 @@ def talker2code2wav( ref_code_len = 0 # Code2Wav expects codebook-major flat: [Q*num_frames] codec_codes = audio_codes.transpose(0, 1).cpu().reshape(-1).tolist() - additional_information: dict[str, Any] = {} - if ref_code_len > 0: - additional_information["meta"] = {"left_context_size": [ref_code_len]} - # Propagate speaker and language from the original prompt so they are - # available as runtime_additional_information in later pipeline stages, - # consistent with qwen3-omni and qwen2.5-omni stage input processors. - speaker = extract_speaker_from_prompt(prompt, index=i) - if speaker is not None: - additional_information["speaker"] = speaker - language = extract_language_from_prompt(prompt, index=i) - if language is not None: - additional_information["language"] = language + additional_information = to_dict( + OmniPayloadStruct( + meta=MetaStruct(left_context_size=ref_code_len) if ref_code_len > 0 else None, + speaker=extract_speaker_from_prompt(prompt, index=i), + language=extract_language_from_prompt(prompt, index=i), + ) + ) code2wav_inputs.append( OmniTokensPrompt( prompt_token_ids=codec_codes, From d8d510e4d8769852fbc1babf86aedd4420fa2d7a Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:45:47 +0000 Subject: [PATCH 11/53] drop defensive int() casts now that producers write meta.left_context_size as int Signed-off-by: Divyansh Singhvi --- .../model_executor/models/cosyvoice3/cosyvoice3.py | 12 +++++------- .../models/qwen3_tts/qwen3_tts_code2wav.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 466611d8b7b..a8f966e2f1e 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -753,13 +753,11 @@ def forward( "stream_finished" in info or "token_offset" in info or "left_context_size" in meta ) if uses_streaming_decode: - token_offset = 0 - try: - if info and "token_offset" in info: - token_offset = max(0, int(info.get("token_offset", 0))) - elif "left_context_size" in meta: - token_offset = max(0, int(meta.get("left_context_size", 0))) - except (TypeError, ValueError): + if info and "token_offset" in info: + token_offset = max(0, info["token_offset"]) + elif "left_context_size" in meta: + token_offset = max(0, meta["left_context_size"]) + else: token_offset = 0 cache_state = None diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 1eef9d2a4b5..202f98b022e 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -236,7 +236,7 @@ def forward( break meta = info.get("meta", {}) if "left_context_size" in meta: - left_context_size[i] = int(meta["left_context_size"]) + left_context_size[i] = meta["left_context_size"] for i, req_ids in enumerate(request_ids_list): if req_ids.numel() < 1: parsed.append((0, 0)) From 992f61e9bff3c5a81d43e99f3305adec9eb858da Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:50:27 +0000 Subject: [PATCH 12/53] drop redundant top-level cosyvoice3 token_offset (same value as meta.left_context_size) Signed-off-by: Divyansh Singhvi --- .../models/cosyvoice3/cosyvoice3.py | 17 +++-------------- .../stage_input_processors/cosyvoice3.py | 1 - 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index a8f966e2f1e..ac7282ac48c 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -716,11 +716,7 @@ def forward( if ( req_ids.numel() > 0 and info - and ( - "token_offset" in info - or "left_context_size" in info.get("meta", {}) - or "generated_len" in info - ) + and ("left_context_size" in info.get("meta", {}) or "generated_len" in info) ): info_keys = ",".join(sorted(info.keys())) if info else "" logger.warning_once( @@ -749,16 +745,9 @@ def forward( # runner, so only explicit chunk-routing fields should switch # code2wav into the streaming path. meta = info.get("meta", {}) if info else {} - uses_streaming_decode = bool(info) and ( - "stream_finished" in info or "token_offset" in info or "left_context_size" in meta - ) + uses_streaming_decode = bool(info) and ("stream_finished" in info or "left_context_size" in meta) if uses_streaming_decode: - if info and "token_offset" in info: - token_offset = max(0, info["token_offset"]) - elif "left_context_size" in meta: - token_offset = max(0, meta["left_context_size"]) - else: - token_offset = 0 + token_offset = max(0, meta.get("left_context_size", 0)) cache_state = None if req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"): diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 1202773cacd..66d52ac4d68 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -230,7 +230,6 @@ def talker2code2wav_async_chunk( "finished": torch.tensor(finished, dtype=torch.bool), "left_context_size": token_offset, }, - "token_offset": token_offset, "req_id": [request_id], "stream_finished": torch.tensor(finished, dtype=torch.bool), } From 2fcab85c39de26da99b0b76eac29ddf349ab5f2e Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 12:54:57 +0000 Subject: [PATCH 13/53] migrate qwen2_5_omni thinker2talker to OmniPayloadStruct, add output_shape/prefill_shape Signed-off-by: Divyansh Singhvi --- vllm_omni/data_entry_keys.py | 2 ++ .../stage_input_processors/qwen2_5_omni.py | 29 +++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 3c765158f12..fd1d451e36a 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -130,6 +130,7 @@ class _StructBase(msgspec.Struct, omit_defaults=True, kw_only=True, forbid_unkno class HiddenStatesStruct(_StructBase): output: torch.Tensor | None = None + output_shape: list[int] | None = None trailing_text: torch.Tensor | None = None last: torch.Tensor | None = None layers: dict[int, torch.Tensor] | None = None @@ -137,6 +138,7 @@ class HiddenStatesStruct(_StructBase): class EmbeddingsStruct(_StructBase): prefill: torch.Tensor | None = None + prefill_shape: list[int] | None = None decode: torch.Tensor | None = None cached_decode: torch.Tensor | None = None tts_bos: torch.Tensor | None = None diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index c5ae43c824c..fd80a93c311 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -1,7 +1,14 @@ import torch from vllm.inputs import TextPrompt -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import ( + EmbeddingsStruct, + HiddenStatesStruct, + IdsStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.inputs.data import OmniTokensPrompt TALKER_CODEC_PAD_TOKEN_ID = 8292 @@ -38,17 +45,15 @@ def thinker2talker( mm: OmniPayload = output.multimodal_output latent = mm["latent"] thinker_hidden_states = latent.clone().detach().to(latent.device) - additional_information = { - "hidden_states": { - "output": thinker_hidden_states[prompt_token_ids_len:].to(torch.float32), - "output_shape": list(thinker_hidden_states[prompt_token_ids_len:].shape), - }, - "embed": { - "prefill": thinker_hidden_states[:prompt_token_ids_len].to(torch.float32), - "prefill_shape": list(thinker_hidden_states[:prompt_token_ids_len].shape), - }, - "ids": {"prompt": prompt_token_ids, "output": thinker_output_ids}, - } + decode_hidden = thinker_hidden_states[prompt_token_ids_len:].to(torch.float32) + prefill_hidden = thinker_hidden_states[:prompt_token_ids_len].to(torch.float32) + additional_information = to_dict( + OmniPayloadStruct( + hidden_states=HiddenStatesStruct(output=decode_hidden, output_shape=list(decode_hidden.shape)), + embed=EmbeddingsStruct(prefill=prefill_hidden, prefill_shape=list(prefill_hidden.shape)), + ids=IdsStruct(prompt=list(prompt_token_ids), output=list(thinker_output_ids)), + ) + ) talker_inputs.append( OmniTokensPrompt( prompt_token_ids=[TALKER_CODEC_START_TOKEN_ID] From fb99cbbf9b98532c570b1b896a25c5b14eb17f8d Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 13:06:48 +0000 Subject: [PATCH 14/53] migrate voxtral_tts and fish_speech async_chunk producers to OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- .../stage_input_processors/fish_speech.py | 27 +++++++++---- .../stage_input_processors/voxtral_tts.py | 38 +++++++++++++------ 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/vllm_omni/model_executor/stage_input_processors/fish_speech.py b/vllm_omni/model_executor/stage_input_processors/fish_speech.py index 365b303be2b..4416f4bb58e 100644 --- a/vllm_omni/model_executor/stage_input_processors/fish_speech.py +++ b/vllm_omni/model_executor/stage_input_processors/fish_speech.py @@ -5,6 +5,14 @@ import torch from vllm.logger import init_logger +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) + logger = init_logger(__name__) @@ -61,7 +69,7 @@ def slow_ar_to_dac_decoder_async_chunk( pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayload | None: """Async streaming processor: emit code chunks as they are produced. Accumulates per-step codes and emits fixed-size chunks with left context @@ -140,9 +148,14 @@ def slow_ar_to_dac_decoder_async_chunk( # Pack into codebook-major flat codes. stacked_frames = torch.stack(window_frames, dim=0) - code_predictor_codes = stacked_frames.transpose(0, 1).reshape(-1).tolist() - - return { - "codes": {"audio": code_predictor_codes}, - "meta": {"left_context_size": left_context_size, "finished": torch.tensor(finished, dtype=torch.bool)}, - } + code_predictor_codes = stacked_frames.transpose(0, 1).reshape(-1) + + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=code_predictor_codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(finished, dtype=torch.bool), + ), + ) + ) diff --git a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py index e1a58b6f16d..7c878235ee6 100644 --- a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py @@ -4,6 +4,13 @@ from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.inputs.data import OmniTokensPrompt logger = init_logger(__name__) @@ -39,7 +46,7 @@ def generator2tokenizer( return tokenizer_inputs -def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: +def _extract_last_frame(pooling_output: OmniPayload) -> torch.Tensor | None: audio = pooling_output.get("audio") if not isinstance(audio, torch.Tensor) or audio.numel() == 0: return None @@ -48,10 +55,10 @@ def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: def generator2tokenizer_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayload | None: request_id = request.external_req_id finished = bool(is_finished or request.is_finished()) @@ -81,10 +88,12 @@ def generator2tokenizer_async_chunk( # finished and nothing was produced, emit an EOF marker. if length <= 0: if finished: - return { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) + ) return None # Use a small chunk size at begin @@ -104,7 +113,14 @@ def generator2tokenizer_async_chunk( # Pack context + chunk into codebook-major flat codes for adapter. code_predictor_codes = torch.tensor(window_frames).reshape(-1).tolist() - return { - "codes": {"audio": [int(ctx_frames)] + [int(context_length)] + code_predictor_codes}, - "meta": {"finished": torch.tensor(finished, dtype=torch.bool)}, - } + return to_dict( + OmniPayloadStruct( + codes=CodesStruct( + audio=torch.tensor( + [int(ctx_frames), int(context_length)] + code_predictor_codes, + dtype=torch.long, + ), + ), + meta=MetaStruct(finished=torch.tensor(finished, dtype=torch.bool)), + ) + ) From 56bc097ff26ffbfd4a92986d719e205a310aa3bc Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 13:12:27 +0000 Subject: [PATCH 15/53] fix voxcpm latent2vae_async_chunk: write nested codes.audio + meta.finished so chunk_transfer_adapter receiver can read it Signed-off-by: Divyansh Singhvi --- .../stage_input_processors/voxcpm.py | 43 +++++++++++-------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/vllm_omni/model_executor/stage_input_processors/voxcpm.py b/vllm_omni/model_executor/stage_input_processors/voxcpm.py index c2fcf521bf4..666dae4fb45 100644 --- a/vllm_omni/model_executor/stage_input_processors/voxcpm.py +++ b/vllm_omni/model_executor/stage_input_processors/voxcpm.py @@ -5,6 +5,13 @@ import torch from vllm.inputs import TextPrompt +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.inputs.data import OmniTokensPrompt _VOXCPM_LATENT_MAGIC = 131071 @@ -88,12 +95,21 @@ def latent2vae( return vae_inputs +def _eof_payload() -> OmniPayload: + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) + ) + + def latent2vae_async_chunk( transfer_manager: Any, pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayload | None: """Stage-0 latent → stage-1 VAE under ``async_chunk`` (connector payload).""" # Kept for callback signature compatibility with OmniChunkTransferAdapter. _ = transfer_manager @@ -101,28 +117,19 @@ def latent2vae_async_chunk( if callable(getattr(request, "is_finished", None)): finished_request = finished_request or _coerce_finished_flag(request.is_finished()) if not isinstance(pooling_output, dict): - if finished_request: - return { - "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), - } - return None + return _eof_payload() if finished_request else None latent = pooling_output.get("latent_audio_feat") if isinstance(latent, torch.Tensor) and latent.numel() == 0: latent = None if latent is None: - if finished_request: - return { - "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), - } - return None + return _eof_payload() if finished_request else None serialized_codes = _serialize_latent_to_codes(latent) - out: dict[str, Any] = { - "code_predictor_codes": serialized_codes, - "finished": torch.tensor(finished_request, dtype=torch.bool), - } - return out + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(serialized_codes, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(finished_request, dtype=torch.bool)), + ) + ) From 22635323eab9f62daf3787053ef8c07538e4f148 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 13:22:23 +0000 Subject: [PATCH 16/53] add encode_payload/decode_payload helpers (msgspec-native) for OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 53 +++++++++++++++++++++++++++++++++++ vllm_omni/data_entry_keys.py | 25 +++++++++++++++-- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index d00099b6ac0..49ab04d41f5 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -363,3 +363,56 @@ def test_context_in_error_message(self): with pytest.raises(msgspec.ValidationError, match="my_call_site"): validate_payload({"bad": 1}, context="my_call_site") + + +class TestNativeMsgspecEncoding: + """Phase 6 scaffolding: native msgspec encode/decode for OmniPayloadStruct.""" + + def test_encode_decode_round_trip_tensor(self): + from vllm_omni.data_entry_keys import decode_payload, encode_payload + + original = OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor([1, 2, 3, 4], dtype=torch.long)), + meta=MetaStruct(left_context_size=5, finished=torch.tensor(True)), + ) + wire = encode_payload(original) + assert isinstance(wire, bytes) + restored = decode_payload(wire) + assert isinstance(restored, OmniPayloadStruct) + assert torch.equal(restored.codes.audio, original.codes.audio) + assert restored.meta.left_context_size == 5 + assert bool(restored.meta.finished.item()) is True + + def test_encode_decode_round_trip_dtypes(self): + from vllm_omni.data_entry_keys import decode_payload, encode_payload + + for dtype in (torch.float32, torch.float16, torch.bfloat16, torch.int64, torch.bool): + original = OmniPayloadStruct(codes=CodesStruct(audio=torch.tensor([1, 0, 1], dtype=dtype))) + restored = decode_payload(encode_payload(original)) + assert restored.codes.audio.dtype == dtype, f"dtype mismatch for {dtype}" + + def test_encode_decode_preserves_shape(self): + from vllm_omni.data_entry_keys import decode_payload, encode_payload + + t = torch.randn(3, 4, 5) + original = OmniPayloadStruct(hidden_states=HiddenStatesStruct(output=t)) + restored = decode_payload(encode_payload(original)) + assert restored.hidden_states.output.shape == (3, 4, 5) + assert torch.allclose(restored.hidden_states.output, t) + + def test_encode_decode_speaker_language(self): + from vllm_omni.data_entry_keys import decode_payload, encode_payload + + original = OmniPayloadStruct(speaker="ethan", language="en") + restored = decode_payload(encode_payload(original)) + assert restored.speaker == "ethan" + assert restored.language == "en" + + def test_decode_rejects_unknown_field(self): + from vllm_omni.data_entry_keys import _OMNI_PAYLOAD_ENCODER, decode_payload + + # Manually craft msgpack with unknown top-level field + bad_dict = {"code_predictor_codes": [1, 2, 3]} + wire = _OMNI_PAYLOAD_ENCODER.encode(bad_dict) + with pytest.raises(msgspec.ValidationError, match="unknown field"): + decode_payload(wire) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index fd1d451e36a..39ba0cb42a9 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -210,11 +210,18 @@ class OmniPayloadStruct(_StructBase): } +_TENSOR_MARKER = "__tensor__" + + def _msgspec_dec_hook(typ: type, obj: Any) -> Any: - """Bridge non-msgspec types (``torch.Tensor``) when decoding into Structs.""" + """Bridge non-msgspec types when decoding bytes/dicts into Structs.""" if typ is torch.Tensor: if isinstance(obj, torch.Tensor): return obj + if isinstance(obj, dict) and obj.get(_TENSOR_MARKER): + arr = np.frombuffer(obj["data"], dtype=np.dtype(obj["dtype"])) + arr = arr.reshape(obj["shape"]) + return torch.from_numpy(arr.copy()) raise TypeError(f"cannot decode {type(obj).__name__} into torch.Tensor") raise NotImplementedError(f"no decoder for {typ}") @@ -222,7 +229,7 @@ def _msgspec_dec_hook(typ: type, obj: Any) -> Any: def _msgspec_enc_hook(obj: Any) -> Any: if isinstance(obj, torch.Tensor): return { - "__tensor__": True, + _TENSOR_MARKER: True, "data": obj.detach().cpu().contiguous().numpy().tobytes(), "shape": list(obj.shape), "dtype": _dtype_to_name(obj.dtype), @@ -230,6 +237,20 @@ def _msgspec_enc_hook(obj: Any) -> Any: raise NotImplementedError(f"no encoder for {type(obj).__name__}") +_OMNI_PAYLOAD_ENCODER = msgspec.msgpack.Encoder(enc_hook=_msgspec_enc_hook) +_OMNI_PAYLOAD_DECODER = msgspec.msgpack.Decoder(OmniPayloadStruct, dec_hook=_msgspec_dec_hook) + + +def encode_payload(struct: OmniPayloadStruct) -> bytes: + """Encode ``OmniPayloadStruct`` to msgpack bytes for cross-process transport.""" + return _OMNI_PAYLOAD_ENCODER.encode(struct) + + +def decode_payload(data: bytes) -> OmniPayloadStruct: + """Decode msgpack bytes back to ``OmniPayloadStruct``, validating the schema.""" + return _OMNI_PAYLOAD_DECODER.decode(data) + + def to_struct(payload: dict[str, Any]) -> OmniPayloadStruct: """Convert a payload dict into ``OmniPayloadStruct``, validating types. From 238d14a14a954078901ec03120dce8c9a47a617a Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 13:52:43 +0000 Subject: [PATCH 17/53] drop tautological tests for TypedDict construction and msgspec attr access Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 39 ----------------------------------- 1 file changed, 39 deletions(-) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index 49ab04d41f5..3cfd5c72d14 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -22,32 +22,6 @@ from vllm_omni.engine import AdditionalInformationPayload -class TestOmniPayload: - def test_nested_payload_structure(self): - """Verify OmniPayload can be constructed with nested dicts.""" - payload: OmniPayload = { - "hidden_states": {"output": torch.tensor([1.0])}, - "embed": {"prefill": torch.tensor([2.0])}, - "codes": {"audio": torch.tensor([3.0])}, - "ids": {"all": [1, 2, 3]}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } - assert torch.equal(payload["hidden_states"]["output"], torch.tensor([1.0])) - assert torch.equal(payload["embed"]["prefill"], torch.tensor([2.0])) - assert torch.equal(payload["codes"]["audio"], torch.tensor([3.0])) - assert payload["ids"]["all"] == [1, 2, 3] - assert payload["meta"]["finished"].item() is True - - def test_partial_payload(self): - """OmniPayload fields are all optional (total=False).""" - payload: OmniPayload = {"meta": {"finished": torch.tensor(False, dtype=torch.bool)}} - assert payload["meta"]["finished"].item() is False - - def test_empty_payload(self): - payload: OmniPayload = {} - assert len(payload) == 0 - - class TestFlattenPayload: def test_basic_nested_to_dotted(self): nested = { @@ -265,14 +239,6 @@ def test_none_values_skipped(self): class TestOmniPayloadStruct: """Runtime-validated mirror of OmniPayload (msgspec.Struct).""" - def test_construct_directly(self): - p = OmniPayloadStruct( - meta=MetaStruct(left_context_size=5, finished=torch.tensor(True)), - codes=CodesStruct(audio=torch.zeros(3, 8)), - ) - assert p.meta.left_context_size == 5 - assert torch.equal(p.codes.audio, torch.zeros(3, 8)) - def test_to_struct_validates_dict(self): d = {"meta": {"left_context_size": 25, "finished": torch.tensor(False)}} s = to_struct(d) @@ -312,11 +278,6 @@ def test_to_dict_drops_unset_fields(self): d = to_dict(s) assert d == {"meta": {"left_context_size": 10}} - def test_struct_attr_access_catches_typos_at_lookup(self): - s = to_struct({"meta": {"finished": torch.tensor(True)}}) - with pytest.raises(AttributeError): - _ = s.meta.finisheed - def test_struct_with_all_categories(self): d = { "hidden_states": {"output": torch.zeros(1)}, From 12fc0bee7e689c306a09eb98722321555c769d7a Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 14:01:23 +0000 Subject: [PATCH 18/53] drop flatten_payload/unflatten_payload from pooler_output path; data flows nested end-to-end Signed-off-by: Divyansh Singhvi --- .../transfer_adapter/chunk_transfer_adapter.py | 5 ++--- vllm_omni/engine/__init__.py | 3 +-- vllm_omni/engine/output_processor.py | 8 -------- vllm_omni/worker/gpu_ar_model_runner.py | 7 ++----- 4 files changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 74fd4539e1d..cd64082f65b 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -8,7 +8,7 @@ import torch from vllm.v1.request import Request, RequestStatus -from vllm_omni.data_entry_keys import OmniPayload, unflatten_payload +from vllm_omni.data_entry_keys import OmniPayload from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec @@ -226,8 +226,7 @@ def _update_request_payload(self, req_id: str, payload_data: OmniPayload) -> Omn return payload_data def _send_single_request(self, task: dict): - raw_po = task["pooling_output"] - pooling_output = unflatten_payload(raw_po) if isinstance(raw_po, dict) else raw_po + pooling_output = task["pooling_output"] request = task["request"] is_finished = task["is_finished"] stage_id = self.connector.stage_id diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 6c92d7952de..a6d5d929458 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -5,7 +5,6 @@ from typing import Any import msgspec -import torch from vllm.v1.engine import ( EngineCoreOutput, EngineCoreOutputs, @@ -78,7 +77,7 @@ class OmniEngineCoreRequest(EngineCoreRequest): class OmniEngineCoreOutput(EngineCoreOutput): - pooling_output: dict[str, torch.Tensor] | None = None + pooling_output: dict[str, Any] | None = None # Finished flag for streaming input segment is_segment_finished: bool | None = False # Streaming update prompt length diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index ac18597ee0e..9179dcfa216 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -16,7 +16,6 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import IterationStats -from vllm_omni.data_entry_keys import unflatten_payload from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -135,13 +134,6 @@ def _consolidate_multimodal_tensors(self) -> None: except Exception: logger.exception("Error consolidating multimodal tensors") - # Restore nested structure from flat dotted keys now that all tensor - # lists have been concatenated into single tensors. - try: - self.mm_accumulated = unflatten_payload(self.mm_accumulated) - except Exception: - logger.exception("Error unflattening consolidated multimodal tensors") - # Override: do not route to pooling-only path; always create completion # outputs, and attach pooling_result into the CompletionOutput. def make_request_output( diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 947b3164f3e..f37b2224efb 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -37,7 +37,6 @@ from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from vllm_omni.data_entry_keys import flatten_payload from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element @@ -855,7 +854,7 @@ def propose_draft_token_ids(sampled_token_ids): ) # Otherwise we don't have the mm CPU data yet, so we still need to build it if self.omni_prefix_cache is None: - mm_cpu = build_mm_cpu(flatten_payload(multimodal_outputs)) + mm_cpu = build_mm_cpu(multimodal_outputs) self._process_additional_information_updates( hidden_states, @@ -910,9 +909,7 @@ def propose_draft_token_ids(sampled_token_ids): seq_len=seq_len, ) payload.update(mm_payload) - # Flatten nested dicts to dotted keys so pooling_output - # stays dict[str, torch.Tensor] for msgspec serialization. - pooler_output.append(flatten_payload(payload)) + pooler_output.append(payload) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() From d595ed248cc6e895970c2e9d6d5e5c9fab11eb9d Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 14:13:47 +0000 Subject: [PATCH 19/53] trim ai-slop docstrings, align data_entry_keys header with project convention Signed-off-by: Divyansh Singhvi --- vllm_omni/data_entry_keys.py | 47 ++++++------------------------------ 1 file changed, 7 insertions(+), 40 deletions(-) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 39ba0cb42a9..a68697f4ff8 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -1,23 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Structured payload types for inter-stage communication. -Adding a new model? -~~~~~~~~~~~~~~~~~~~ -Every key you put into the inter-stage payload (``additional_information``, -``multimodal_output``, ``pooling_output``) **must** use the nested -``OmniPayload`` TypedDict structure. For each category, every known -qualifier is an explicit field so misspellings are caught statically. - -Categories +Categories under ``OmniPayload``: hidden_states – intermediate / output hidden-state tensors embed – embedding tensors (prefill, decode, special tokens) ids – token-ID sequences codes – codec / audio code tensors meta – scalar metadata, control flags, shapes - -This module provides: -- Structured ``TypedDict`` types for static type checking (``OmniPayload``) -- ``serialize_payload`` / ``deserialize_payload`` for transport across - process boundaries via ``AdditionalInformationPayload`` """ from __future__ import annotations @@ -31,12 +21,6 @@ if TYPE_CHECKING: from vllm_omni.engine import AdditionalInformationEntry, AdditionalInformationPayload -# ── Structured payload types ── -# These are TypedDicts (plain dicts at runtime, zero overhead) that give -# static type checking and IDE autocomplete for inter-stage payloads. -# Every field is optional (total=False) because each stage only populates -# the subset it needs. - class HiddenStates(TypedDict, total=False): output: torch.Tensor @@ -106,26 +90,11 @@ class OmniPayload(TypedDict, total=False): request_id: str -# ── msgspec.Struct schema (runtime-validated mirror of the TypedDicts above) ── -# -# The TypedDicts give static type checking but are plain dicts at runtime, so -# producer/consumer key mismatches degrade silently (the regularize_data_entries -# refactor surfaced ~8 such bugs). These Struct types add runtime type -# checking via ``msgspec.convert``, attribute access (``p.meta.finished`` -# instead of ``p["meta"]["finished"]``), and unify the serialization path -# with the existing msgspec encoders. -# -# Dict and Struct forms coexist during migration; converters -# (:func:`to_struct`, :func:`to_dict`) bridge the two. +# ── msgspec.Struct mirror of the TypedDicts (runtime-validated) ── class _StructBase(msgspec.Struct, omit_defaults=True, kw_only=True, forbid_unknown_fields=True): - """Common base for nested payload structs. - - - ``omit_defaults``: skip ``None`` fields when serializing. - - ``kw_only``: mirror TypedDict construction style. - - ``forbid_unknown_fields``: reject typos and legacy flat keys at decode time. - """ + pass class HiddenStatesStruct(_StructBase): @@ -187,6 +156,7 @@ class MetaStruct(_StructBase): class OmniPayloadStruct(_StructBase): + hidden: torch.Tensor | None = None hidden_states: HiddenStatesStruct | None = None embed: EmbeddingsStruct | None = None ids: IdsStruct | None = None @@ -278,10 +248,7 @@ def validate_payload(payload: dict[str, Any] | None, *, context: str = "payload" def to_dict(struct: OmniPayloadStruct) -> dict[str, Any]: - """Convert ``OmniPayloadStruct`` back to a plain dict, dropping unset fields. - - Used during migration when downstream code still expects dicts. - """ + """Convert ``OmniPayloadStruct`` to a plain dict, dropping ``None`` fields.""" out: dict[str, Any] = {} for field in OmniPayloadStruct.__struct_fields__: value = getattr(struct, field) From f874199d0409af263e1c70a999c3b56d82af8307 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 14:33:40 +0000 Subject: [PATCH 20/53] make build_mm_cpu and to_payload_element handle nested dicts of tensors recursively Signed-off-by: Divyansh Singhvi --- vllm_omni/utils/mm_outputs.py | 45 +++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index 66d4e6ffe04..78f0531232a 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -28,28 +28,30 @@ def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]: if multimodal_outputs: for k, v in multimodal_outputs.items(): - if isinstance(v, torch.Tensor): - mm_cpu[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor): - sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() - if sub_dict: - mm_cpu[k] = sub_dict - elif isinstance(v, list) and len(v) > 0: - cpu_list = [] - for elem in v: - if isinstance(elem, torch.Tensor): - cpu_list.append(elem.detach().to("cpu").contiguous()) - else: - cpu_list.append(elem) - mm_cpu[k] = cpu_list - elif v is not None: - mm_cpu[k] = v + cpu_v = _to_cpu(v) + if cpu_v is not None: + mm_cpu[k] = cpu_v return mm_cpu +def _to_cpu(value): + """Recursively detach + move tensors to CPU; preserve dict/list nesting.""" + if isinstance(value, torch.Tensor): + return value.detach().to("cpu").contiguous() + if isinstance(value, dict): + out = {} + for k, v in value.items(): + cpu_v = _to_cpu(v) + if cpu_v is not None: + out[k] = cpu_v + return out or None + if isinstance(value, list): + if not value: + return None + return [_to_cpu(v) for v in value] + return value + + def to_payload_element( element: object, idx: int, start: int, end: int, pass_lists_through: bool = False, seq_len: int | None = None ): @@ -77,7 +79,10 @@ def to_payload_element( # Every other case is shared between prefix cache (passthrough data) # and running a model without prefix caching. elif isinstance(element, dict): - return {sk: sv[start:end].contiguous() for sk, sv in element.items()} + return { + sk: to_payload_element(sv, idx, start, end, pass_lists_through=pass_lists_through, seq_len=seq_len) + for sk, sv in element.items() + } elif isinstance(element, list): # For lists, clone tensors to avoid cross-request aliasing if pass_lists_through: From 84d7e5c69b814aaaab60a1879c7001efe8095262 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 15:40:20 +0000 Subject: [PATCH 21/53] switch additional_information wire type to OmniPayloadStruct, drop legacy AdditionalInformationPayload/Entry/serialize_payload/flatten_payload Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 222 +----------------------------- vllm_omni/data_entry_keys.py | 159 +-------------------- vllm_omni/engine/__init__.py | 40 +----- vllm_omni/engine/serialization.py | 29 ++-- vllm_omni/request.py | 9 +- 5 files changed, 21 insertions(+), 438 deletions(-) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index 3cfd5c72d14..d9c939a14b6 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -1,4 +1,4 @@ -"""Tests for data_entry_keys: TypedDict payload structure, flatten/unflatten, serialize/deserialize.""" +"""Tests for data_entry_keys.""" import msgspec import pytest @@ -10,230 +10,10 @@ HiddenStatesStruct, IdsStruct, MetaStruct, - OmniPayload, OmniPayloadStruct, - deserialize_payload, - flatten_payload, - serialize_payload, to_dict, to_struct, - unflatten_payload, ) -from vllm_omni.engine import AdditionalInformationPayload - - -class TestFlattenPayload: - def test_basic_nested_to_dotted(self): - nested = { - "codes": {"audio": torch.tensor([1.0])}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 5}, - } - flat = flatten_payload(nested) - assert torch.equal(flat["codes.audio"], torch.tensor([1.0])) - assert flat["meta.finished"].item() is True - assert flat["meta.left_context_size"] == 5 - assert "codes" not in flat - assert "meta" not in flat - - def test_top_level_keys_preserved(self): - nested = { - "latent": torch.tensor([9.0]), - "generated_len": 42, - } - flat = flatten_payload(nested) - assert torch.equal(flat["latent"], torch.tensor([9.0])) - assert flat["generated_len"] == 42 - - def test_hidden_states_layers_expanded(self): - nested = { - "hidden_states": { - "output": torch.tensor([1.0]), - "layers": { - 0: torch.tensor([2.0]), - 24: torch.tensor([3.0]), - }, - }, - } - flat = flatten_payload(nested) - assert torch.equal(flat["hidden_states.output"], torch.tensor([1.0])) - assert torch.equal(flat["hidden_states.layer_0"], torch.tensor([2.0])) - assert torch.equal(flat["hidden_states.layer_24"], torch.tensor([3.0])) - assert "hidden_states.layers" not in flat - - def test_empty_payload(self): - assert flatten_payload({}) == {} - - def test_mixed_nested_and_top_level(self): - nested: OmniPayload = { - "codes": {"audio": torch.tensor([1.0])}, - "latent": torch.tensor([2.0]), - "meta": {"finished": torch.tensor(False, dtype=torch.bool)}, - } - flat = flatten_payload(nested) - assert set(flat.keys()) == {"codes.audio", "latent", "meta.finished"} - - -class TestUnflattenPayload: - def test_basic_dotted_to_nested(self): - flat = { - "codes.audio": torch.tensor([1.0]), - "meta.finished": torch.tensor(True, dtype=torch.bool), - "meta.left_context_size": 5, - } - nested = unflatten_payload(flat) - assert torch.equal(nested["codes"]["audio"], torch.tensor([1.0])) - assert nested["meta"]["finished"].item() is True - assert nested["meta"]["left_context_size"] == 5 - - def test_top_level_keys_preserved(self): - flat = {"latent": torch.tensor([9.0]), "generated_len": 42} - nested = unflatten_payload(flat) - assert torch.equal(nested["latent"], torch.tensor([9.0])) - assert nested["generated_len"] == 42 - - def test_hidden_states_layers_collected(self): - flat = { - "hidden_states.output": torch.tensor([1.0]), - "hidden_states.layer_0": torch.tensor([2.0]), - "hidden_states.layer_24": torch.tensor([3.0]), - } - nested = unflatten_payload(flat) - assert torch.equal(nested["hidden_states"]["output"], torch.tensor([1.0])) - assert torch.equal(nested["hidden_states"]["layers"][0], torch.tensor([2.0])) - assert torch.equal(nested["hidden_states"]["layers"][24], torch.tensor([3.0])) - - def test_empty_payload(self): - assert unflatten_payload({}) == {} - - -class TestFlattenUnflattenRoundTrip: - def test_round_trip_simple(self): - original: OmniPayload = { - "codes": {"audio": torch.tensor([1.0, 2.0])}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 10}, - "ids": {"prompt": [1, 2, 3]}, - "latent": torch.tensor([5.0]), - } - restored = unflatten_payload(flatten_payload(original)) - assert torch.equal(restored["codes"]["audio"], original["codes"]["audio"]) - assert restored["meta"]["finished"].item() is True - assert restored["meta"]["left_context_size"] == 10 - assert restored["ids"]["prompt"] == [1, 2, 3] - assert torch.equal(restored["latent"], original["latent"]) - - def test_round_trip_with_layers(self): - original = { - "hidden_states": { - "output": torch.tensor([1.0]), - "layers": {0: torch.tensor([2.0]), 24: torch.tensor([3.0])}, - }, - } - restored = unflatten_payload(flatten_payload(original)) - assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0])) - assert torch.equal(restored["hidden_states"]["layers"][0], torch.tensor([2.0])) - assert torch.equal(restored["hidden_states"]["layers"][24], torch.tensor([3.0])) - - def test_round_trip_all_categories(self): - original: OmniPayload = { - "hidden_states": {"output": torch.tensor([1.0]), "last": torch.tensor([2.0])}, - "embed": {"prefill": torch.tensor([3.0]), "tts_bos": torch.tensor([4.0])}, - "codes": {"audio": torch.tensor([5.0]), "ref": torch.tensor([6.0])}, - "ids": {"all": [1, 2], "prompt": [3, 4]}, - "meta": {"finished": torch.tensor(False, dtype=torch.bool), "ar_width": 8}, - } - restored = unflatten_payload(flatten_payload(original)) - assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0])) - assert torch.equal(restored["hidden_states"]["last"], torch.tensor([2.0])) - assert torch.equal(restored["embed"]["prefill"], torch.tensor([3.0])) - assert torch.equal(restored["embed"]["tts_bos"], torch.tensor([4.0])) - assert torch.equal(restored["codes"]["audio"], torch.tensor([5.0])) - assert torch.equal(restored["codes"]["ref"], torch.tensor([6.0])) - assert restored["ids"]["all"] == [1, 2] - assert restored["ids"]["prompt"] == [3, 4] - assert restored["meta"]["finished"].item() is False - assert restored["meta"]["ar_width"] == 8 - - -class TestSerializeDeserializePayload: - def test_tensor_round_trip(self): - original: OmniPayload = { - "hidden_states": {"output": torch.tensor([[1.0, 2.0], [3.0, 4.0]])}, - } - wire = serialize_payload(original) - assert isinstance(wire, AdditionalInformationPayload) - restored = deserialize_payload(wire) - assert torch.equal(restored["hidden_states"]["output"], original["hidden_states"]["output"]) - - def test_list_round_trip(self): - original: OmniPayload = { - "ids": {"prompt": [10, 20, 30]}, - } - wire = serialize_payload(original) - restored = deserialize_payload(wire) - assert restored["ids"]["prompt"] == [10, 20, 30] - - def test_finished_tensor_round_trip(self): - original: OmniPayload = { - "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 5}, - } - wire = serialize_payload(original) - restored = deserialize_payload(wire) - assert isinstance(restored["meta"]["finished"], torch.Tensor) - assert restored["meta"]["finished"].dtype == torch.bool - assert restored["meta"]["finished"].item() is True - assert restored["meta"]["left_context_size"] == 5 - - def test_mixed_types_round_trip(self): - original: OmniPayload = { - "hidden_states": {"output": torch.tensor([1.0, 2.0])}, - "ids": {"all": [1, 2, 3]}, - "meta": {"finished": torch.tensor(False, dtype=torch.bool), "ar_width": 4}, - "codes": {"audio": torch.tensor([3.0])}, - } - wire = serialize_payload(original) - restored = deserialize_payload(wire) - assert torch.equal(restored["hidden_states"]["output"], original["hidden_states"]["output"]) - assert restored["ids"]["all"] == [1, 2, 3] - assert restored["meta"]["finished"].item() is False - assert restored["meta"]["ar_width"] == 4 - assert torch.equal(restored["codes"]["audio"], original["codes"]["audio"]) - - def test_hidden_states_layers_round_trip(self): - original = { - "hidden_states": { - "output": torch.tensor([1.0]), - "layers": {0: torch.tensor([2.0]), 24: torch.tensor([3.0])}, - }, - } - wire = serialize_payload(original) - restored = deserialize_payload(wire) - assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0])) - assert torch.equal(restored["hidden_states"]["layers"][0], torch.tensor([2.0])) - assert torch.equal(restored["hidden_states"]["layers"][24], torch.tensor([3.0])) - - def test_tensor_dtype_preserved(self): - # bfloat16 excluded: numpy() doesn't support it; callers must cast before serializing. - for dtype in [torch.float16, torch.float32, torch.int64, torch.int32, torch.bool]: - original: OmniPayload = {"codes": {"audio": torch.tensor([1], dtype=dtype)}} - wire = serialize_payload(original) - restored = deserialize_payload(wire) - assert restored["codes"]["audio"].dtype == dtype, f"dtype mismatch for {dtype}" - - def test_tensor_shape_preserved(self): - t = torch.randn(3, 4, 5) - original: OmniPayload = {"hidden_states": {"output": t}} - wire = serialize_payload(original) - restored = deserialize_payload(wire) - assert restored["hidden_states"]["output"].shape == (3, 4, 5) - assert torch.allclose(restored["hidden_states"]["output"], t) - - def test_empty_payload_returns_none(self): - assert serialize_payload({}) is None - - def test_none_values_skipped(self): - original: OmniPayload = {"meta": {"finished": None}} - wire = serialize_payload(original) - assert wire is None class TestOmniPayloadStruct: diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index a68697f4ff8..dda6f85ddc1 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -12,15 +12,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypedDict +from typing import Any, TypedDict import msgspec import numpy as np import torch -if TYPE_CHECKING: - from vllm_omni.engine import AdditionalInformationEntry, AdditionalInformationPayload - class HiddenStates(TypedDict, total=False): output: torch.Tensor @@ -267,94 +264,6 @@ def to_dict(struct: OmniPayloadStruct) -> dict[str, Any]: return out -# ── Keys whose values are nested dicts (TypedDict sub-categories) ── -_NESTED_KEYS = frozenset({"hidden_states", "embed", "ids", "codes", "meta"}) - -# Sub-TypedDict for each nested category, used by runtime validation. -_NESTED_SCHEMAS: dict[str, type] = { - "hidden_states": HiddenStates, - "embed": Embeddings, - "ids": Ids, - "codes": Codes, - "meta": OmniPayloadMeta, -} - -_ROOT_KEYS: frozenset[str] = frozenset(OmniPayload.__annotations__.keys()) - - -def assert_payload(payload: dict[str, Any], *, context: str = "payload") -> None: - """Validate ``payload`` matches the ``OmniPayload`` nested schema. - - TypedDict is a static-only contract in Python; this helper closes the - loop at runtime by rejecting: - * non-dict payloads - * top-level keys not declared on ``OmniPayload`` - * nested-category values that aren't dicts - * sub-keys not declared on the matching nested TypedDict - - Call at producer/consumer boundaries when a schema violation should - crash the pipeline instead of silently degrading audio quality. - """ - assert isinstance(payload, dict), f"{context}: expected dict, got {type(payload).__name__}" - extra_top = set(payload) - _ROOT_KEYS - assert not extra_top, f"{context}: unknown top-level keys {sorted(extra_top)!r}" - for nested_key, schema in _NESTED_SCHEMAS.items(): - if nested_key not in payload: - continue - sub = payload[nested_key] - assert isinstance(sub, dict), f"{context}: payload[{nested_key!r}] must be dict, got {type(sub).__name__}" - known_sub = frozenset(schema.__annotations__.keys()) - extra_sub = set(sub) - known_sub - assert not extra_sub, f"{context}: payload[{nested_key!r}] unknown sub-keys {sorted(extra_sub)!r}" - - -def flatten_payload(payload: dict[str, Any]) -> dict[str, Any]: - """Flatten a nested ``OmniPayload`` to dotted keys. - - Nested sub-dicts under ``_NESTED_KEYS`` are expanded: - ``{"codes": {"audio": tensor}}`` → ``{"codes.audio": tensor}``. - ``hidden_states["layers"]`` is expanded to ``hidden_states.layer_N``. - Top-level values are kept as-is. - """ - if not payload: - return {} - flat: dict[str, Any] = {} - for key, value in payload.items(): - if key in _NESTED_KEYS and isinstance(value, dict): - for qual, val in value.items(): - if qual == "layers" and key == "hidden_states" and isinstance(val, dict): - for layer_idx, tensor in val.items(): - flat[f"hidden_states.layer_{layer_idx}"] = tensor - else: - flat[f"{key}.{qual}"] = val - else: - flat[key] = value - return flat - - -def unflatten_payload(flat: dict[str, Any]) -> dict[str, Any]: - """Unflatten dotted keys back to nested dicts. - - Reverse of :func:`flatten_payload`. - ``hidden_states.layer_N`` keys are collected into ``hidden_states.layers``. - """ - result: dict[str, Any] = {} - for key, value in flat.items(): - if "." in key: - type_key, qualifier = key.split(".", 1) - sub = result.setdefault(type_key, {}) - if type_key == "hidden_states" and qualifier.startswith("layer_"): - layers = sub.setdefault("layers", {}) - layer_idx = int(qualifier[len("layer_") :]) - layers[layer_idx] = value - else: - sub[qualifier] = value - else: - result[key] = value - return result - - -# ── dtype helpers ── _DTYPE_TO_NAME: dict[torch.dtype, str] = { torch.float32: "float32", torch.float16: "float16", @@ -371,69 +280,3 @@ def unflatten_payload(flat: dict[str, Any]) -> dict[str, Any]: def _dtype_to_name(dtype: torch.dtype) -> str: return _DTYPE_TO_NAME.get(dtype, str(dtype).replace("torch.", "")) - - -def _serialize_tensor(t: torch.Tensor) -> AdditionalInformationEntry: - from vllm_omni.engine import AdditionalInformationEntry - - t_cpu = t.detach().to("cpu").contiguous() - return AdditionalInformationEntry( - tensor_data=t_cpu.numpy().tobytes(), - tensor_shape=list(t_cpu.shape), - tensor_dtype=_dtype_to_name(t_cpu.dtype), - ) - - -def _deserialize_tensor(entry: AdditionalInformationEntry) -> torch.Tensor: - dt = np.dtype(entry.tensor_dtype or "float32") - arr = np.frombuffer(entry.tensor_data, dtype=dt) # type: ignore[arg-type] - arr = arr.reshape(entry.tensor_shape) - return torch.from_numpy(arr.copy()) - - -def serialize_payload( - payload: OmniPayload, -) -> AdditionalInformationPayload | None: - """Serialize an ``OmniPayload`` for EngineCore transport. - - Uses :func:`flatten_payload` to produce dotted keys, then converts - each value to an ``AdditionalInformationEntry``. - """ - from vllm_omni.engine import ( - AdditionalInformationEntry, - AdditionalInformationPayload, - ) - - flat = flatten_payload(payload) - entries: dict[str, AdditionalInformationEntry] = {} - - for key, value in flat.items(): - if isinstance(value, torch.Tensor): - entries[key] = _serialize_tensor(value) - elif isinstance(value, list): - entries[key] = AdditionalInformationEntry(list_data=value) - elif value is not None: - entries[key] = AdditionalInformationEntry(scalar_data=value) - - return AdditionalInformationPayload(entries=entries) if entries else None - - -def deserialize_payload( - wire: AdditionalInformationPayload, -) -> OmniPayload: - """Deserialize an ``AdditionalInformationPayload`` back to ``OmniPayload``. - - Decodes entries to tensors/lists, then uses :func:`unflatten_payload` - to reconstruct the nested structure. - """ - flat: dict[str, Any] = {} - - for key, entry in wire.entries.items(): - if entry.tensor_data is not None: - flat[key] = _deserialize_tensor(entry) - elif entry.list_data is not None: - flat[key] = entry.list_data - elif entry.scalar_data is not None: - flat[key] = entry.scalar_data - - return unflatten_payload(flat) # type: ignore[return-value] diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index a6d5d929458..603e3f6d0dc 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -11,6 +11,8 @@ EngineCoreRequest, ) +from vllm_omni.data_entry_keys import OmniPayloadStruct + class PromptEmbedsPayload(msgspec.Struct): """Serialized prompt embeddings payload for direct transfer. @@ -25,37 +27,6 @@ class PromptEmbedsPayload(msgspec.Struct): dtype: str -class AdditionalInformationEntry(msgspec.Struct): - """One entry of additional_information. - - Three supported forms are encoded: - - tensor: data/shape/dtype - - list: a Python list (msgspec-serializable) - - scalar: a Python scalar (msgspec-serializable) - Exactly one of (tensor_data, list_data, scalar_data) should be non-None. - """ - - # Tensor form - tensor_data: bytes | None = None - tensor_shape: list[int] | None = None - tensor_dtype: str | None = None - - # List form - list_data: list[Any] | None = None - - # Scalar form - scalar_data: Any | None = None - - -class AdditionalInformationPayload(msgspec.Struct): - """Serialized dictionary payload for additional_information. - - Keys are strings; values are encoded as AdditionalInformationEntry. - """ - - entries: dict[str, AdditionalInformationEntry] - - class OmniEngineCoreRequest(EngineCoreRequest): """Engine core request for omni models with embeddings support. @@ -66,14 +37,9 @@ class OmniEngineCoreRequest(EngineCoreRequest): Note: prompt_embeds is inherited from EngineCoreRequest (torch.Tensor | None). PromptEmbedsPayload should be decoded to torch.Tensor before constructing this request. - - Attributes: - additional_information: Optional serialized additional information - dictionary containing tensors or lists to pass along with the request """ - # Optional additional information dictionary (serialized) - additional_information: AdditionalInformationPayload | None = None + additional_information: OmniPayloadStruct | None = None class OmniEngineCoreOutput(EngineCoreOutput): diff --git a/vllm_omni/engine/serialization.py b/vllm_omni/engine/serialization.py index 5b87074106c..ac05d7d754d 100644 --- a/vllm_omni/engine/serialization.py +++ b/vllm_omni/engine/serialization.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Shared serialization helpers for omni engine request payloads.""" from __future__ import annotations @@ -6,37 +8,28 @@ from vllm.logger import init_logger -from vllm_omni.data_entry_keys import OmniPayload, deserialize_payload, serialize_payload -from vllm_omni.engine import AdditionalInformationPayload +from vllm_omni.data_entry_keys import OmniPayloadStruct, to_dict, to_struct logger = init_logger(__name__) def serialize_additional_information( - raw_info: dict[str, Any] | AdditionalInformationPayload | None, + raw_info: dict[str, Any] | OmniPayloadStruct | None, *, log_prefix: str | None = None, -) -> AdditionalInformationPayload | None: - """Serialize omni request metadata for EngineCore transport. - - Delegates to ``serialize_payload`` which understands the nested - ``OmniPayload`` TypedDict structure. - """ +) -> OmniPayloadStruct | None: + """Convert dict-form ``OmniPayload`` into ``OmniPayloadStruct`` for cross-process transport.""" if raw_info is None: return None - if isinstance(raw_info, AdditionalInformationPayload): + if isinstance(raw_info, OmniPayloadStruct): return raw_info - - payload: OmniPayload = raw_info # type: ignore[assignment] - return serialize_payload(payload) + return to_struct(raw_info) -def deserialize_additional_information( - payload: dict | AdditionalInformationPayload | None, -) -> dict: - """Deserialize an *additional_information* payload into a plain dict.""" +def deserialize_additional_information(payload: dict | OmniPayloadStruct | None) -> dict: + """Convert an ``OmniPayloadStruct`` back into a plain dict.""" if payload is None: return {} if isinstance(payload, dict): return payload - return deserialize_payload(payload) # type: ignore[return-value] + return to_dict(payload) diff --git a/vllm_omni/request.py b/vllm_omni/request.py index 48cbf9b31d7..687df354eb8 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -11,7 +11,8 @@ if TYPE_CHECKING: from vllm.v1.core.kv_cache_utils import BlockHash -from vllm_omni.engine import AdditionalInformationPayload, OmniEngineCoreRequest, PromptEmbedsPayload +from vllm_omni.data_entry_keys import OmniPayloadStruct +from vllm_omni.engine import OmniEngineCoreRequest, PromptEmbedsPayload class OmniRequest(Request): @@ -33,7 +34,7 @@ def __init__( prompt_embeds: PromptEmbedsPayload | torch.Tensor | None = None, # Optional external request ID for tracking external_req_id: str | None = None, - additional_information: AdditionalInformationPayload | None = None, + additional_information: OmniPayloadStruct | None = None, *args, **kwargs, ): @@ -46,7 +47,7 @@ def __init__( # Optional external request ID for tracking self.external_req_id: str | None = external_req_id # Serialized additional information payload (optional) - self.additional_information: AdditionalInformationPayload | None = additional_information + self.additional_information: OmniPayloadStruct | None = additional_information @staticmethod def _maybe_decode_prompt_embeds( @@ -112,7 +113,7 @@ class OmniStreamingUpdate: max_tokens: int arrival_time: float sampling_params: SamplingParams | None - additional_information: AdditionalInformationPayload | None = None + additional_information: OmniPayloadStruct | None = None @classmethod def from_request(cls, request: "Request") -> "OmniStreamingUpdate | None": From 93b76dedc1f3b79e7edfe2265ee8f30b860b47a1 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 15:58:28 +0000 Subject: [PATCH 22/53] convert dict to OmniPayloadStruct at additional_information assignment sites Signed-off-by: Divyansh Singhvi --- vllm_omni/core/sched/omni_ar_scheduler.py | 20 +++++-------------- vllm_omni/data_entry_keys.py | 3 +++ .../chunk_transfer_adapter.py | 13 ++++++++---- vllm_omni/engine/async_omni_engine.py | 2 +- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index a5579dd4640..ca80b236124 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -20,6 +20,7 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.output import OmniSchedulerOutput +from vllm_omni.data_entry_keys import to_struct from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, ) @@ -109,7 +110,7 @@ def _request_omits_kv_transfer_to_next_stage(self, request: Request) -> bool: result = False else: info = deserialize_additional_information(payload) - result = info.get("omni_final_stage_id") == 0 + result = info.get("meta", {}).get("omni_final_stage_id") == 0 self._omits_kv_transfer_cache[rid] = result return result @@ -628,20 +629,9 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di } # Also update request.additional_information for good measure add_info = getattr(request, "additional_information", None) - # If additional_information is an AdditionalInformationPayload-like object, - # unpack it into a plain dict. - if ( - add_info is not None - and hasattr(add_info, "entries") - and isinstance(getattr(add_info, "entries"), dict) - ): - request.additional_information = deserialize_additional_information(add_info) - add_info = request.additional_information - if add_info is None: - request.additional_information = {} - add_info = request.additional_information - if isinstance(add_info, dict): - add_info.update(kv_xfer_params) + merged: dict[str, Any] = deserialize_additional_information(add_info) + merged.update(kv_xfer_params) + request.additional_information = to_struct(merged) return kv_xfer_params diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index dda6f85ddc1..9fc4f6f3eee 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -150,6 +150,7 @@ class MetaStruct(_StructBase): codec_chunk_frames: int | None = None codec_left_context_frames: int | None = None code_flat_numel: int | None = None + omni_final_stage_id: int | None = None class OmniPayloadStruct(_StructBase): @@ -166,6 +167,8 @@ class OmniPayloadStruct(_StructBase): speaker: Any = None language: Any = None request_id: str | None = None + past_key_values: list[int] | None = None + kv_metadata: dict[str, Any] | None = None _NESTED_STRUCTS: dict[str, type[_StructBase]] = { diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index cd64082f65b..85c0495769a 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -8,7 +8,7 @@ import torch from vllm.v1.request import Request, RequestStatus -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import OmniPayload, OmniPayloadStruct, to_dict, to_struct from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec @@ -155,7 +155,7 @@ def _poll_single_request(self, request: Request): meta = payload_data.get("meta", {}) if self.model_mode == "ar": merged_payload = self._update_request_payload(external_req_id, payload_data) - request.additional_information = merged_payload + request.additional_information = to_struct(merged_payload) if meta.get("finished"): self.finished_requests.add(req_id) else: @@ -169,7 +169,12 @@ def _poll_single_request(self, request: Request): new_ids = [] request.prompt_token_ids = new_ids prev_info = getattr(request, "additional_information", None) - info = dict(prev_info) if isinstance(prev_info, dict) else {} + if isinstance(prev_info, OmniPayloadStruct): + info = to_dict(prev_info) + elif isinstance(prev_info, dict): + info = dict(prev_info) + else: + info = {} for key, value in payload_data.items(): if key == "codes": continue @@ -183,7 +188,7 @@ def _poll_single_request(self, request: Request): info[key] = merged_sub continue info[key] = value - request.additional_information = info + request.additional_information = to_struct(info) request.num_computed_tokens = 0 # Empty chunk with more data expected: keep polling. diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 8299a577a41..d83b789ef29 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -196,7 +196,7 @@ def _apply_omni_final_stage_metadata( merged: dict[str, Any] = {} if isinstance(request, OmniEngineCoreRequest) and request.additional_information is not None: merged = deserialize_additional_information(request.additional_information) - merged["omni_final_stage_id"] = final_stage_id + merged.setdefault("meta", {})["omni_final_stage_id"] = final_stage_id payload = serialize_additional_information(merged) return OmniEngineCoreRequest( request_id=request.request_id, From 25c674b42312daeb2167203499e992367d6ade3a Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sat, 25 Apr 2026 19:05:20 +0000 Subject: [PATCH 23/53] fix stale AdditionalInformationPayload import in core/sched/output.py Signed-off-by: Divyansh Singhvi --- vllm_omni/core/sched/output.py | 4 ++-- vllm_omni/worker/gpu_model_runner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index c7a8c07c5ce..995dc101d69 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -3,7 +3,7 @@ from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.request import Request -from vllm_omni.engine import AdditionalInformationPayload +from vllm_omni.data_entry_keys import OmniPayloadStruct @dataclass @@ -23,7 +23,7 @@ class OmniNewRequestData(NewRequestData): """ external_req_id: str | None = None - additional_information: AdditionalInformationPayload | None = None + additional_information: OmniPayloadStruct | None = None @classmethod def from_request( diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index d914f1b39df..a459acba3ac 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -932,7 +932,7 @@ def _dummy_run( # ------------------------------------------------------------------ # Payload decoding helpers (torch.Tensor passthrough + legacy - # PromptEmbedsPayload / AdditionalInformationPayload support) + # PromptEmbedsPayload / OmniPayloadStruct support) # ------------------------------------------------------------------ @staticmethod From 02cc628e0724d1faa5e1519614e3894285e59a13 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 05:41:29 +0000 Subject: [PATCH 24/53] drop bfloat16 from encode/decode round-trip dtype test numpy() doesn't support bfloat16; callers must cast before serializing. Matches the dtype list previously used by test_tensor_dtype_preserved on main. Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index d9c939a14b6..a4515a2bd26 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -127,7 +127,8 @@ def test_encode_decode_round_trip_tensor(self): def test_encode_decode_round_trip_dtypes(self): from vllm_omni.data_entry_keys import decode_payload, encode_payload - for dtype in (torch.float32, torch.float16, torch.bfloat16, torch.int64, torch.bool): + # bfloat16 excluded: numpy() doesn't support it; callers must cast before serializing. + for dtype in (torch.float32, torch.float16, torch.int64, torch.bool): original = OmniPayloadStruct(codes=CodesStruct(audio=torch.tensor([1, 0, 1], dtype=dtype))) restored = decode_payload(encode_payload(original)) assert restored.codes.audio.dtype == dtype, f"dtype mismatch for {dtype}" From 04ed2539439c6ac88363373ecc613b5b51427883 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 07:10:27 +0000 Subject: [PATCH 25/53] fix tests for nested meta.left_context_size and tensor codes.audio cosyvoice3 helpers: migrate runtime_info to nested meta.left_context_size (production stopped reading flat token_offset/left_context_size in 992f61e9 and d9e1624c). mimo_audio _flush_remaining_codes: assert via .tolist() since codes.audio is a tensor end-to-end (4e31b658). Signed-off-by: Divyansh Singhvi --- .../models/cosyvoice3/test_cosyvoice3_model_helpers.py | 7 +++---- .../test_mimo_audio_flush_remaining_codes.py | 10 +++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index 9a78c54de65..b7ccbb36fce 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -150,8 +150,7 @@ def test_forward_prefers_token_offset_when_present(): "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "token_offset": 2, - "left_context_size": 1, + "meta": {"left_context_size": 2}, } ] @@ -179,7 +178,7 @@ def test_forward_falls_back_to_left_context_size_for_backward_compat(): "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "left_context_size": 2, + "meta": {"left_context_size": 2}, } ] @@ -200,7 +199,7 @@ def test_forward_ignores_single_request_padded_tail_tokens(): "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "token_offset": 0, + "meta": {"left_context_size": 0}, } ] diff --git a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py index b30da97800b..ceedd6f8d5e 100644 --- a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py +++ b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py @@ -20,7 +20,7 @@ def test_flush_remaining_codes_when_no_codes_accumulated_missing_request_id(): """No entry for request_id: treat as empty, return finished sentinel with empty codes.""" tm = SimpleNamespace(code_prompt_token_ids={}) out = _flush_remaining_codes(tm, "missing", chunk_size=3, left_context_size=3) - assert out["codes"]["audio"] == _sentinel()["codes"]["audio"] + assert out["codes"]["audio"].tolist() == _sentinel()["codes"]["audio"] assert out["meta"]["finished"].item() is True @@ -28,7 +28,7 @@ def test_flush_remaining_codes_when_no_codes_accumulated_empty_list(): """Explicit empty accumulation list returns the same sentinel.""" tm = SimpleNamespace(code_prompt_token_ids={"r": []}) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) - assert out["codes"]["audio"] == [] + assert out["codes"]["audio"].tolist() == [] assert out["meta"]["finished"].item() is True @@ -42,7 +42,7 @@ def test_flush_remaining_codes_partial_chunk_remaining(): ) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) assert out["meta"]["finished"].item() is True - assert out["codes"]["audio"] == [4, 5, 6, 7] + assert out["codes"]["audio"].tolist() == [4, 5, 6, 7] def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size(): @@ -52,7 +52,7 @@ def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size(): ) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) # context_length = chunk_size = 3, end_index = min(6, 6) -> all 6 - assert out["codes"]["audio"] == [1, 2, 3, 4, 5, 6] + assert out["codes"]["audio"].tolist() == [1, 2, 3, 4, 5, 6] @pytest.mark.parametrize( @@ -74,5 +74,5 @@ def test_flush_remaining_codes_context_window_end_index( tm = SimpleNamespace(code_prompt_token_ids={"r": accumulated}) out = _flush_remaining_codes(tm, "r", chunk_size=chunk_size, left_context_size=left_context) expected_flat = list(range(length - expected_end_index, length)) - assert out["codes"]["audio"] == expected_flat + assert out["codes"]["audio"].tolist() == expected_flat assert out["meta"]["finished"].item() is True From 242132b169fd5c4524ae65665032447508a6cb04 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 07:18:31 +0000 Subject: [PATCH 26/53] revert additional_information wire type from OmniPayloadStruct to AdditionalInformationPayload Signed-off-by: Divyansh Singhvi --- vllm_omni/core/sched/omni_ar_scheduler.py | 20 ++- vllm_omni/core/sched/output.py | 4 +- vllm_omni/data_entry_keys.py | 121 +++++++++++++++++- .../chunk_transfer_adapter.py | 15 +-- vllm_omni/engine/__init__.py | 40 +++++- vllm_omni/engine/async_omni_engine.py | 2 +- vllm_omni/engine/serialization.py | 29 +++-- vllm_omni/request.py | 9 +- 8 files changed, 201 insertions(+), 39 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index ca80b236124..a5579dd4640 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -20,7 +20,6 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.output import OmniSchedulerOutput -from vllm_omni.data_entry_keys import to_struct from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, ) @@ -110,7 +109,7 @@ def _request_omits_kv_transfer_to_next_stage(self, request: Request) -> bool: result = False else: info = deserialize_additional_information(payload) - result = info.get("meta", {}).get("omni_final_stage_id") == 0 + result = info.get("omni_final_stage_id") == 0 self._omits_kv_transfer_cache[rid] = result return result @@ -629,9 +628,20 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di } # Also update request.additional_information for good measure add_info = getattr(request, "additional_information", None) - merged: dict[str, Any] = deserialize_additional_information(add_info) - merged.update(kv_xfer_params) - request.additional_information = to_struct(merged) + # If additional_information is an AdditionalInformationPayload-like object, + # unpack it into a plain dict. + if ( + add_info is not None + and hasattr(add_info, "entries") + and isinstance(getattr(add_info, "entries"), dict) + ): + request.additional_information = deserialize_additional_information(add_info) + add_info = request.additional_information + if add_info is None: + request.additional_information = {} + add_info = request.additional_information + if isinstance(add_info, dict): + add_info.update(kv_xfer_params) return kv_xfer_params diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index 995dc101d69..c7a8c07c5ce 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -3,7 +3,7 @@ from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.request import Request -from vllm_omni.data_entry_keys import OmniPayloadStruct +from vllm_omni.engine import AdditionalInformationPayload @dataclass @@ -23,7 +23,7 @@ class OmniNewRequestData(NewRequestData): """ external_req_id: str | None = None - additional_information: OmniPayloadStruct | None = None + additional_information: AdditionalInformationPayload | None = None @classmethod def from_request( diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 9fc4f6f3eee..92c7bc99f74 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -12,12 +12,15 @@ from __future__ import annotations -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict import msgspec import numpy as np import torch +if TYPE_CHECKING: + from vllm_omni.engine import AdditionalInformationEntry, AdditionalInformationPayload + class HiddenStates(TypedDict, total=False): output: torch.Tensor @@ -283,3 +286,119 @@ def to_dict(struct: OmniPayloadStruct) -> dict[str, Any]: def _dtype_to_name(dtype: torch.dtype) -> str: return _DTYPE_TO_NAME.get(dtype, str(dtype).replace("torch.", "")) + + +# ── Keys whose values are nested dicts (TypedDict sub-categories) ── +_NESTED_KEYS = frozenset({"hidden_states", "embed", "ids", "codes", "meta"}) + + +def flatten_payload(payload: dict[str, Any]) -> dict[str, Any]: + """Flatten a nested ``OmniPayload`` to dotted keys. + + Nested sub-dicts under ``_NESTED_KEYS`` are expanded: + ``{"codes": {"audio": tensor}}`` → ``{"codes.audio": tensor}``. + ``hidden_states["layers"]`` is expanded to ``hidden_states.layer_N``. + Top-level values are kept as-is. + """ + if not payload: + return {} + flat: dict[str, Any] = {} + for key, value in payload.items(): + if key in _NESTED_KEYS and isinstance(value, dict): + for qual, val in value.items(): + if qual == "layers" and key == "hidden_states" and isinstance(val, dict): + for layer_idx, tensor in val.items(): + flat[f"hidden_states.layer_{layer_idx}"] = tensor + else: + flat[f"{key}.{qual}"] = val + else: + flat[key] = value + return flat + + +def unflatten_payload(flat: dict[str, Any]) -> dict[str, Any]: + """Unflatten dotted keys back to nested dicts. + + Reverse of :func:`flatten_payload`. + ``hidden_states.layer_N`` keys are collected into ``hidden_states.layers``. + """ + result: dict[str, Any] = {} + for key, value in flat.items(): + if "." in key: + type_key, qualifier = key.split(".", 1) + sub = result.setdefault(type_key, {}) + if type_key == "hidden_states" and qualifier.startswith("layer_"): + layers = sub.setdefault("layers", {}) + layer_idx = int(qualifier[len("layer_") :]) + layers[layer_idx] = value + else: + sub[qualifier] = value + else: + result[key] = value + return result + + +def _serialize_tensor(t: torch.Tensor) -> AdditionalInformationEntry: + from vllm_omni.engine import AdditionalInformationEntry + + t_cpu = t.detach().to("cpu").contiguous() + return AdditionalInformationEntry( + tensor_data=t_cpu.numpy().tobytes(), + tensor_shape=list(t_cpu.shape), + tensor_dtype=_dtype_to_name(t_cpu.dtype), + ) + + +def _deserialize_tensor(entry: AdditionalInformationEntry) -> torch.Tensor: + dt = np.dtype(entry.tensor_dtype or "float32") + arr = np.frombuffer(entry.tensor_data, dtype=dt) # type: ignore[arg-type] + arr = arr.reshape(entry.tensor_shape) + return torch.from_numpy(arr.copy()) + + +def serialize_payload( + payload: OmniPayload, +) -> AdditionalInformationPayload | None: + """Serialize an ``OmniPayload`` for EngineCore transport. + + Uses :func:`flatten_payload` to produce dotted keys, then converts + each value to an ``AdditionalInformationEntry``. + """ + from vllm_omni.engine import ( + AdditionalInformationEntry, + AdditionalInformationPayload, + ) + + flat = flatten_payload(payload) + entries: dict[str, AdditionalInformationEntry] = {} + + for key, value in flat.items(): + if isinstance(value, torch.Tensor): + entries[key] = _serialize_tensor(value) + elif isinstance(value, list): + entries[key] = AdditionalInformationEntry(list_data=value) + elif value is not None: + entries[key] = AdditionalInformationEntry(scalar_data=value) + + return AdditionalInformationPayload(entries=entries) if entries else None + + +def deserialize_payload( + wire: AdditionalInformationPayload, +) -> OmniPayload: + """Deserialize an ``AdditionalInformationPayload`` back to ``OmniPayload``. + + Decodes entries to tensors/lists, then uses :func:`unflatten_payload` + to reconstruct the nested structure. + """ + flat: dict[str, Any] = {} + + for key, entry in wire.entries.items(): + if entry.tensor_data is not None: + flat[key] = _deserialize_tensor(entry) + elif entry.list_data is not None: + flat[key] = entry.list_data + elif entry.scalar_data is not None: + flat[key] = entry.scalar_data + + return unflatten_payload(flat) # type: ignore[return-value] diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 85c0495769a..c88346d3650 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -8,8 +8,6 @@ import torch from vllm.v1.request import Request, RequestStatus -from vllm_omni.data_entry_keys import OmniPayload, OmniPayloadStruct, to_dict, to_struct - from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec from ..utils.logging import get_connector_logger @@ -155,7 +153,7 @@ def _poll_single_request(self, request: Request): meta = payload_data.get("meta", {}) if self.model_mode == "ar": merged_payload = self._update_request_payload(external_req_id, payload_data) - request.additional_information = to_struct(merged_payload) + request.additional_information = merged_payload if meta.get("finished"): self.finished_requests.add(req_id) else: @@ -169,12 +167,7 @@ def _poll_single_request(self, request: Request): new_ids = [] request.prompt_token_ids = new_ids prev_info = getattr(request, "additional_information", None) - if isinstance(prev_info, OmniPayloadStruct): - info = to_dict(prev_info) - elif isinstance(prev_info, dict): - info = dict(prev_info) - else: - info = {} + info = dict(prev_info) if isinstance(prev_info, dict) else {} for key, value in payload_data.items(): if key == "codes": continue @@ -188,7 +181,7 @@ def _poll_single_request(self, request: Request): info[key] = merged_sub continue info[key] = value - request.additional_information = to_struct(info) + request.additional_information = info request.num_computed_tokens = 0 # Empty chunk with more data expected: keep polling. @@ -202,7 +195,7 @@ def _poll_single_request(self, request: Request): return False - def _update_request_payload(self, req_id: str, payload_data: OmniPayload) -> OmniPayload: + def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: """Update the stored payload for *req_id* with the latest chunk.""" if req_id not in self.request_payload: self.request_payload[req_id] = payload_data diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 603e3f6d0dc..a6d5d929458 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -11,8 +11,6 @@ EngineCoreRequest, ) -from vllm_omni.data_entry_keys import OmniPayloadStruct - class PromptEmbedsPayload(msgspec.Struct): """Serialized prompt embeddings payload for direct transfer. @@ -27,6 +25,37 @@ class PromptEmbedsPayload(msgspec.Struct): dtype: str +class AdditionalInformationEntry(msgspec.Struct): + """One entry of additional_information. + + Three supported forms are encoded: + - tensor: data/shape/dtype + - list: a Python list (msgspec-serializable) + - scalar: a Python scalar (msgspec-serializable) + Exactly one of (tensor_data, list_data, scalar_data) should be non-None. + """ + + # Tensor form + tensor_data: bytes | None = None + tensor_shape: list[int] | None = None + tensor_dtype: str | None = None + + # List form + list_data: list[Any] | None = None + + # Scalar form + scalar_data: Any | None = None + + +class AdditionalInformationPayload(msgspec.Struct): + """Serialized dictionary payload for additional_information. + + Keys are strings; values are encoded as AdditionalInformationEntry. + """ + + entries: dict[str, AdditionalInformationEntry] + + class OmniEngineCoreRequest(EngineCoreRequest): """Engine core request for omni models with embeddings support. @@ -37,9 +66,14 @@ class OmniEngineCoreRequest(EngineCoreRequest): Note: prompt_embeds is inherited from EngineCoreRequest (torch.Tensor | None). PromptEmbedsPayload should be decoded to torch.Tensor before constructing this request. + + Attributes: + additional_information: Optional serialized additional information + dictionary containing tensors or lists to pass along with the request """ - additional_information: OmniPayloadStruct | None = None + # Optional additional information dictionary (serialized) + additional_information: AdditionalInformationPayload | None = None class OmniEngineCoreOutput(EngineCoreOutput): diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index f885540699e..a37afd24b4f 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -196,7 +196,7 @@ def _apply_omni_final_stage_metadata( merged: dict[str, Any] = {} if isinstance(request, OmniEngineCoreRequest) and request.additional_information is not None: merged = deserialize_additional_information(request.additional_information) - merged.setdefault("meta", {})["omni_final_stage_id"] = final_stage_id + merged["omni_final_stage_id"] = final_stage_id payload = serialize_additional_information(merged) return OmniEngineCoreRequest( request_id=request.request_id, diff --git a/vllm_omni/engine/serialization.py b/vllm_omni/engine/serialization.py index ac05d7d754d..5b87074106c 100644 --- a/vllm_omni/engine/serialization.py +++ b/vllm_omni/engine/serialization.py @@ -1,5 +1,3 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Shared serialization helpers for omni engine request payloads.""" from __future__ import annotations @@ -8,28 +6,37 @@ from vllm.logger import init_logger -from vllm_omni.data_entry_keys import OmniPayloadStruct, to_dict, to_struct +from vllm_omni.data_entry_keys import OmniPayload, deserialize_payload, serialize_payload +from vllm_omni.engine import AdditionalInformationPayload logger = init_logger(__name__) def serialize_additional_information( - raw_info: dict[str, Any] | OmniPayloadStruct | None, + raw_info: dict[str, Any] | AdditionalInformationPayload | None, *, log_prefix: str | None = None, -) -> OmniPayloadStruct | None: - """Convert dict-form ``OmniPayload`` into ``OmniPayloadStruct`` for cross-process transport.""" +) -> AdditionalInformationPayload | None: + """Serialize omni request metadata for EngineCore transport. + + Delegates to ``serialize_payload`` which understands the nested + ``OmniPayload`` TypedDict structure. + """ if raw_info is None: return None - if isinstance(raw_info, OmniPayloadStruct): + if isinstance(raw_info, AdditionalInformationPayload): return raw_info - return to_struct(raw_info) + + payload: OmniPayload = raw_info # type: ignore[assignment] + return serialize_payload(payload) -def deserialize_additional_information(payload: dict | OmniPayloadStruct | None) -> dict: - """Convert an ``OmniPayloadStruct`` back into a plain dict.""" +def deserialize_additional_information( + payload: dict | AdditionalInformationPayload | None, +) -> dict: + """Deserialize an *additional_information* payload into a plain dict.""" if payload is None: return {} if isinstance(payload, dict): return payload - return to_dict(payload) + return deserialize_payload(payload) # type: ignore[return-value] diff --git a/vllm_omni/request.py b/vllm_omni/request.py index 687df354eb8..48cbf9b31d7 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -11,8 +11,7 @@ if TYPE_CHECKING: from vllm.v1.core.kv_cache_utils import BlockHash -from vllm_omni.data_entry_keys import OmniPayloadStruct -from vllm_omni.engine import OmniEngineCoreRequest, PromptEmbedsPayload +from vllm_omni.engine import AdditionalInformationPayload, OmniEngineCoreRequest, PromptEmbedsPayload class OmniRequest(Request): @@ -34,7 +33,7 @@ def __init__( prompt_embeds: PromptEmbedsPayload | torch.Tensor | None = None, # Optional external request ID for tracking external_req_id: str | None = None, - additional_information: OmniPayloadStruct | None = None, + additional_information: AdditionalInformationPayload | None = None, *args, **kwargs, ): @@ -47,7 +46,7 @@ def __init__( # Optional external request ID for tracking self.external_req_id: str | None = external_req_id # Serialized additional information payload (optional) - self.additional_information: OmniPayloadStruct | None = additional_information + self.additional_information: AdditionalInformationPayload | None = additional_information @staticmethod def _maybe_decode_prompt_embeds( @@ -113,7 +112,7 @@ class OmniStreamingUpdate: max_tokens: int arrival_time: float sampling_params: SamplingParams | None - additional_information: OmniPayloadStruct | None = None + additional_information: AdditionalInformationPayload | None = None @classmethod def from_request(cls, request: "Request") -> "OmniStreamingUpdate | None": From 28cce90e2d71ea2b110d610f6e69626715f52a8a Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 08:29:34 +0000 Subject: [PATCH 27/53] revert pooler_output wire format from dict[str, Any] to dict[str, torch.Tensor] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts commit 12fc0bee. With dict[str, Any], msgspec stops decoding the wire-encoded torch.Tensor values (encoded by vLLM's MsgpackEncoder as (dtype_str, shape, data) tuples), so receivers see the dtype string "float32" / "int32" instead of a tensor. Concrete failure: cosyvoice3 e2e — multi_modal_data["latent"] arrives as list of "float32" strings, _as_tensor() returns None, code2wav falls through with "missing prompt conditioning" and the request 400s. Same wire-decode regression would hit any multi-stage model passing tensors through pooling_output (qwen3_omni, qwen3_tts, mimo_audio, voxtral_tts, ming_flash_omni, etc.). Restores main's flat-dotted-key flow via flatten_payload (producer side) and unflatten_payload (consumer side, in _consolidate_multimodal_tensors). Each touched block byte-matches main; no behavioral change beyond undoing 12fc0bee. Signed-off-by: Divyansh Singhvi --- .../transfer_adapter/chunk_transfer_adapter.py | 5 ++++- vllm_omni/engine/__init__.py | 3 ++- vllm_omni/engine/output_processor.py | 8 ++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index c88346d3650..1b33af1deed 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -8,6 +8,8 @@ import torch from vllm.v1.request import Request, RequestStatus +from vllm_omni.data_entry_keys import unflatten_payload + from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec from ..utils.logging import get_connector_logger @@ -224,7 +226,8 @@ def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> return payload_data def _send_single_request(self, task: dict): - pooling_output = task["pooling_output"] + raw_po = task["pooling_output"] + pooling_output = unflatten_payload(raw_po) if isinstance(raw_po, dict) else raw_po request = task["request"] is_finished = task["is_finished"] stage_id = self.connector.stage_id diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index a6d5d929458..6c92d7952de 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -5,6 +5,7 @@ from typing import Any import msgspec +import torch from vllm.v1.engine import ( EngineCoreOutput, EngineCoreOutputs, @@ -77,7 +78,7 @@ class OmniEngineCoreRequest(EngineCoreRequest): class OmniEngineCoreOutput(EngineCoreOutput): - pooling_output: dict[str, Any] | None = None + pooling_output: dict[str, torch.Tensor] | None = None # Finished flag for streaming input segment is_segment_finished: bool | None = False # Streaming update prompt length diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 02e2d4dec8d..84016fdb4a1 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -16,6 +16,7 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import IterationStats +from vllm_omni.data_entry_keys import unflatten_payload from vllm_omni.engine.output_modality import DRAINABLE_MODALITIES from vllm_omni.outputs import OmniRequestOutput @@ -149,6 +150,13 @@ def _consolidate_multimodal_tensors(self) -> None: except Exception: logger.exception("Error consolidating multimodal tensors") + # Restore nested structure from flat dotted keys now that all tensor + # lists have been concatenated into single tensors. + try: + self.mm_accumulated = unflatten_payload(self.mm_accumulated) + except Exception: + logger.exception("Error unflattening consolidated multimodal tensors") + # Override: do not route to pooling-only path; always create completion # outputs, and attach pooling_result into the CompletionOutput. def make_request_output( From 8787428f488e8f172e058eedf373f54871f610a2 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 09:08:05 +0000 Subject: [PATCH 28/53] update qwen3_tts and voxtral_tts async-chunk test assertions for nested-int meta and tensor codes.audio MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - meta.left_context_size: int (not [int]) — matches commit d9e1624c (drop legacy list/flat shapes in qwen3_tts/cosyvoice3 producers). - voxtral_tts test_context_handling_format: codes.audio is a 1-D long tensor; extract via .item() and compare numel(), not Python int casts on tensor scalars. Signed-off-by: Divyansh Singhvi --- .../test_qwen3_tts_async_chunk.py | 4 ++-- .../test_voxtral_tts_async_chunk.py | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index 07e343bf030..feedde9e556 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -346,7 +346,7 @@ def test_non_async_processor_prepends_ref_code_and_sets_trim_context(): assert len(prompts) == 1 prompt = prompts[0] - assert prompt["additional_information"] == {"meta": {"left_context_size": [2]}} + assert prompt["additional_information"] == {"meta": {"left_context_size": 2}} assert prompt["prompt_token_ids"] == [ 9, 8, @@ -394,4 +394,4 @@ def test_non_async_processor_filters_out_of_range_codec_values(): prompt = prompts[0] # Only ref_code (1 frame) + 2 valid frames = 3 frames * 4 quantizers = 12 codes assert len(prompt["prompt_token_ids"]) == 4 * 3 - assert prompt["additional_information"] == {"meta": {"left_context_size": [1]}} + assert prompt["additional_information"] == {"meta": {"left_context_size": 1}} diff --git a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py index 1b78b103da8..1b41769c798 100644 --- a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py @@ -149,7 +149,7 @@ def test_eof_marker_when_finished_with_no_frames(): request=request, ) - assert payload["codes"] == {"audio": []} + assert payload["codes"]["audio"].tolist() == [] assert payload["meta"]["finished"].item() is True @@ -251,18 +251,15 @@ def test_context_handling_format(): assert payload is not None codes = payload["codes"]["audio"] - # First two elements are ctx_frames and context_length - ctx_frames = codes[0] - context_length = codes[1] + # codes is a 1-D long tensor: [ctx_frames, context_length, ...flat_codes] + ctx_frames = int(codes[0].item()) + context_length = int(codes[1].item()) flat_codes = codes[2:] - # The window has ctx_frames + context_length frames - assert isinstance(ctx_frames, int) - assert isinstance(context_length, int) assert ctx_frames >= 0 assert context_length > 0 # flat_codes = total_window_frames * codebook_dim total_window_frames = ctx_frames + context_length - assert len(flat_codes) == total_window_frames * 2 # codebook_dim=2 + assert flat_codes.numel() == total_window_frames * 2 # codebook_dim=2 def test_none_pooling_output_not_finished_returns_none(): From b90c384aac37abe56d44b484e03fb80034a00969 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 09:17:21 +0000 Subject: [PATCH 29/53] revert producer-side flatten_payload in gpu_ar_model_runner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Completes the 28cce90e revert. With pooling_output typed as dict[str, torch.Tensor], producers must flatten nested dicts to dotted keys before assigning, otherwise msgspec rejects the nested values. Without this, multi-stage models that emit nested payloads (e.g., {"meta": {"finished": tensor}, "codes": {"audio": tensor}}) hit "msgspec.ValidationError: not enough values to unpack (expected 3, got 1)" on the worker→engine wire. Re-adds: - from vllm_omni.data_entry_keys import flatten_payload - mm_cpu = build_mm_cpu(flatten_payload(multimodal_outputs)) - pooler_output.append(flatten_payload(payload)) gpu_ar_model_runner.py and gpu_model_runner.py now byte-match main. Signed-off-by: Divyansh Singhvi --- vllm_omni/worker/gpu_ar_model_runner.py | 7 +++++-- vllm_omni/worker/gpu_model_runner.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index f37b2224efb..947b3164f3e 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -37,6 +37,7 @@ from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm_omni.data_entry_keys import flatten_payload from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element @@ -854,7 +855,7 @@ def propose_draft_token_ids(sampled_token_ids): ) # Otherwise we don't have the mm CPU data yet, so we still need to build it if self.omni_prefix_cache is None: - mm_cpu = build_mm_cpu(multimodal_outputs) + mm_cpu = build_mm_cpu(flatten_payload(multimodal_outputs)) self._process_additional_information_updates( hidden_states, @@ -909,7 +910,9 @@ def propose_draft_token_ids(sampled_token_ids): seq_len=seq_len, ) payload.update(mm_payload) - pooler_output.append(payload) + # Flatten nested dicts to dotted keys so pooling_output + # stays dict[str, torch.Tensor] for msgspec serialization. + pooler_output.append(flatten_payload(payload)) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index a459acba3ac..d914f1b39df 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -932,7 +932,7 @@ def _dummy_run( # ------------------------------------------------------------------ # Payload decoding helpers (torch.Tensor passthrough + legacy - # PromptEmbedsPayload / OmniPayloadStruct support) + # PromptEmbedsPayload / AdditionalInformationPayload support) # ------------------------------------------------------------------ @staticmethod From b82fd5ebd58ab3cef2216bf4c462c4334c3e90cd Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 09:23:15 +0000 Subject: [PATCH 30/53] update cosyvoice3 and voxcpm test assertions for nested meta/codes layout Producer migrations dropped legacy flat keys (token_offset, finished, code_predictor_codes) in favor of nested meta.left_context_size, meta.finished, codes.audio. Tests updated to match. cosyvoice3 talker2code2wav_async_chunk: read meta.left_context_size instead of flat token_offset / left_context_size. voxcpm latent2vae_async_chunk: read meta.finished and codes.audio instead of flat finished / code_predictor_codes; .tolist() for tensor codes.audio comparison. Signed-off-by: Divyansh Singhvi --- .../test_cosyvoice3_stage_input_processors.py | 13 ++++++------- .../test_voxcpm_async_chunk.py | 10 ++++------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py index 9c236664616..d12b2300c55 100644 --- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py +++ b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py @@ -85,8 +85,7 @@ def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): assert payload is not None assert payload["meta"]["finished"].item() is True assert payload["codes"]["audio"] == [1, 2, 3] - assert payload["token_offset"] == 0 - assert payload["left_context_size"] == 0 + assert payload["meta"]["left_context_size"] == 0 assert payload["req_id"] == ["rid-0"] assert payload["stream_finished"].item() is True assert "speech_token" in payload @@ -139,7 +138,7 @@ def test_talker2code2wav_async_chunk_does_not_reemit_without_new_tokens(): assert payload1 is not None assert payload1["codes"]["audio"] == [1, 2] - assert payload1["token_offset"] == 0 + assert payload1["meta"]["left_context_size"] == 0 assert payload2 is None @@ -169,7 +168,7 @@ def test_talker2code2wav_async_chunk_waits_for_prelookahead_and_emits_cumulative assert payload_pending is None assert payload_ready is not None assert payload_ready["codes"]["audio"] == [1, 2, 3] - assert payload_ready["token_offset"] == 0 + assert payload_ready["meta"]["left_context_size"] == 0 assert payload_ready["meta"]["finished"].item() is False @@ -199,11 +198,11 @@ def test_talker2code2wav_async_chunk_final_flush_uses_previous_token_offset(): assert payload_stream is not None assert payload_stream["meta"]["finished"].item() is False assert payload_stream["codes"]["audio"] == [3, 4, 5] - assert payload_stream["token_offset"] == 0 + assert payload_stream["meta"]["left_context_size"] == 0 assert payload_final is not None assert payload_final["meta"]["finished"].item() is True assert payload_final["codes"]["audio"] == [3, 4, 5, 6] - assert payload_final["token_offset"] == 2 + assert payload_final["meta"]["left_context_size"] == 2 def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): @@ -234,7 +233,7 @@ def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): assert payload_pending is None assert payload_ready is not None assert payload_ready["codes"]["audio"] == [8, 9, 10, 11] - assert payload_ready["token_offset"] == 0 + assert payload_ready["meta"]["left_context_size"] == 0 def test_talker2code2wav_async_chunk_emits_terminal_eof_without_duplicate_audio(): diff --git a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py index 7d6fc6e74c9..806bf4b153e 100644 --- a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py @@ -57,8 +57,8 @@ def test_latent2vae_async_chunk_serializes_latent_payload(): ) assert payload is not None - assert torch.equal(payload["finished"], torch.tensor(False, dtype=torch.bool)) - recovered = _decode_serialized_latent(payload["code_predictor_codes"]) + assert torch.equal(payload["meta"]["finished"], torch.tensor(False, dtype=torch.bool)) + recovered = _decode_serialized_latent(payload["codes"]["audio"].tolist()) torch.testing.assert_close(recovered, latent.to(torch.bfloat16).to(torch.float32).unsqueeze(0)) @@ -70,10 +70,8 @@ def test_latent2vae_async_chunk_returns_terminal_marker_without_latent(): is_finished=False, ) - assert payload == { - "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), - } + assert payload["codes"]["audio"].tolist() == [] + assert torch.equal(payload["meta"]["finished"], torch.tensor(True, dtype=torch.bool)) def test_latent2vae_async_chunk_returns_none_for_nonterminal_empty_chunk(): From 1ed91323ebc6ae404fcc8b55c72259afcf4914d6 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 09:38:37 +0000 Subject: [PATCH 31/53] migrate cosyvoice3 prompt-conditioning tensors to embed.* Signed-off-by: Divyansh Singhvi --- .../test_cosyvoice3_model_helpers.py | 48 ++++++++++++------- vllm_omni/data_entry_keys.py | 4 ++ .../models/cosyvoice3/cosyvoice3.py | 25 ++++++---- .../stage_input_processors/cosyvoice3.py | 6 ++- 4 files changed, 54 insertions(+), 29 deletions(-) diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index b7ccbb36fce..945932a2bc8 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -147,9 +147,11 @@ def test_forward_prefers_token_offset_when_present(): runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, "meta": {"left_context_size": 2}, } ] @@ -175,9 +177,11 @@ def test_forward_falls_back_to_left_context_size_for_backward_compat(): runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, "meta": {"left_context_size": 2}, } ] @@ -196,9 +200,11 @@ def test_forward_ignores_single_request_padded_tail_tokens(): model = _make_code2wav_model(with_stride_cfg=True) runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, "meta": {"left_context_size": 0}, } ] @@ -220,9 +226,11 @@ def test_forward_uses_non_stream_decode_without_chunk_metadata(): runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, "prefix_ids": [101, 102], "generated_len": 3, } @@ -258,9 +266,11 @@ def test_forward_reuses_streaming_cache_state_between_chunks(): runtime_info = [ { "req_id": ["rid-stream"], - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, "token_offset": 0, "stream_finished": torch.tensor(False), } @@ -304,9 +314,11 @@ def test_forward_clears_streaming_cache_on_terminal_chunk(): runtime_info = [ { "req_id": ["rid-stream"], - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, "token_offset": 0, "stream_finished": torch.tensor(False), } diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 92c7bc99f74..56717f7898e 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -39,6 +39,8 @@ class Embeddings(TypedDict, total=False): tts_pad_projected: torch.Tensor voice: torch.Tensor speech_feat: torch.Tensor + speech_token: torch.Tensor + embedding: torch.Tensor thinker_reply: torch.Tensor @@ -116,6 +118,8 @@ class EmbeddingsStruct(_StructBase): tts_pad_projected: torch.Tensor | None = None voice: torch.Tensor | None = None speech_feat: torch.Tensor | None = None + speech_token: torch.Tensor | None = None + embedding: torch.Tensor | None = None thinker_reply: torch.Tensor | None = None diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index ac7282ac48c..61251108c93 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -32,6 +32,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler +from vllm_omni.data_entry_keys import EmbeddingsStruct, OmniPayloadStruct, to_dict from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config from vllm_omni.model_executor.models.cosyvoice3.utils import ( concat_text_with_prompt_ids, @@ -674,12 +675,17 @@ def forward( multimodal_outputs = {} if "speech_token" in kwargs: - # Wrap in lists to pass through gpu_ar_model_runner shape filtering - multimodal_outputs = { - "speech_token": [kwargs.get("speech_token")], - "speech_feat": [kwargs.get("speech_feat")], - "embedding": [kwargs.get("embedding")], - } + # Prompt conditioning tensors for code2wav: live under + # ``embed.*`` per OmniPayloadStruct schema. + multimodal_outputs = to_dict( + OmniPayloadStruct( + embed=EmbeddingsStruct( + speech_token=kwargs.get("speech_token"), + speech_feat=kwargs.get("speech_feat"), + embedding=kwargs.get("embedding"), + ), + ) + ) return OmniOutput(text_hidden_states=hidden_states, multimodal_outputs=multimodal_outputs) elif self.model_stage == "cosyvoice3_code2wav": @@ -705,9 +711,10 @@ def forward( info = runtime_info[idx] if idx < len(runtime_info) and isinstance(runtime_info[idx], dict) else {} req_id = self._as_str(info.get("req_id")) if info else None stream_finished = self._as_bool(info.get("stream_finished")) if info else False - speech_token = self._as_tensor(info.get("speech_token")) if info else None - speech_feat = self._as_tensor(info.get("speech_feat")) if info else None - embedding = self._as_tensor(info.get("embedding")) if info else None + embed_info = info.get("embed", {}) if info else {} + speech_token = self._as_tensor(embed_info.get("speech_token")) if embed_info else None + speech_feat = self._as_tensor(embed_info.get("speech_feat")) if embed_info else None + embedding = self._as_tensor(embed_info.get("embedding")) if embed_info else None if speech_token is None or speech_feat is None or embedding is None: if stream_finished and req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"): with self._stream_audio_cache_lock: diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 51248e20dfe..25e83f1a691 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -118,16 +118,18 @@ def talker2code2wav_async_chunk( if not isinstance(request_state, dict) or "_cosyvoice3_async_state" not in request_state: with nullcontext(): info = _decode_additional_information(getattr(request, "additional_information", None)) + info_embed = info.get("embed", {}) if isinstance(info, dict) else {} prompt_payload = {} for key in ("speech_token", "speech_feat", "embedding"): - value = _to_cpu_tensor(info.get(key)) + value = _to_cpu_tensor(info_embed.get(key)) if value is not None: prompt_payload[key] = value if isinstance(pooling_output, dict): + po_embed = pooling_output.get("embed", {}) if isinstance(pooling_output.get("embed"), dict) else {} for key in ("speech_token", "speech_feat", "embedding"): if key in prompt_payload: continue - value = _to_cpu_tensor(pooling_output.get(key)) + value = _to_cpu_tensor(po_embed.get(key)) if value is not None: prompt_payload[key] = value prompt_token = prompt_payload.get("speech_token") From 9d021fcb709f06140c2a00ce6636b7f176c7c044 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 09:43:22 +0000 Subject: [PATCH 32/53] update cosyvoice3 stage_input_processor tests for nested embed.* layout Signed-off-by: Divyansh Singhvi --- .../test_cosyvoice3_stage_input_processors.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py index d12b2300c55..fb26498aa4b 100644 --- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py +++ b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py @@ -68,9 +68,11 @@ def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): external_req_id="rid-0", output_token_ids=[1, 2, 6562, 3], additional_information={ - "speech_token": [torch.tensor([[11, 12, 13]])], - "speech_feat": [torch.tensor([[[0.1, 0.2], [0.3, 0.4]]])], - "embedding": [torch.tensor([[0.5, 0.6]])], + "embed": { + "speech_token": [torch.tensor([[11, 12, 13]])], + "speech_feat": [torch.tensor([[[0.1, 0.2], [0.3, 0.4]]])], + "embedding": [torch.tensor([[0.5, 0.6]])], + }, }, is_finished=lambda: True, ) @@ -211,7 +213,7 @@ def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): external_req_id="rid-pad", output_token_ids=[8, 9, 10], additional_information={ - "speech_token": [torch.tensor([[1, 2, 3]])], + "embed": {"speech_token": [torch.tensor([[1, 2, 3]])]}, }, is_finished=lambda: False, ) From 9e385277d1cf973bfa587d51bcdda05f81324a54 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 09:58:21 +0000 Subject: [PATCH 33/53] fix test_generation_mode for nested meta.left_context_size Signed-off-by: Divyansh Singhvi --- tests/core/sched/test_chunk_scheduling_coordinator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/sched/test_chunk_scheduling_coordinator.py b/tests/core/sched/test_chunk_scheduling_coordinator.py index 5e19465e224..c9459d8ec43 100644 --- a/tests/core/sched/test_chunk_scheduling_coordinator.py +++ b/tests/core/sched/test_chunk_scheduling_coordinator.py @@ -222,7 +222,7 @@ def test_generation_mode(self): self.assertEqual(req.prompt_token_ids, [10, 20, 30]) self.assertEqual(req.num_computed_tokens, 0) self.assertIsNone(req.additional_information) - self.assertEqual(req._omni_initial_model_buffer, {"left_context_size": 25}) + self.assertEqual(req._omni_initial_model_buffer, {"meta": {"left_context_size": 25}}) class TestChunkCoordinatorPostprocess(unittest.TestCase): From ce59a174428152130353f8ddf66f6965ee0720c9 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 12:28:30 +0000 Subject: [PATCH 34/53] recurse into nested dicts in prefix-cache mm_payload unwrap; restore main's flatten/unflatten/serialize/deserialize tests Signed-off-by: Divyansh Singhvi --- vllm_omni/worker/gpu_ar_model_runner.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 947b3164f3e..b595cc581ae 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -888,15 +888,21 @@ def propose_draft_token_ids(sampled_token_ids): if combined_multimodal_outputs: # Prefix cache enabled; all items have already been processed # and split apart for each request as needed, and all tensors - # have already been detached to the CPU. The only exception is - # lists, which we keep as passthrough data for consistent behavior - # in postprocess. + # have already been detached to the CPU. Lists are kept as + # passthrough data for consistent behavior in postprocess. + # Recurse into nested dicts so list-valued sub-keys (e.g. + # embed.tts_bos = [tensor]) are unwrapped to bare tensors + # at the leaves; downstream flatten_payload then yields a + # wire-clean dict[str, torch.Tensor]. + def _unwrap_lists(v): + if isinstance(v, list): + return v[idx] if idx < len(v) else v[0] + if isinstance(v, dict): + return {k: _unwrap_lists(sv) for k, sv in v.items()} + return v + for mm_key in combined_multimodal_outputs.keys(): - value = combined_multimodal_outputs[mm_key][rid] - if isinstance(value, list): - mm_payload[mm_key] = value[idx] if idx < len(value) else value[0] - else: - mm_payload[mm_key] = value + mm_payload[mm_key] = _unwrap_lists(combined_multimodal_outputs[mm_key][rid]) else: # Prefix cache disabled; we still need to process the data From 22d1bbc2febe7aa3a0138a1fadbf3cc40ab87465 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 12:37:02 +0000 Subject: [PATCH 35/53] restore flatten/unflatten/serialize/deserialize tests (skip tautological empty-payload checks) Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 214 ++++++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index a4515a2bd26..ae27cfe0858 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -10,10 +10,16 @@ HiddenStatesStruct, IdsStruct, MetaStruct, + OmniPayload, OmniPayloadStruct, + deserialize_payload, + flatten_payload, + serialize_payload, to_dict, to_struct, + unflatten_payload, ) +from vllm_omni.engine import AdditionalInformationPayload class TestOmniPayloadStruct: @@ -158,3 +164,211 @@ def test_decode_rejects_unknown_field(self): wire = _OMNI_PAYLOAD_ENCODER.encode(bad_dict) with pytest.raises(msgspec.ValidationError, match="unknown field"): decode_payload(wire) + + +class TestFlattenPayload: + def test_basic_nested_to_dotted(self): + nested = { + "codes": {"audio": torch.tensor([1.0])}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 5}, + } + flat = flatten_payload(nested) + assert torch.equal(flat["codes.audio"], torch.tensor([1.0])) + assert flat["meta.finished"].item() is True + assert flat["meta.left_context_size"] == 5 + assert "codes" not in flat + assert "meta" not in flat + + def test_top_level_keys_preserved(self): + nested = { + "latent": torch.tensor([9.0]), + "generated_len": 42, + } + flat = flatten_payload(nested) + assert torch.equal(flat["latent"], torch.tensor([9.0])) + assert flat["generated_len"] == 42 + + def test_hidden_states_layers_expanded(self): + nested = { + "hidden_states": { + "output": torch.tensor([1.0]), + "layers": { + 0: torch.tensor([2.0]), + 24: torch.tensor([3.0]), + }, + }, + } + flat = flatten_payload(nested) + assert torch.equal(flat["hidden_states.output"], torch.tensor([1.0])) + assert torch.equal(flat["hidden_states.layer_0"], torch.tensor([2.0])) + assert torch.equal(flat["hidden_states.layer_24"], torch.tensor([3.0])) + assert "hidden_states.layers" not in flat + + def test_mixed_nested_and_top_level(self): + nested: OmniPayload = { + "codes": {"audio": torch.tensor([1.0])}, + "latent": torch.tensor([2.0]), + "meta": {"finished": torch.tensor(False, dtype=torch.bool)}, + } + flat = flatten_payload(nested) + assert set(flat.keys()) == {"codes.audio", "latent", "meta.finished"} + + +class TestUnflattenPayload: + def test_basic_dotted_to_nested(self): + flat = { + "codes.audio": torch.tensor([1.0]), + "meta.finished": torch.tensor(True, dtype=torch.bool), + "meta.left_context_size": 5, + } + nested = unflatten_payload(flat) + assert torch.equal(nested["codes"]["audio"], torch.tensor([1.0])) + assert nested["meta"]["finished"].item() is True + assert nested["meta"]["left_context_size"] == 5 + + def test_top_level_keys_preserved(self): + flat = {"latent": torch.tensor([9.0]), "generated_len": 42} + nested = unflatten_payload(flat) + assert torch.equal(nested["latent"], torch.tensor([9.0])) + assert nested["generated_len"] == 42 + + def test_hidden_states_layers_collected(self): + flat = { + "hidden_states.output": torch.tensor([1.0]), + "hidden_states.layer_0": torch.tensor([2.0]), + "hidden_states.layer_24": torch.tensor([3.0]), + } + nested = unflatten_payload(flat) + assert torch.equal(nested["hidden_states"]["output"], torch.tensor([1.0])) + assert torch.equal(nested["hidden_states"]["layers"][0], torch.tensor([2.0])) + assert torch.equal(nested["hidden_states"]["layers"][24], torch.tensor([3.0])) + + +class TestFlattenUnflattenRoundTrip: + def test_round_trip_simple(self): + original: OmniPayload = { + "codes": {"audio": torch.tensor([1.0, 2.0])}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 10}, + "ids": {"prompt": [1, 2, 3]}, + "latent": torch.tensor([5.0]), + } + restored = unflatten_payload(flatten_payload(original)) + assert torch.equal(restored["codes"]["audio"], original["codes"]["audio"]) + assert restored["meta"]["finished"].item() is True + assert restored["meta"]["left_context_size"] == 10 + assert restored["ids"]["prompt"] == [1, 2, 3] + assert torch.equal(restored["latent"], original["latent"]) + + def test_round_trip_with_layers(self): + original = { + "hidden_states": { + "output": torch.tensor([1.0]), + "layers": {0: torch.tensor([2.0]), 24: torch.tensor([3.0])}, + }, + } + restored = unflatten_payload(flatten_payload(original)) + assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0])) + assert torch.equal(restored["hidden_states"]["layers"][0], torch.tensor([2.0])) + assert torch.equal(restored["hidden_states"]["layers"][24], torch.tensor([3.0])) + + def test_round_trip_all_categories(self): + original: OmniPayload = { + "hidden_states": {"output": torch.tensor([1.0]), "last": torch.tensor([2.0])}, + "embed": {"prefill": torch.tensor([3.0]), "tts_bos": torch.tensor([4.0])}, + "codes": {"audio": torch.tensor([5.0]), "ref": torch.tensor([6.0])}, + "ids": {"all": [1, 2], "prompt": [3, 4]}, + "meta": {"finished": torch.tensor(False, dtype=torch.bool), "ar_width": 8}, + } + restored = unflatten_payload(flatten_payload(original)) + assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0])) + assert torch.equal(restored["hidden_states"]["last"], torch.tensor([2.0])) + assert torch.equal(restored["embed"]["prefill"], torch.tensor([3.0])) + assert torch.equal(restored["embed"]["tts_bos"], torch.tensor([4.0])) + assert torch.equal(restored["codes"]["audio"], torch.tensor([5.0])) + assert torch.equal(restored["codes"]["ref"], torch.tensor([6.0])) + assert restored["ids"]["all"] == [1, 2] + assert restored["ids"]["prompt"] == [3, 4] + assert restored["meta"]["finished"].item() is False + assert restored["meta"]["ar_width"] == 8 + + +class TestSerializeDeserializePayload: + def test_tensor_round_trip(self): + original: OmniPayload = { + "hidden_states": {"output": torch.tensor([[1.0, 2.0], [3.0, 4.0]])}, + } + wire = serialize_payload(original) + assert isinstance(wire, AdditionalInformationPayload) + restored = deserialize_payload(wire) + assert torch.equal(restored["hidden_states"]["output"], original["hidden_states"]["output"]) + + def test_list_round_trip(self): + original: OmniPayload = { + "ids": {"prompt": [10, 20, 30]}, + } + wire = serialize_payload(original) + restored = deserialize_payload(wire) + assert restored["ids"]["prompt"] == [10, 20, 30] + + def test_finished_tensor_round_trip(self): + original: OmniPayload = { + "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 5}, + } + wire = serialize_payload(original) + restored = deserialize_payload(wire) + assert isinstance(restored["meta"]["finished"], torch.Tensor) + assert restored["meta"]["finished"].dtype == torch.bool + assert restored["meta"]["finished"].item() is True + assert restored["meta"]["left_context_size"] == 5 + + def test_mixed_types_round_trip(self): + original: OmniPayload = { + "hidden_states": {"output": torch.tensor([1.0, 2.0])}, + "ids": {"all": [1, 2, 3]}, + "meta": {"finished": torch.tensor(False, dtype=torch.bool), "ar_width": 4}, + "codes": {"audio": torch.tensor([3.0])}, + } + wire = serialize_payload(original) + restored = deserialize_payload(wire) + assert torch.equal(restored["hidden_states"]["output"], original["hidden_states"]["output"]) + assert restored["ids"]["all"] == [1, 2, 3] + assert restored["meta"]["finished"].item() is False + assert restored["meta"]["ar_width"] == 4 + assert torch.equal(restored["codes"]["audio"], original["codes"]["audio"]) + + def test_hidden_states_layers_round_trip(self): + original = { + "hidden_states": { + "output": torch.tensor([1.0]), + "layers": {0: torch.tensor([2.0]), 24: torch.tensor([3.0])}, + }, + } + wire = serialize_payload(original) + restored = deserialize_payload(wire) + assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0])) + assert torch.equal(restored["hidden_states"]["layers"][0], torch.tensor([2.0])) + assert torch.equal(restored["hidden_states"]["layers"][24], torch.tensor([3.0])) + + def test_tensor_dtype_preserved(self): + # bfloat16 excluded: numpy() doesn't support it; callers must cast before serializing. + for dtype in [torch.float16, torch.float32, torch.int64, torch.int32, torch.bool]: + original: OmniPayload = {"codes": {"audio": torch.tensor([1], dtype=dtype)}} + wire = serialize_payload(original) + restored = deserialize_payload(wire) + assert restored["codes"]["audio"].dtype == dtype, f"dtype mismatch for {dtype}" + + def test_tensor_shape_preserved(self): + t = torch.randn(3, 4, 5) + original: OmniPayload = {"hidden_states": {"output": t}} + wire = serialize_payload(original) + restored = deserialize_payload(wire) + assert restored["hidden_states"]["output"].shape == (3, 4, 5) + assert torch.allclose(restored["hidden_states"]["output"], t) + + def test_empty_payload_returns_none(self): + assert serialize_payload({}) is None + + def test_none_values_skipped(self): + original: OmniPayload = {"meta": {"finished": None}} + wire = serialize_payload(original) + assert wire is None From d4e7f34e9ef56cfccea5cdd0769e97126c3079cf Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 14:19:14 +0000 Subject: [PATCH 36/53] migrate cosyvoice3 talker2code2wav_async_chunk producer to OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- .../test_cosyvoice3_model_helpers.py | 18 +++-- .../test_cosyvoice3_stage_input_processors.py | 28 +++---- vllm_omni/data_entry_keys.py | 4 + .../models/cosyvoice3/cosyvoice3.py | 7 +- .../stage_input_processors/cosyvoice3.py | 76 +++++++++++++------ 5 files changed, 87 insertions(+), 46 deletions(-) diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index 945932a2bc8..44530b9d43b 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -265,14 +265,16 @@ def test_forward_reuses_streaming_cache_state_between_chunks(): ) runtime_info = [ { - "req_id": ["rid-stream"], "embed": { "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), }, - "token_offset": 0, - "stream_finished": torch.tensor(False), + "meta": { + "req_id": ["rid-stream"], + "stream_finished": torch.tensor(False), + "left_context_size": 0, + }, } ] @@ -313,14 +315,16 @@ def test_forward_clears_streaming_cache_on_terminal_chunk(): ) runtime_info = [ { - "req_id": ["rid-stream"], "embed": { "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), }, - "token_offset": 0, - "stream_finished": torch.tensor(False), + "meta": { + "req_id": ["rid-stream"], + "stream_finished": torch.tensor(False), + "left_context_size": 0, + }, } ] @@ -332,7 +336,7 @@ def test_forward_clears_streaming_cache_on_terminal_chunk(): ) assert "rid-stream" in model._stream_vocoder_cache_by_req - runtime_info[0]["stream_finished"] = torch.tensor(True) + runtime_info[0]["meta"]["stream_finished"] = torch.tensor(True) out = model.forward( input_ids=torch.tensor([0, 1, 2], dtype=torch.long), positions=torch.tensor([0, 1, 2], dtype=torch.long), diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py index fb26498aa4b..11e692d46c8 100644 --- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py +++ b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py @@ -86,13 +86,13 @@ def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): assert payload is not None assert payload["meta"]["finished"].item() is True - assert payload["codes"]["audio"] == [1, 2, 3] + assert payload["codes"]["audio"].tolist() == [1, 2, 3] assert payload["meta"]["left_context_size"] == 0 - assert payload["req_id"] == ["rid-0"] - assert payload["stream_finished"].item() is True - assert "speech_token" in payload - assert "speech_feat" in payload - assert "embedding" in payload + assert payload["meta"]["req_id"] == ["rid-0"] + assert payload["meta"]["stream_finished"].item() is True + assert "speech_token" in payload["embed"] + assert "speech_feat" in payload["embed"] + assert "embedding" in payload["embed"] def test_talker2code2wav_async_chunk_emits_eof_when_finished_without_valid_codes(): @@ -112,7 +112,7 @@ def test_talker2code2wav_async_chunk_emits_eof_when_finished_without_valid_codes ) assert payload is not None - assert payload["codes"]["audio"] == [] + assert payload["codes"]["audio"].tolist() == [] assert payload["meta"]["finished"].item() is True @@ -139,7 +139,7 @@ def test_talker2code2wav_async_chunk_does_not_reemit_without_new_tokens(): ) assert payload1 is not None - assert payload1["codes"]["audio"] == [1, 2] + assert payload1["codes"]["audio"].tolist() == [1, 2] assert payload1["meta"]["left_context_size"] == 0 assert payload2 is None @@ -169,7 +169,7 @@ def test_talker2code2wav_async_chunk_waits_for_prelookahead_and_emits_cumulative assert payload_pending is None assert payload_ready is not None - assert payload_ready["codes"]["audio"] == [1, 2, 3] + assert payload_ready["codes"]["audio"].tolist() == [1, 2, 3] assert payload_ready["meta"]["left_context_size"] == 0 assert payload_ready["meta"]["finished"].item() is False @@ -199,11 +199,11 @@ def test_talker2code2wav_async_chunk_final_flush_uses_previous_token_offset(): assert payload_stream is not None assert payload_stream["meta"]["finished"].item() is False - assert payload_stream["codes"]["audio"] == [3, 4, 5] + assert payload_stream["codes"]["audio"].tolist() == [3, 4, 5] assert payload_stream["meta"]["left_context_size"] == 0 assert payload_final is not None assert payload_final["meta"]["finished"].item() is True - assert payload_final["codes"]["audio"] == [3, 4, 5, 6] + assert payload_final["codes"]["audio"].tolist() == [3, 4, 5, 6] assert payload_final["meta"]["left_context_size"] == 2 @@ -234,7 +234,7 @@ def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): assert payload_pending is None assert payload_ready is not None - assert payload_ready["codes"]["audio"] == [8, 9, 10, 11] + assert payload_ready["codes"]["audio"].tolist() == [8, 9, 10, 11] assert payload_ready["meta"]["left_context_size"] == 0 @@ -262,7 +262,7 @@ def test_talker2code2wav_async_chunk_emits_terminal_eof_without_duplicate_audio( assert payload_stream is not None assert payload_stream["meta"]["finished"].item() is False - assert payload_stream["codes"]["audio"] == [3, 4] + assert payload_stream["codes"]["audio"].tolist() == [3, 4] assert payload_final is not None assert payload_final["meta"]["finished"].item() is True - assert payload_final["codes"]["audio"] == [] + assert payload_final["codes"]["audio"].tolist() == [] diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 56717f7898e..04ee6e35d11 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -59,6 +59,8 @@ class Ids(TypedDict, total=False): class OmniPayloadMeta(TypedDict, total=False): finished: torch.Tensor + stream_finished: torch.Tensor + req_id: list[str] left_context_size: int override_keys: list[tuple[str, str]] num_processed_tokens: int @@ -138,6 +140,8 @@ class IdsStruct(_StructBase): class MetaStruct(_StructBase): finished: torch.Tensor | None = None + stream_finished: torch.Tensor | None = None + req_id: list[str] | None = None left_context_size: int | None = None override_keys: list[tuple[str, str]] | None = None num_processed_tokens: int | None = None diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 61251108c93..e12c3254331 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -709,8 +709,9 @@ def forward( for idx, req_ids in enumerate(request_ids_list): info = runtime_info[idx] if idx < len(runtime_info) and isinstance(runtime_info[idx], dict) else {} - req_id = self._as_str(info.get("req_id")) if info else None - stream_finished = self._as_bool(info.get("stream_finished")) if info else False + meta_info = info.get("meta", {}) if info else {} + req_id = self._as_str(meta_info.get("req_id")) + stream_finished = self._as_bool(meta_info.get("stream_finished")) embed_info = info.get("embed", {}) if info else {} speech_token = self._as_tensor(embed_info.get("speech_token")) if embed_info else None speech_feat = self._as_tensor(embed_info.get("speech_feat")) if embed_info else None @@ -752,7 +753,7 @@ def forward( # runner, so only explicit chunk-routing fields should switch # code2wav into the streaming path. meta = info.get("meta", {}) if info else {} - uses_streaming_decode = bool(info) and ("stream_finished" in info or "left_context_size" in meta) + uses_streaming_decode = bool(info) and ("stream_finished" in meta or "left_context_size" in meta) if uses_streaming_decode: token_offset = max(0, meta.get("left_context_size", 0)) diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 25e83f1a691..54cabf93f2b 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -8,9 +8,30 @@ import torch from vllm.inputs import TextPrompt +from vllm_omni.data_entry_keys import ( + CodesStruct, + EmbeddingsStruct, + MetaStruct, + OmniPayloadStruct, + to_dict, +) from vllm_omni.inputs.data import OmniTokensPrompt +def _build_prompt_embed_struct(prompt_payload: dict[str, Any]) -> EmbeddingsStruct | None: + """Wrap prompt_payload's flat speech_token/speech_feat/embedding tensors into EmbeddingsStruct.""" + speech_token = prompt_payload.get("speech_token") + speech_feat = prompt_payload.get("speech_feat") + embedding = prompt_payload.get("embedding") + if speech_token is None and speech_feat is None and embedding is None: + return None + return EmbeddingsStruct( + speech_token=speech_token, + speech_feat=speech_feat, + embedding=embedding, + ) + + def _ensure_list(x: Any) -> list[Any]: if hasattr(x, "_x"): return list(x._x) @@ -182,27 +203,33 @@ def talker2code2wav_async_chunk( if length <= 0: if not finished: return None - payload: dict[str, Any] = { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + embed_struct = None if not state.get("sent_prompt", False): - payload.update(state.get("prompt_payload", {})) + embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True state["terminal_sent"] = True - return payload + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + embed=embed_struct, + ) + ) emitted_token_len = int(state.get("emitted_token_len", 0)) if finished and length <= emitted_token_len: - payload = { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + embed_struct = None if not state.get("sent_prompt", False): - payload.update(state.get("prompt_payload", {})) + embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True state["terminal_sent"] = True - return payload + return to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + embed=embed_struct, + ) + ) with nullcontext(): token_hop_len = max(1, int(state.get("token_hop_len", chunk_size))) @@ -226,19 +253,24 @@ def talker2code2wav_async_chunk( with nullcontext(): code_predictor_codes = [int(frame[0]) for frame in token_frames[:prefix_len]] - payload = { - "codes": {"audio": code_predictor_codes}, - "meta": { - "finished": torch.tensor(finished, dtype=torch.bool), - "left_context_size": token_offset, - }, - "req_id": [request_id], - "stream_finished": torch.tensor(finished, dtype=torch.bool), - } + embed_struct = None if not state.get("sent_prompt", False): - payload.update(state.get("prompt_payload", {})) + embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True + payload = to_dict( + OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(code_predictor_codes, dtype=torch.long)), + meta=MetaStruct( + finished=torch.tensor(finished, dtype=torch.bool), + stream_finished=torch.tensor(finished, dtype=torch.bool), + req_id=[request_id], + left_context_size=token_offset, + ), + embed=embed_struct, + ) + ) + if not finished: state["emitted_token_len"] = emitted_token_len + this_token_hop_len state["token_hop_len"] = min( From c1e7a974725919fcd59724c26166136bded12813 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 18:48:51 +0000 Subject: [PATCH 37/53] narrow OmniPayloadStruct.speaker/language from Any to list[str] | str | None Signed-off-by: Divyansh Singhvi --- vllm_omni/data_entry_keys.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 04ee6e35d11..1dc7d8432cc 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -175,8 +175,8 @@ class OmniPayloadStruct(_StructBase): generated_len: int | None = None model_outputs: list[torch.Tensor] | None = None mtp_inputs: tuple[torch.Tensor, torch.Tensor] | None = None - speaker: Any = None - language: Any = None + speaker: list[str] | str | None = None + language: list[str] | str | None = None request_id: str | None = None past_key_values: list[int] | None = None kv_metadata: dict[str, Any] | None = None From d8f4942c2a2081c11225240edb94c785b28d33ee Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Sun, 26 Apr 2026 20:19:12 +0000 Subject: [PATCH 38/53] flatten nested multimodal_outputs at prefix-cache boundary Signed-off-by: Divyansh Singhvi --- vllm_omni/worker/gpu_ar_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index b595cc581ae..e054bb13a5f 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -227,7 +227,7 @@ def _maybe_update_prefix_cache( self.omni_prefix_cache.update_omni_tensor_prefix_cache( hidden_states=hidden_states, - multimodal_outputs=multimodal_outputs, + multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, num_tokens_padded=num_tokens_padded, @@ -255,7 +255,7 @@ def _maybe_get_combined_prefix_cache_tensors( combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( query_start_loc=self.query_start_loc.cpu, input_batch=self.input_batch, - multimodal_outputs=multimodal_outputs, + multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs, num_scheduled_tokens=num_scheduled_tokens, ) return combined_hidden_states, combined_multimodal_outputs From adebc2035c3a661200c3d1f7e71867815f4606cd Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Mon, 27 Apr 2026 03:27:43 +0000 Subject: [PATCH 39/53] migrate cosyvoice3 code2wav consumer to typed OmniPayloadStruct access Signed-off-by: Divyansh Singhvi --- .../test_cosyvoice3_model_helpers.py | 2 +- .../models/cosyvoice3/cosyvoice3.py | 76 ++++++------------- 2 files changed, 23 insertions(+), 55 deletions(-) diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index 44530b9d43b..657512d5afb 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -231,7 +231,7 @@ def test_forward_uses_non_stream_decode_without_chunk_metadata(): "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), }, - "prefix_ids": [101, 102], + "ids": {"prompt": [101, 102]}, "generated_len": 3, } ] diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index e12c3254331..ac8e5389b27 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -32,7 +32,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler -from vllm_omni.data_entry_keys import EmbeddingsStruct, OmniPayloadStruct, to_dict +from vllm_omni.data_entry_keys import EmbeddingsStruct, OmniPayloadStruct, to_dict, to_struct from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config from vllm_omni.model_executor.models.cosyvoice3.utils import ( concat_text_with_prompt_ids, @@ -344,43 +344,6 @@ def _create_llm_vllm_config(self, parent_config: VllmConfig) -> VllmConfig: # Use parent's cache config - critical for PagedAttention to work correctly return parent_config.with_hf_config(qwen_hf_config, architectures=["Qwen2Model"]) - @staticmethod - def _as_tensor(value: object) -> torch.Tensor | None: - """Extract tensor payload from runtime info fields.""" - if isinstance(value, list): - if not value: - return None - value = value[0] - if isinstance(value, torch.Tensor): - return value - return None - - @staticmethod - def _as_str(value: object) -> str | None: - """Extract string payload from runtime info fields.""" - if isinstance(value, list): - if not value: - return None - value = value[0] - if value is None: - return None - return str(value) - - @staticmethod - def _as_bool(value: object) -> bool: - """Extract boolean payload from runtime info fields.""" - if isinstance(value, list): - if not value: - return False - value = value[0] - if isinstance(value, torch.Tensor): - if value.numel() == 0: - return False - return bool(value.reshape(-1)[0].item()) - if value is None: - return False - return bool(value) - @staticmethod def _cross_fade_audio(audio: torch.Tensor, prev_tail: torch.Tensor) -> torch.Tensor: """Blend previous chunk tail into current chunk head using a Hamming window. @@ -708,25 +671,29 @@ def forward( runtime_info = [] for idx, req_ids in enumerate(request_ids_list): - info = runtime_info[idx] if idx < len(runtime_info) and isinstance(runtime_info[idx], dict) else {} - meta_info = info.get("meta", {}) if info else {} - req_id = self._as_str(meta_info.get("req_id")) - stream_finished = self._as_bool(meta_info.get("stream_finished")) - embed_info = info.get("embed", {}) if info else {} - speech_token = self._as_tensor(embed_info.get("speech_token")) if embed_info else None - speech_feat = self._as_tensor(embed_info.get("speech_feat")) if embed_info else None - embedding = self._as_tensor(embed_info.get("embedding")) if embed_info else None + raw = runtime_info[idx] if idx < len(runtime_info) and isinstance(runtime_info[idx], dict) else {} + payload = to_struct(raw) + meta = payload.meta + embed = payload.embed + + req_id = meta.req_id[0] if (meta and meta.req_id) else None + stream_finished = ( + bool(meta.stream_finished.item()) if (meta and meta.stream_finished is not None) else False + ) + speech_token = embed.speech_token if embed else None + speech_feat = embed.speech_feat if embed else None + embedding = embed.embedding if embed else None if speech_token is None or speech_feat is None or embedding is None: if stream_finished and req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"): with self._stream_audio_cache_lock: self._stream_vocoder_cache_by_req.pop(req_id, None) audios[idx] = self._stitch_stream_audio(req_id, empty_audio, stream_finished) - if ( - req_ids.numel() > 0 - and info - and ("left_context_size" in info.get("meta", {}) or "generated_len" in info) + if req_ids.numel() > 0 and ( + (meta and meta.left_context_size is not None) or payload.generated_len is not None ): - info_keys = ",".join(sorted(info.keys())) if info else "" + info_keys = ",".join( + sorted(f for f in payload.__struct_fields__ if getattr(payload, f) is not None) + ) logger.warning_once( "CosyVoice3 code2wav missing prompt conditioning for non-empty codec tokens: " "raw_len=%d info_keys=%s", @@ -752,10 +719,11 @@ def forward( # `generated_len` is injected for many models by the generic # runner, so only explicit chunk-routing fields should switch # code2wav into the streaming path. - meta = info.get("meta", {}) if info else {} - uses_streaming_decode = bool(info) and ("stream_finished" in meta or "left_context_size" in meta) + uses_streaming_decode = meta and ( + meta.stream_finished is not None or meta.left_context_size is not None + ) if uses_streaming_decode: - token_offset = max(0, meta.get("left_context_size", 0)) + token_offset = max(0, meta.left_context_size or 0) cache_state = None if req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"): From 5cbb0f729f742146bc1aadb13d7d62050f3b71df Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Mon, 27 Apr 2026 03:35:21 +0000 Subject: [PATCH 40/53] DEBUG: add PCDIAG logs to prefix-cache write/read paths Signed-off-by: Divyansh Singhvi --- vllm_omni/core/prefix_cache.py | 32 +++++++++++++++++++++++++ vllm_omni/worker/gpu_ar_model_runner.py | 14 +++++++++++ 2 files changed, 46 insertions(+) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 69e7346c4c1..f45ea7e56fe 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -133,6 +133,14 @@ def update_omni_tensor_prefix_cache( multimodal_outputs, seq_len=num_tokens_padded, ) + # DEBUG: PCDIAG + logger.warning( + "[PCDIAG] update: tokens_unpadded=%d tokens_padded=%d mm_input_keys=%s mm_cache_keys=%s", + num_tokens_unpadded, + num_tokens_padded, + sorted(multimodal_outputs.keys()), + sorted(self.mm_cache_keys), + ) for mm_out_key, mm_cache in self.mm_outputs_cache.items(): if mm_out_key in multimodal_outputs: @@ -244,11 +252,27 @@ def _get_merged_tensors( start = query_start_loc[req_idx] new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) + # DEBUG: PCDIAG + logger.warning( + "[PCDIAG] merge req=%s hit=True num_scheduled=%d cached_shape=%s new_shape=%s combined_shape=%s", + req_id, + num_scheduled_tokens[req_id], + list(cached_hs.shape), + list(new_hs.shape), + list(combined_hidden_states[req_id].shape), + ) else: # cache miss for this request, pass through normally start = query_start_loc[req_idx] new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] combined_hidden_states[req_id] = new_hs + # DEBUG: PCDIAG + logger.warning( + "[PCDIAG] merge req=%s hit=False num_scheduled=%d new_shape=%s", + req_id, + num_scheduled_tokens[req_id], + list(new_hs.shape), + ) return combined_hidden_states @@ -259,6 +283,14 @@ def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch. num_computed = input_batch.num_computed_tokens_cpu[req_idx] # NOTE: vLLM only caches full blocks num_cached_blocks = num_computed // self.block_size + # DEBUG: PCDIAG + logger.warning( + "[PCDIAG] cached_blocks: req_idx=%d num_computed=%d block_size=%d num_cached_blocks=%d", + req_idx, + int(num_computed), + int(self.block_size), + int(num_cached_blocks), + ) # Get the block IDs attached to this cache hit and reindex into # the flattened cached hidden states (i.e., 1 row per token). return input_batch.block_table[0].block_table.cpu[req_idx, :num_cached_blocks] diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index e054bb13a5f..02fb9ebfcaa 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -903,6 +903,20 @@ def _unwrap_lists(v): for mm_key in combined_multimodal_outputs.keys(): mm_payload[mm_key] = _unwrap_lists(combined_multimodal_outputs[mm_key][rid]) + # DEBUG: PCDIAG + logger.warning( + "[PCDIAG] cache-hit pooler req=%s keys=%s shapes=%s", + rid, + sorted(combined_multimodal_outputs.keys()), + { + k: ( + list(mm_payload[k].shape) + if hasattr(mm_payload[k], "shape") + else type(mm_payload[k]).__name__ + ) + for k in mm_payload + }, + ) else: # Prefix cache disabled; we still need to process the data From 6180aaf6ce7ad89a52583df1291bd1470e397f5b Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Mon, 27 Apr 2026 04:11:54 +0000 Subject: [PATCH 41/53] DEBUG: log num_computed_tokens at scheduled_new_reqs Signed-off-by: Divyansh Singhvi --- vllm_omni/worker/gpu_model_runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index efad5f3d4d5..f45d3eeb47d 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -327,6 +327,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): # num_computed_tokens > 0 means that we have a hit in prefix # caching; mark it so that we can manage the hidden states # later on as needed. + # DEBUG: PCDIAG + logger.warning( + "[PCDIAG] new_req=%s num_computed=%s prefix_cache_active=%s", + req_id, + getattr(new_req_data, "num_computed_tokens", None), + self.omni_prefix_cache is not None, + ) if self.omni_prefix_cache is not None and new_req_data.num_computed_tokens > 0: self.omni_prefix_cache.add_prefix_cached_new_req_id(req_id) From dee4f2781937f487cf0a306f589af56a2e824462 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Mon, 27 Apr 2026 06:11:23 +0000 Subject: [PATCH 42/53] remove PCDIAG diagnostic logs from prefix cache + model runner Signed-off-by: Divyansh Singhvi --- vllm_omni/core/prefix_cache.py | 32 ------------------------- vllm_omni/worker/gpu_ar_model_runner.py | 14 ----------- vllm_omni/worker/gpu_model_runner.py | 7 ------ 3 files changed, 53 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index f45ea7e56fe..69e7346c4c1 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -133,14 +133,6 @@ def update_omni_tensor_prefix_cache( multimodal_outputs, seq_len=num_tokens_padded, ) - # DEBUG: PCDIAG - logger.warning( - "[PCDIAG] update: tokens_unpadded=%d tokens_padded=%d mm_input_keys=%s mm_cache_keys=%s", - num_tokens_unpadded, - num_tokens_padded, - sorted(multimodal_outputs.keys()), - sorted(self.mm_cache_keys), - ) for mm_out_key, mm_cache in self.mm_outputs_cache.items(): if mm_out_key in multimodal_outputs: @@ -252,27 +244,11 @@ def _get_merged_tensors( start = query_start_loc[req_idx] new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) - # DEBUG: PCDIAG - logger.warning( - "[PCDIAG] merge req=%s hit=True num_scheduled=%d cached_shape=%s new_shape=%s combined_shape=%s", - req_id, - num_scheduled_tokens[req_id], - list(cached_hs.shape), - list(new_hs.shape), - list(combined_hidden_states[req_id].shape), - ) else: # cache miss for this request, pass through normally start = query_start_loc[req_idx] new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] combined_hidden_states[req_id] = new_hs - # DEBUG: PCDIAG - logger.warning( - "[PCDIAG] merge req=%s hit=False num_scheduled=%d new_shape=%s", - req_id, - num_scheduled_tokens[req_id], - list(new_hs.shape), - ) return combined_hidden_states @@ -283,14 +259,6 @@ def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch. num_computed = input_batch.num_computed_tokens_cpu[req_idx] # NOTE: vLLM only caches full blocks num_cached_blocks = num_computed // self.block_size - # DEBUG: PCDIAG - logger.warning( - "[PCDIAG] cached_blocks: req_idx=%d num_computed=%d block_size=%d num_cached_blocks=%d", - req_idx, - int(num_computed), - int(self.block_size), - int(num_cached_blocks), - ) # Get the block IDs attached to this cache hit and reindex into # the flattened cached hidden states (i.e., 1 row per token). return input_batch.block_table[0].block_table.cpu[req_idx, :num_cached_blocks] diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 02fb9ebfcaa..e054bb13a5f 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -903,20 +903,6 @@ def _unwrap_lists(v): for mm_key in combined_multimodal_outputs.keys(): mm_payload[mm_key] = _unwrap_lists(combined_multimodal_outputs[mm_key][rid]) - # DEBUG: PCDIAG - logger.warning( - "[PCDIAG] cache-hit pooler req=%s keys=%s shapes=%s", - rid, - sorted(combined_multimodal_outputs.keys()), - { - k: ( - list(mm_payload[k].shape) - if hasattr(mm_payload[k], "shape") - else type(mm_payload[k]).__name__ - ) - for k in mm_payload - }, - ) else: # Prefix cache disabled; we still need to process the data diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index f45d3eeb47d..efad5f3d4d5 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -327,13 +327,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): # num_computed_tokens > 0 means that we have a hit in prefix # caching; mark it so that we can manage the hidden states # later on as needed. - # DEBUG: PCDIAG - logger.warning( - "[PCDIAG] new_req=%s num_computed=%s prefix_cache_active=%s", - req_id, - getattr(new_req_data, "num_computed_tokens", None), - self.omni_prefix_cache is not None, - ) if self.omni_prefix_cache is not None and new_req_data.num_computed_tokens > 0: self.omni_prefix_cache.add_prefix_cached_new_req_id(req_id) From f6e3199fd0c18871d06773e14dc90309f19bb4b7 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Mon, 27 Apr 2026 07:18:54 +0000 Subject: [PATCH 43/53] return OmniPayloadStruct from async_chunk producers, drop to_dict at function returns Signed-off-by: Divyansh Singhvi --- .../test_cosyvoice3_stage_input_processors.py | 54 +++++++++--------- .../test_mimo_audio_flush_remaining_codes.py | 18 +++--- .../test_qwen3_tts_async_chunk.py | 32 +++++------ .../test_voxcpm_async_chunk.py | 8 +-- .../test_voxtral_tts_async_chunk.py | 14 ++--- .../chunk_transfer_adapter.py | 4 +- .../stage_input_processors/cosyvoice3.py | 43 ++++++--------- .../stage_input_processors/fish_speech.py | 26 ++++----- .../stage_input_processors/mimo_audio.py | 55 ++++++++----------- .../stage_input_processors/qwen3_omni.py | 20 +++---- .../stage_input_processors/qwen3_tts.py | 13 ++--- .../stage_input_processors/voxcpm.py | 22 +++----- .../stage_input_processors/voxtral_tts.py | 27 ++++----- 13 files changed, 152 insertions(+), 184 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py index 11e692d46c8..57cfd0a65dc 100644 --- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py +++ b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py @@ -85,14 +85,14 @@ def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): ) assert payload is not None - assert payload["meta"]["finished"].item() is True - assert payload["codes"]["audio"].tolist() == [1, 2, 3] - assert payload["meta"]["left_context_size"] == 0 - assert payload["meta"]["req_id"] == ["rid-0"] - assert payload["meta"]["stream_finished"].item() is True - assert "speech_token" in payload["embed"] - assert "speech_feat" in payload["embed"] - assert "embedding" in payload["embed"] + assert payload.meta.finished.item() is True + assert payload.codes.audio.tolist() == [1, 2, 3] + assert payload.meta.left_context_size == 0 + assert payload.meta.req_id == ["rid-0"] + assert payload.meta.stream_finished.item() is True + assert payload.embed.speech_token is not None + assert payload.embed.speech_feat is not None + assert payload.embed.embedding is not None def test_talker2code2wav_async_chunk_emits_eof_when_finished_without_valid_codes(): @@ -112,8 +112,8 @@ def test_talker2code2wav_async_chunk_emits_eof_when_finished_without_valid_codes ) assert payload is not None - assert payload["codes"]["audio"].tolist() == [] - assert payload["meta"]["finished"].item() is True + assert payload.codes.audio.tolist() == [] + assert payload.meta.finished.item() is True def test_talker2code2wav_async_chunk_does_not_reemit_without_new_tokens(): @@ -139,8 +139,8 @@ def test_talker2code2wav_async_chunk_does_not_reemit_without_new_tokens(): ) assert payload1 is not None - assert payload1["codes"]["audio"].tolist() == [1, 2] - assert payload1["meta"]["left_context_size"] == 0 + assert payload1.codes.audio.tolist() == [1, 2] + assert payload1.meta.left_context_size == 0 assert payload2 is None @@ -169,9 +169,9 @@ def test_talker2code2wav_async_chunk_waits_for_prelookahead_and_emits_cumulative assert payload_pending is None assert payload_ready is not None - assert payload_ready["codes"]["audio"].tolist() == [1, 2, 3] - assert payload_ready["meta"]["left_context_size"] == 0 - assert payload_ready["meta"]["finished"].item() is False + assert payload_ready.codes.audio.tolist() == [1, 2, 3] + assert payload_ready.meta.left_context_size == 0 + assert payload_ready.meta.finished.item() is False def test_talker2code2wav_async_chunk_final_flush_uses_previous_token_offset(): @@ -198,13 +198,13 @@ def test_talker2code2wav_async_chunk_final_flush_uses_previous_token_offset(): ) assert payload_stream is not None - assert payload_stream["meta"]["finished"].item() is False - assert payload_stream["codes"]["audio"].tolist() == [3, 4, 5] - assert payload_stream["meta"]["left_context_size"] == 0 + assert payload_stream.meta.finished.item() is False + assert payload_stream.codes.audio.tolist() == [3, 4, 5] + assert payload_stream.meta.left_context_size == 0 assert payload_final is not None - assert payload_final["meta"]["finished"].item() is True - assert payload_final["codes"]["audio"].tolist() == [3, 4, 5, 6] - assert payload_final["meta"]["left_context_size"] == 2 + assert payload_final.meta.finished.item() is True + assert payload_final.codes.audio.tolist() == [3, 4, 5, 6] + assert payload_final.meta.left_context_size == 2 def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): @@ -234,8 +234,8 @@ def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): assert payload_pending is None assert payload_ready is not None - assert payload_ready["codes"]["audio"].tolist() == [8, 9, 10, 11] - assert payload_ready["meta"]["left_context_size"] == 0 + assert payload_ready.codes.audio.tolist() == [8, 9, 10, 11] + assert payload_ready.meta.left_context_size == 0 def test_talker2code2wav_async_chunk_emits_terminal_eof_without_duplicate_audio(): @@ -261,8 +261,8 @@ def test_talker2code2wav_async_chunk_emits_terminal_eof_without_duplicate_audio( ) assert payload_stream is not None - assert payload_stream["meta"]["finished"].item() is False - assert payload_stream["codes"]["audio"].tolist() == [3, 4] + assert payload_stream.meta.finished.item() is False + assert payload_stream.codes.audio.tolist() == [3, 4] assert payload_final is not None - assert payload_final["meta"]["finished"].item() is True - assert payload_final["codes"]["audio"].tolist() == [] + assert payload_final.meta.finished.item() is True + assert payload_final.codes.audio.tolist() == [] diff --git a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py index ceedd6f8d5e..acdd95372b8 100644 --- a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py +++ b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py @@ -20,16 +20,16 @@ def test_flush_remaining_codes_when_no_codes_accumulated_missing_request_id(): """No entry for request_id: treat as empty, return finished sentinel with empty codes.""" tm = SimpleNamespace(code_prompt_token_ids={}) out = _flush_remaining_codes(tm, "missing", chunk_size=3, left_context_size=3) - assert out["codes"]["audio"].tolist() == _sentinel()["codes"]["audio"] - assert out["meta"]["finished"].item() is True + assert out.codes.audio.tolist() == _sentinel()["codes"]["audio"] + assert out.meta.finished.item() is True def test_flush_remaining_codes_when_no_codes_accumulated_empty_list(): """Explicit empty accumulation list returns the same sentinel.""" tm = SimpleNamespace(code_prompt_token_ids={"r": []}) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) - assert out["codes"]["audio"].tolist() == [] - assert out["meta"]["finished"].item() is True + assert out.codes.audio.tolist() == [] + assert out.meta.finished.item() is True def test_flush_remaining_codes_partial_chunk_remaining(): @@ -41,8 +41,8 @@ def test_flush_remaining_codes_partial_chunk_remaining(): code_prompt_token_ids={"r": [[1], [2], [3], [4], [5], [6], [7]]}, ) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) - assert out["meta"]["finished"].item() is True - assert out["codes"]["audio"].tolist() == [4, 5, 6, 7] + assert out.meta.finished.item() is True + assert out.codes.audio.tolist() == [4, 5, 6, 7] def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size(): @@ -52,7 +52,7 @@ def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size(): ) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) # context_length = chunk_size = 3, end_index = min(6, 6) -> all 6 - assert out["codes"]["audio"].tolist() == [1, 2, 3, 4, 5, 6] + assert out.codes.audio.tolist() == [1, 2, 3, 4, 5, 6] @pytest.mark.parametrize( @@ -74,5 +74,5 @@ def test_flush_remaining_codes_context_window_end_index( tm = SimpleNamespace(code_prompt_token_ids={"r": accumulated}) out = _flush_remaining_codes(tm, "r", chunk_size=chunk_size, left_context_size=left_context) expected_flat = list(range(length - expected_end_index, length)) - assert out["codes"]["audio"].tolist() == expected_flat - assert out["meta"]["finished"].item() is True + assert out.codes.audio.tolist() == expected_flat + assert out.meta.finished.item() is True diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index feedde9e556..ff204a30190 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -79,8 +79,8 @@ def test_eof_marker_when_finished_empty(): request=_req("r", finished=True), is_finished=True, ) - assert p["codes"] == {"audio": []} - assert p["meta"]["finished"].item() is True + assert p.codes.audio.tolist() == [] + assert p.meta.finished.item() is True def test_flush_on_finish(): @@ -93,8 +93,8 @@ def test_flush_on_finish(): is_finished=True, ) assert p is not None - assert p["meta"]["finished"].item() is True - assert len(p["codes"]["audio"]) == _Q * 24 + assert p.meta.finished.item() is True + assert len(p.codes.audio) == _Q * 24 _CASES = [ @@ -159,8 +159,8 @@ def test_streaming_phases(config, n_frames, finished, expected): else: exp_ctx, exp_window = expected assert payload is not None - assert payload["meta"]["left_context_size"] == exp_ctx - assert len(payload["codes"]["audio"]) == _Q * exp_window + assert payload.meta.left_context_size == exp_ctx + assert len(payload.codes.audio) == _Q * exp_window def test_dynamic_ic_adapts_to_load(): @@ -170,14 +170,14 @@ def test_dynamic_ic_adapts_to_load(): # Low load (1/8) -> IC=2 -> emit at 2 p1 = _call(tm, "r", n_frames=2) assert p1 is not None - assert len(p1["codes"]["audio"]) == _Q * 2 + assert len(p1.codes.audio) == _Q * 2 # High load: add 4 others -> active=5/8 -> IC=8 -> emit at 8 for i in range(4): tm.code_prompt_token_ids[f"other-{i}"] = [[0]] p2 = _call(tm, "r", n_frames=8) assert p2 is not None - assert len(p2["codes"]["audio"]) == _Q * 8 + assert len(p2.codes.audio) == _Q * 8 # Requests past initial phase still count in load factor tm2 = _tm(max_num_seqs=4) @@ -186,7 +186,7 @@ def test_dynamic_ic_adapts_to_load(): # active=4/4=1.0 -> IC=16 p3 = _call(tm2, "new", n_frames=16) assert p3 is not None - assert len(p3["codes"]["audio"]) == _Q * 16 + assert len(p3.codes.audio) == _Q * 16 def test_ic_load_change_mid_request(): @@ -207,7 +207,7 @@ def test_ic_load_change_mid_request(): assert _call(tm, "r", n_frames=27) is None p3 = _call(tm, "r", n_frames=49) assert p3 is not None - assert p3["meta"]["left_context_size"] == 24 + assert p3.meta.left_context_size == 24 # A *new* request under high load gets IC=16 (not IC=2). # Frame 2 would emit under IC=2 but must hold under IC=16. @@ -261,8 +261,8 @@ def test_first_streaming_chunk_prepends_ref_code_context(): ) assert payload is not None - assert payload["meta"]["left_context_size"] == 2 - assert len(payload["codes"]["audio"]) == _Q * 12 + assert payload.meta.left_context_size == 2 + assert len(payload.codes.audio) == _Q * 12 def test_ref_code_context_applies_to_all_streaming_chunks(): @@ -283,8 +283,8 @@ def test_ref_code_context_applies_to_all_streaming_chunks(): assert payload is not None # ref_code (2 frames) prepended as left context on second chunk too - assert payload["meta"]["left_context_size"] == 10 + 2 - assert len(payload["codes"]["audio"]) == _Q * (20 + 2) + assert payload.meta.left_context_size == 10 + 2 + assert len(payload.codes.audio) == _Q * (20 + 2) def test_ref_code_context_can_be_buffered_before_first_emit(): @@ -318,8 +318,8 @@ def test_ref_code_context_can_be_buffered_before_first_emit(): assert payload is not None # ref_code (2 frames) is kept (not popped) for subsequent chunks - assert payload["meta"]["left_context_size"] == 2 - assert len(payload["codes"]["audio"]) == _Q * 12 + assert payload.meta.left_context_size == 2 + assert len(payload.codes.audio) == _Q * 12 assert rid in tm.request_payload diff --git a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py index 806bf4b153e..6738c763aec 100644 --- a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py @@ -57,8 +57,8 @@ def test_latent2vae_async_chunk_serializes_latent_payload(): ) assert payload is not None - assert torch.equal(payload["meta"]["finished"], torch.tensor(False, dtype=torch.bool)) - recovered = _decode_serialized_latent(payload["codes"]["audio"].tolist()) + assert torch.equal(payload.meta.finished, torch.tensor(False, dtype=torch.bool)) + recovered = _decode_serialized_latent(payload.codes.audio.tolist()) torch.testing.assert_close(recovered, latent.to(torch.bfloat16).to(torch.float32).unsqueeze(0)) @@ -70,8 +70,8 @@ def test_latent2vae_async_chunk_returns_terminal_marker_without_latent(): is_finished=False, ) - assert payload["codes"]["audio"].tolist() == [] - assert torch.equal(payload["meta"]["finished"], torch.tensor(True, dtype=torch.bool)) + assert payload.codes.audio.tolist() == [] + assert torch.equal(payload.meta.finished, torch.tensor(True, dtype=torch.bool)) def test_latent2vae_async_chunk_returns_none_for_nonterminal_empty_chunk(): diff --git a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py index 1b41769c798..52e478dad66 100644 --- a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py @@ -126,8 +126,8 @@ def test_flush_tail_when_finished(): ) assert payload is not None - assert payload["meta"]["finished"].item() is True - codes = payload["codes"]["audio"] + assert payload.meta.finished.item() is True + codes = payload.codes.audio # Format: [ctx_frames, context_length, ...flat_codes] assert len(codes) >= 2 # At least ctx_frames + context_length header ctx_frames = codes[0] @@ -149,8 +149,8 @@ def test_eof_marker_when_finished_with_no_frames(): request=request, ) - assert payload["codes"]["audio"].tolist() == [] - assert payload["meta"]["finished"].item() is True + assert payload.codes.audio.tolist() == [] + assert payload.meta.finished.item() is True def test_normal_chunk_emission(): @@ -174,7 +174,7 @@ def test_normal_chunk_emission(): # A chunk should be emitted assert payload is not None - codes = payload["codes"]["audio"] + codes = payload.codes.audio ctx_frames = codes[0] context_length = codes[1] assert ctx_frames == 20 # 25 - 5(chunk_size_at_begin) @@ -201,7 +201,7 @@ def test_small_initial_chunks(): ) assert payload is not None - codes = payload["codes"]["audio"] + codes = payload.codes.audio ctx_frames = codes[0] context_length = codes[1] assert ctx_frames == 0 @@ -250,7 +250,7 @@ def test_context_handling_format(): ) assert payload is not None - codes = payload["codes"]["audio"] + codes = payload.codes.audio # codes is a 1-D long tensor: [ctx_frames, context_length, ...flat_codes] ctx_frames = int(codes[0].item()) context_length = int(codes[1].item()) diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 1b33af1deed..14bc758f11e 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -249,7 +249,7 @@ def _send_single_request(self, task: dict): except Exception as e: logger.error(f"Failed to use custom_process_input_func for payload extraction: {e}") - if not payload_data: + if payload_data is None: return success, size, metadata = self.connector.put( @@ -262,7 +262,7 @@ def _send_single_request(self, task: dict): if success: self.put_req_chunk[external_req_id] += 1 logger.debug(f"[Stage-{stage_id}] Sent {connector_put_key}") - finished_flag = payload_data.get("meta", {}).get("finished", payload_data.get("finished")) + finished_flag = payload_data.meta.finished if payload_data.meta is not None else None is_payload_finished = False if isinstance(finished_flag, torch.Tensor): is_payload_finished = finished_flag.numel() == 1 and bool(finished_flag.item()) diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 54cabf93f2b..9e003630135 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -13,7 +13,6 @@ EmbeddingsStruct, MetaStruct, OmniPayloadStruct, - to_dict, ) from vllm_omni.inputs.data import OmniTokensPrompt @@ -113,7 +112,7 @@ def talker2code2wav_async_chunk( pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayloadStruct | None: """CosyVoice3 async_chunk processor: talker token stream -> code2wav chunks.""" with nullcontext(): request_id = request.external_req_id @@ -208,12 +207,10 @@ def talker2code2wav_async_chunk( embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True state["terminal_sent"] = True - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), - meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), - embed=embed_struct, - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + embed=embed_struct, ) emitted_token_len = int(state.get("emitted_token_len", 0)) @@ -223,12 +220,10 @@ def talker2code2wav_async_chunk( embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True state["terminal_sent"] = True - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), - meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), - embed=embed_struct, - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + embed=embed_struct, ) with nullcontext(): @@ -258,17 +253,15 @@ def talker2code2wav_async_chunk( embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True - payload = to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.tensor(code_predictor_codes, dtype=torch.long)), - meta=MetaStruct( - finished=torch.tensor(finished, dtype=torch.bool), - stream_finished=torch.tensor(finished, dtype=torch.bool), - req_id=[request_id], - left_context_size=token_offset, - ), - embed=embed_struct, - ) + payload = OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(code_predictor_codes, dtype=torch.long)), + meta=MetaStruct( + finished=torch.tensor(finished, dtype=torch.bool), + stream_finished=torch.tensor(finished, dtype=torch.bool), + req_id=[request_id], + left_context_size=token_offset, + ), + embed=embed_struct, ) if not finished: diff --git a/vllm_omni/model_executor/stage_input_processors/fish_speech.py b/vllm_omni/model_executor/stage_input_processors/fish_speech.py index 4416f4bb58e..5a10f08326f 100644 --- a/vllm_omni/model_executor/stage_input_processors/fish_speech.py +++ b/vllm_omni/model_executor/stage_input_processors/fish_speech.py @@ -8,9 +8,7 @@ from vllm_omni.data_entry_keys import ( CodesStruct, MetaStruct, - OmniPayload, OmniPayloadStruct, - to_dict, ) logger = init_logger(__name__) @@ -69,7 +67,7 @@ def slow_ar_to_dac_decoder_async_chunk( pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> OmniPayload | None: +) -> OmniPayloadStruct | None: """Async streaming processor: emit code chunks as they are produced. Accumulates per-step codes and emits fixed-size chunks with left context @@ -117,10 +115,10 @@ def slow_ar_to_dac_decoder_async_chunk( if length <= 0: if finished: - return { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) return None in_initial_phase = initial_chunk_size > 0 and length <= chunk_size @@ -150,12 +148,10 @@ def slow_ar_to_dac_decoder_async_chunk( stacked_frames = torch.stack(window_frames, dim=0) code_predictor_codes = stacked_frames.transpose(0, 1).reshape(-1) - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=code_predictor_codes), - meta=MetaStruct( - left_context_size=left_context_size, - finished=torch.tensor(finished, dtype=torch.bool), - ), - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=code_predictor_codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(finished, dtype=torch.bool), + ), ) diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index cb376ddf4c4..8f8744f1250 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -9,7 +9,6 @@ MetaStruct, OmniPayload, OmniPayloadStruct, - to_dict, ) from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID @@ -60,13 +59,11 @@ def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torc return y_col_major -def _make_finished_sentinel() -> dict[str, Any]: +def _make_finished_sentinel() -> OmniPayloadStruct: """Return a minimal payload with finished=True so Stage-1 can end the request.""" - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), - meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), ) @@ -75,7 +72,7 @@ def _flush_remaining_codes( request_id: str, chunk_size: int, left_context_size: int, -) -> dict[str, Any]: +) -> OmniPayloadStruct: """Flush any accumulated but unsent codes when the request finishes.""" accumulated = transfer_manager.code_prompt_token_ids.get(request_id, []) if not accumulated: @@ -90,17 +87,15 @@ def _flush_remaining_codes( left_ctx_frames = max(0, min(length - context_length, left_context_size)) flat_codes = torch.tensor(accumulated[-end_index:]).reshape(-1) - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=flat_codes), - meta=MetaStruct( - left_context_size=left_ctx_frames, - codec_chunk_frames=chunk_size, - codec_left_context_frames=left_context_size, - code_flat_numel=int(flat_codes.numel()), - finished=torch.tensor(True, dtype=torch.bool), - ), - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=flat_codes), + meta=MetaStruct( + left_context_size=left_ctx_frames, + codec_chunk_frames=chunk_size, + codec_left_context_frames=left_context_size, + code_flat_numel=int(flat_codes.numel()), + finished=torch.tensor(True, dtype=torch.bool), + ), ) @@ -131,7 +126,7 @@ def llm2code2wav_async_chunk( pooling_output: OmniPayload, request: Any, is_finished: bool = False, -) -> OmniPayload | None: +) -> OmniPayloadStruct | None: """ Async chunk version: convert stage-0 pooling_output to code2wav payload (pooling / connector accumulation). @@ -181,17 +176,15 @@ def llm2code2wav_async_chunk( left_ctx_frames = max(0, min(length - context_length, left_context_size)) flat_codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).reshape(-1).tolist() - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.tensor(flat_codes)), - meta=MetaStruct( - left_context_size=left_ctx_frames, - codec_chunk_frames=chunk_size, - codec_left_context_frames=left_context_size, - code_flat_numel=len(flat_codes), - finished=torch.tensor(is_finished, dtype=torch.bool), - ), - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(flat_codes)), + meta=MetaStruct( + left_context_size=left_ctx_frames, + codec_chunk_frames=chunk_size, + codec_left_context_frames=left_context_size, + code_flat_numel=len(flat_codes), + finished=torch.tensor(is_finished, dtype=torch.bool), + ), ) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 39ea86f3af1..9f2b548bb36 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -299,7 +299,7 @@ def thinker2talker_async_chunk( pooling_output: OmniPayload, request: OmniEngineCoreRequest, is_finished: bool = False, -) -> OmniPayload | None: +) -> OmniPayloadStruct | None: """ Process thinker outputs to create talker inputs. 1. thinker's text generation outputs (token IDs + hidden states) @@ -364,7 +364,7 @@ def thinker2talker_async_chunk( speaker=speaker, language=language, ) - return to_dict(payload) + return payload def thinker2talker( @@ -485,7 +485,7 @@ def talker2code2wav_async_chunk( pooling_output: OmniPayload, request: OmniEngineCoreRequest, is_finished: bool = False, -) -> OmniPayload | None: +) -> OmniPayloadStruct | None: """ Pooling version. """ @@ -536,14 +536,12 @@ def talker2code2wav_async_chunk( codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).transpose(0, 1).reshape(-1) - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=codes), - meta=MetaStruct( - left_context_size=left_context_size, - finished=torch.tensor(is_finished, dtype=torch.bool), - ), - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(is_finished, dtype=torch.bool), + ), ) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 2a6c88ecae8..abd3fc1c02c 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -141,7 +141,7 @@ def talker2code2wav_async_chunk( pooling_output: OmniPayload | None, request: Any, is_finished: bool = False, -) -> OmniPayload | None: +) -> OmniPayloadStruct | None: request_id = request.external_req_id finished = bool(is_finished or request.is_finished()) request_payload = getattr(transfer_manager, "request_payload", None) @@ -212,10 +212,10 @@ def talker2code2wav_async_chunk( if length <= 0: if finished: - return { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) return None in_initial_phase = initial_chunk_size > 0 and initial_chunk_size < chunk_size and length <= chunk_size @@ -262,7 +262,7 @@ def talker2code2wav_async_chunk( dtype=torch.long, ) - payload = OmniPayloadStruct( + return OmniPayloadStruct( codes=CodesStruct(audio=code_predictor_codes), meta=MetaStruct( left_context_size=left_context_size, @@ -271,4 +271,3 @@ def talker2code2wav_async_chunk( speaker=extract_speaker_from_request(request), language=extract_language_from_request(request), ) - return to_dict(payload) diff --git a/vllm_omni/model_executor/stage_input_processors/voxcpm.py b/vllm_omni/model_executor/stage_input_processors/voxcpm.py index 666dae4fb45..e238b7b9de8 100644 --- a/vllm_omni/model_executor/stage_input_processors/voxcpm.py +++ b/vllm_omni/model_executor/stage_input_processors/voxcpm.py @@ -8,9 +8,7 @@ from vllm_omni.data_entry_keys import ( CodesStruct, MetaStruct, - OmniPayload, OmniPayloadStruct, - to_dict, ) from vllm_omni.inputs.data import OmniTokensPrompt @@ -95,12 +93,10 @@ def latent2vae( return vae_inputs -def _eof_payload() -> OmniPayload: - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), - meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), - ) +def _eof_payload() -> OmniPayloadStruct: + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), ) @@ -109,7 +105,7 @@ def latent2vae_async_chunk( pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> OmniPayload | None: +) -> OmniPayloadStruct | None: """Stage-0 latent → stage-1 VAE under ``async_chunk`` (connector payload).""" # Kept for callback signature compatibility with OmniChunkTransferAdapter. _ = transfer_manager @@ -127,9 +123,7 @@ def latent2vae_async_chunk( return _eof_payload() if finished_request else None serialized_codes = _serialize_latent_to_codes(latent) - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.tensor(serialized_codes, dtype=torch.long)), - meta=MetaStruct(finished=torch.tensor(finished_request, dtype=torch.bool)), - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(serialized_codes, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(finished_request, dtype=torch.bool)), ) diff --git a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py index 7c878235ee6..06291d19098 100644 --- a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py @@ -9,7 +9,6 @@ MetaStruct, OmniPayload, OmniPayloadStruct, - to_dict, ) from vllm_omni.inputs.data import OmniTokensPrompt @@ -58,7 +57,7 @@ def generator2tokenizer_async_chunk( pooling_output: OmniPayload, request: Any, is_finished: bool = False, -) -> OmniPayload | None: +) -> OmniPayloadStruct | None: request_id = request.external_req_id finished = bool(is_finished or request.is_finished()) @@ -88,11 +87,9 @@ def generator2tokenizer_async_chunk( # finished and nothing was produced, emit an EOF marker. if length <= 0: if finished: - return to_dict( - OmniPayloadStruct( - codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), - meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), - ) + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), ) return None @@ -113,14 +110,12 @@ def generator2tokenizer_async_chunk( # Pack context + chunk into codebook-major flat codes for adapter. code_predictor_codes = torch.tensor(window_frames).reshape(-1).tolist() - return to_dict( - OmniPayloadStruct( - codes=CodesStruct( - audio=torch.tensor( - [int(ctx_frames), int(context_length)] + code_predictor_codes, - dtype=torch.long, - ), + return OmniPayloadStruct( + codes=CodesStruct( + audio=torch.tensor( + [int(ctx_frames), int(context_length)] + code_predictor_codes, + dtype=torch.long, ), - meta=MetaStruct(finished=torch.tensor(finished, dtype=torch.bool)), - ) + ), + meta=MetaStruct(finished=torch.tensor(finished, dtype=torch.bool)), ) From 4d700164e86bbaa8e3535d4aa9b0a79795685b12 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Mon, 27 Apr 2026 08:48:35 +0000 Subject: [PATCH 44/53] fix chunk_transfer_adapter test fixture: producer mock returns OmniPayloadStruct Signed-off-by: Divyansh Singhvi --- .../omni_connectors/test_chunk_transfer_adapter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py index 22f7c268be2..bdbdf332374 100644 --- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py +++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py @@ -12,7 +12,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm.v1.request import RequestStatus -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import MetaStruct, OmniPayload, OmniPayloadStruct from vllm_omni.distributed.omni_connectors.transfer_adapter.base import OmniTransferAdapterBase from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, @@ -144,7 +144,9 @@ def test_send_single_request_cleans_up_after_finished_payload(build_adapter, mon adapter, _ = build_adapter(stage_id=1) request = _req("req-finished", RequestStatus.FINISHED_STOPPED, external_req_id="ext-finished") - adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": True} + adapter.custom_process_next_stage_input_func = lambda **kwargs: OmniPayloadStruct( + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)) + ) cleanup_calls = [] monkeypatch.setattr(adapter, "cleanup", lambda *a, **kw: cleanup_calls.append((a, kw))) From a479da4e748ffadb6374df44573424a163a93f6a Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Mon, 27 Apr 2026 09:00:41 +0000 Subject: [PATCH 45/53] add wire-equivalence test for OmniPayloadStruct vs to_dict Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index ae27cfe0858..6d0b2797823 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -166,6 +166,66 @@ def test_decode_rejects_unknown_field(self): decode_payload(wire) +class TestWireEquivalenceStructVsDict: + """Producer return-side migration invariant: encoding an ``OmniPayloadStruct`` + via the connector serializer must decode to the same payload as encoding + the equivalent ``to_dict(struct)``. + + Guards against regressions where the wire format diverges between the two + paths (e.g. msgspec adds a struct tag, or ``to_dict`` drops a non-default + sub-field that ``omit_defaults`` retains). + """ + + @staticmethod + def _round_trip(obj): + from vllm_omni.distributed.omni_connectors.utils.serialization import ( + OmniMsgpackDecoder, + OmniMsgpackEncoder, + ) + + return OmniMsgpackDecoder().decode(OmniMsgpackEncoder().encode(obj)) + + @staticmethod + def _assert_decoded_equal(a, b): + if isinstance(a, dict): + assert isinstance(b, dict) + assert sorted(a.keys()) == sorted(b.keys()) + for k in a: + TestWireEquivalenceStructVsDict._assert_decoded_equal(a[k], b[k]) + elif isinstance(a, torch.Tensor): + assert isinstance(b, torch.Tensor) + assert a.dtype == b.dtype + assert a.shape == b.shape + assert torch.equal(a, b) + elif isinstance(a, list): + assert isinstance(b, list) and len(a) == len(b) + for x, y in zip(a, b, strict=True): + TestWireEquivalenceStructVsDict._assert_decoded_equal(x, y) + else: + assert a == b + + def test_basic_payload(self): + struct = OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor([1, 2, 3], dtype=torch.long)), + meta=MetaStruct(left_context_size=10, finished=torch.tensor(True, dtype=torch.bool)), + ) + self._assert_decoded_equal(self._round_trip(struct), self._round_trip(to_dict(struct))) + + def test_nested_sub_structs(self): + # Exercises depth-2 sub-struct encoding (embed.*) which was the case that + # exposed schema drift in the #1829 migration. + struct = OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor([5, 6], dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(False, dtype=torch.bool)), + embed=EmbeddingsStruct( + speech_token=torch.tensor([[1, 2, 3]]), + speech_feat=torch.tensor([[[0.1, 0.2], [0.3, 0.4]]]), + embedding=torch.tensor([[0.5, 0.6]]), + ), + ) + self._assert_decoded_equal(self._round_trip(struct), self._round_trip(to_dict(struct))) + + class TestFlattenPayload: def test_basic_nested_to_dotted(self): nested = { From 78bf29811de8921ace560f4d0b629daa245ec3ce Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 28 Apr 2026 08:18:15 +0000 Subject: [PATCH 46/53] type producer hook + pin struct/dict wire asymmetry Address PR #3149 review claims 1 and 2: - Claim 1: type-hint `custom_process_next_stage_input_func` and the `payload_data` local in `_send_single_request` as `OmniPayloadStruct | None`, so the contract producers must satisfy is checkable instead of implicit. - Claim 2: the chunk-adapter sender uses struct attribute access (`payload_data.meta.finished`) while the receive path reads dict keys. That works only because `OmniMsgpackDecoder` is type-erased (no target type) and round-trips struct -> dict. Add a comment at the sender and a `test_wire_round_trip_struct_to_dict_contract` test that pins the contract: future schema changes that break this round-trip will fail loudly. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Divyansh Singhvi --- .../test_chunk_transfer_adapter.py | 39 +++++++++++++++++++ .../chunk_transfer_adapter.py | 12 ++++-- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py index bdbdf332374..e90275bbbd8 100644 --- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py +++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py @@ -606,3 +606,42 @@ def _super_finish(_self, request_ids, finished_status): OmniARScheduler.finish_requests(sched, ["r1"], RequestStatus.FINISHED_ABORTED) assert order == ["adapter", "super"] + + +def test_wire_round_trip_struct_to_dict_contract(): + """Pin the wire contract: encoding ``OmniPayloadStruct`` and decoding it + yields a dict equivalent to ``to_dict(struct)``. + + The chunk-adapter sender uses struct attribute access while the receiver + uses dict-key access. This works only because ``OmniMsgpackDecoder`` has + no target type and decodes structs back to plain dicts. If this test + breaks, the receiver's dict access will silently drop fields or KeyError. + """ + from vllm_omni.data_entry_keys import CodesStruct, to_dict + from vllm_omni.distributed.omni_connectors.utils.serialization import ( + OmniMsgpackDecoder, + OmniMsgpackEncoder, + ) + + struct = OmniPayloadStruct( + meta=MetaStruct( + finished=torch.tensor(True, dtype=torch.bool), + left_context_size=12, + ), + codes=CodesStruct(audio=torch.tensor([1, 2, 3], dtype=torch.int64)), + ) + + encoded = OmniMsgpackEncoder().encode(struct) + decoded = OmniMsgpackDecoder().decode(encoded) + + assert isinstance(decoded, dict) + assert isinstance(decoded["meta"], dict) + assert isinstance(decoded["meta"]["finished"], torch.Tensor) + assert bool(decoded["meta"]["finished"].item()) is True + assert decoded["meta"]["left_context_size"] == 12 + assert torch.equal(decoded["codes"]["audio"], torch.tensor([1, 2, 3], dtype=torch.int64)) + + expected = to_dict(struct) + assert set(decoded.keys()) == set(expected.keys()) + assert set(decoded["meta"].keys()) == set(expected["meta"].keys()) + assert set(decoded["codes"].keys()) == set(expected["codes"].keys()) diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 14bc758f11e..8c04ca7defe 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -3,12 +3,13 @@ import importlib from collections import defaultdict, deque +from collections.abc import Callable from typing import Any import torch from vllm.v1.request import Request, RequestStatus -from vllm_omni.data_entry_keys import unflatten_payload +from vllm_omni.data_entry_keys import OmniPayloadStruct, unflatten_payload from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec @@ -43,7 +44,7 @@ def __init__(self, vllm_config: Any): super().__init__(model_config) self.model_mode = getattr(model_config, "worker_type", None) or "ar" # State specific to Chunk management - self.custom_process_next_stage_input_func = None + self.custom_process_next_stage_input_func: Callable[..., OmniPayloadStruct | None] | None = None custom_process_next_stage_input_func = getattr(model_config, "custom_process_next_stage_input_func", None) if custom_process_next_stage_input_func: module_path, func_name = custom_process_next_stage_input_func.rsplit(".", 1) @@ -236,7 +237,7 @@ def _send_single_request(self, task: dict): chunk_id = self.put_req_chunk[external_req_id] connector_put_key = f"{external_req_id}_{stage_id}_{chunk_id}" # Process payload in save_loop thread - payload_data = None + payload_data: OmniPayloadStruct | None = None if self.custom_process_next_stage_input_func: try: payload_data = self.custom_process_next_stage_input_func( @@ -262,6 +263,11 @@ def _send_single_request(self, task: dict): if success: self.put_req_chunk[external_req_id] += 1 logger.debug(f"[Stage-{stage_id}] Sent {connector_put_key}") + # Sender uses struct attr access here; the receive path in + # `_load_one_request` / `_update_request_payload` reads dict keys. + # That asymmetry is intentional: `OmniMsgpackDecoder` is type-erased + # (no target type), so the wire round-trips struct -> dict. If you + # change the schema, update both ends — see test_wire_round_trip. finished_flag = payload_data.meta.finished if payload_data.meta is not None else None is_payload_finished = False if isinstance(finished_flag, torch.Tensor): From fc4b9e95833756fc0a3bf6297d69b1d150420f15 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 28 Apr 2026 08:25:18 +0000 Subject: [PATCH 47/53] unify chunk-payload accumulators on shared depth-2 merger Address PR #3149 review claim 3: ``OmniChunkTransferAdapter._update_request_payload`` and ``OmniConnectorModelRunnerMixin._accumulate_payload`` were two separate accumulators for the same wire payload, with disagreeing semantics. The chunk-adapter did proper depth-2 merge; the mixin did shallow merge keyed on flat top-level names that no in-tree producer writes anymore (``finished``, ``override_keys``, ``thinker_decode_embeddings``), so nested ``meta``/``embed``/``codes`` dicts were silently replaced wholesale instead of merged sub-key by sub-key. - Add ``utils/payload_merge.py::merge_chunk_payloads`` as the single source of truth: depth-2 merge with ``meta.finished`` taken from incoming, ``meta.override_keys`` exemptions, tensor-cat / list-extend / scalar-replace per sub-key. Span-aware merge for ``embed.decode`` and ``embed.cached_decode`` activates when the matching ``*_token_start``/``*_token_end`` fields are present (handles overlapping/retried chunks); falls back to naive cat otherwise. - Add ``decode_token_start/end`` and ``cached_decode_token_start/end`` to ``EmbeddingsStruct`` so producers can opt into span-aware merge. - Reduce ``_update_request_payload`` and ``_accumulate_payload`` to thin wrappers calling ``merge_chunk_payloads``. - Drop now-unused ``get_tensor_span``/``merge_tensor_spans`` imports from the mixin (the helpers move into ``payload_merge``). - New ``test_payload_merge.py`` covers depth-2 cat, override_keys, finished override, span-adjacent, span-overlap-trim, naive fallback, list extension, first-chunk passthrough, and a parity test that drives both call sites with the same chunk pair and asserts equal merged output. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Divyansh Singhvi --- .../omni_connectors/test_payload_merge.py | 163 ++++++++++++++++++ vllm_omni/data_entry_keys.py | 8 + .../chunk_transfer_adapter.py | 33 ++-- .../omni_connectors/utils/payload_merge.py | 140 +++++++++++++++ .../omni_connector_model_runner_mixin.py | 81 +-------- 5 files changed, 329 insertions(+), 96 deletions(-) create mode 100644 tests/distributed/omni_connectors/test_payload_merge.py create mode 100644 vllm_omni/distributed/omni_connectors/utils/payload_merge.py diff --git a/tests/distributed/omni_connectors/test_payload_merge.py b/tests/distributed/omni_connectors/test_payload_merge.py new file mode 100644 index 00000000000..d42af0e9821 --- /dev/null +++ b/tests/distributed/omni_connectors/test_payload_merge.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm_omni.distributed.omni_connectors.utils.payload_merge import merge_chunk_payloads + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _t(values, dtype=torch.float32): + return torch.tensor(values, dtype=dtype) + + +def test_depth2_tensor_concat_under_meta_embed_codes(): + origin = { + "embed": {"decode": _t([[1.0, 2.0]])}, + "codes": {"audio": torch.tensor([1, 2], dtype=torch.int64)}, + "meta": {"left_context_size": 4}, + } + incoming = { + "embed": {"decode": _t([[3.0, 4.0]])}, + "codes": {"audio": torch.tensor([3, 4], dtype=torch.int64)}, + "meta": {"left_context_size": 5, "finished": torch.tensor(False)}, + } + + merged = merge_chunk_payloads(origin, incoming) + + assert torch.equal(merged["embed"]["decode"], _t([[1.0, 2.0], [3.0, 4.0]])) + assert torch.equal(merged["codes"]["audio"], torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + # `left_context_size` is a scalar and not in override_keys, so incoming wins. + assert merged["meta"]["left_context_size"] == 5 + # `finished` is always taken from incoming. + assert bool(merged["meta"]["finished"]) is False + + +def test_meta_finished_taken_from_incoming(): + origin = {"meta": {"finished": torch.tensor(False)}} + incoming = {"meta": {"finished": torch.tensor(True)}} + merged = merge_chunk_payloads(origin, incoming) + assert bool(merged["meta"]["finished"]) is True + + +def test_override_keys_replace_instead_of_merging(): + origin = { + "embed": {"prefill": _t([[1.0]])}, + "meta": {}, + } + incoming = { + "embed": {"prefill": _t([[9.0, 9.0]])}, + "meta": {"override_keys": [["embed", "prefill"]]}, + } + merged = merge_chunk_payloads(origin, incoming) + # prefill is replaced wholesale, not concatenated. + assert torch.equal(merged["embed"]["prefill"], _t([[9.0, 9.0]])) + + +def test_span_aware_merge_adjacent(): + origin = { + "embed": { + "decode": _t([[1.0], [2.0]]), + "decode_token_start": 10, + "decode_token_end": 12, + } + } + incoming = { + "embed": { + "decode": _t([[3.0], [4.0]]), + "decode_token_start": 12, + "decode_token_end": 14, + } + } + merged = merge_chunk_payloads(origin, incoming) + assert torch.equal(merged["embed"]["decode"], _t([[1.0], [2.0], [3.0], [4.0]])) + assert merged["embed"]["decode_token_start"] == 10 + assert merged["embed"]["decode_token_end"] == 14 + + +def test_span_aware_merge_partial_overlap_trims_incoming(): + origin = { + "embed": { + "decode": _t([[1.0], [2.0], [3.0]]), + "decode_token_start": 10, + "decode_token_end": 13, + } + } + incoming = { + "embed": { + "decode": _t([[3.5], [4.0], [5.0]]), + "decode_token_start": 12, + "decode_token_end": 15, + } + } + merged = merge_chunk_payloads(origin, incoming) + # Overlap = 13 - 12 = 1, so first row of incoming is trimmed. + assert torch.equal(merged["embed"]["decode"], _t([[1.0], [2.0], [3.0], [4.0], [5.0]])) + assert merged["embed"]["decode_token_start"] == 10 + assert merged["embed"]["decode_token_end"] == 15 + + +def test_span_aware_falls_back_to_naive_cat_without_span_metadata(): + origin = {"embed": {"decode": _t([[1.0]])}} + incoming = {"embed": {"decode": _t([[2.0]])}} + merged = merge_chunk_payloads(origin, incoming) + assert torch.equal(merged["embed"]["decode"], _t([[1.0], [2.0]])) + + +def test_list_extension_inside_subdict(): + origin = {"meta": {"some_list": [1, 2]}} + incoming = {"meta": {"some_list": [3, 4]}} + merged = merge_chunk_payloads(origin, incoming) + assert merged["meta"]["some_list"] == [1, 2, 3, 4] + + +def test_first_chunk_returns_dict_unchanged(): + incoming = {"embed": {"decode": _t([[1.0]])}} + merged = merge_chunk_payloads({}, incoming) + assert torch.equal(merged["embed"]["decode"], _t([[1.0]])) + + +def test_parity_chunk_adapter_vs_mixin_accumulator(): + """Both call sites should produce identical merged output for the same + chunk pair. This pins the contract that the two transports agree. + """ + from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( + OmniChunkTransferAdapter, + ) + + chunk1 = { + "embed": { + "decode": _t([[1.0], [2.0]]), + "decode_token_start": 0, + "decode_token_end": 2, + }, + "codes": {"audio": torch.tensor([1], dtype=torch.int64)}, + "meta": {"finished": torch.tensor(False)}, + } + chunk2 = { + "embed": { + "decode": _t([[3.0]]), + "decode_token_start": 2, + "decode_token_end": 3, + }, + "codes": {"audio": torch.tensor([2], dtype=torch.int64)}, + "meta": {"finished": torch.tensor(True)}, + } + + # Adapter path + adapter = OmniChunkTransferAdapter.__new__(OmniChunkTransferAdapter) + adapter.request_payload = {} + adapter._update_request_payload("r1", dict(chunk1)) + adapter_merged = adapter._update_request_payload("r1", dict(chunk2)) + + # Mixin path: emulate by calling merge_chunk_payloads directly the same + # way `_accumulate_payload` does (post-refactor it's a thin wrapper). + mixin_merged = merge_chunk_payloads(dict(chunk1), dict(chunk2)) + + assert torch.equal(adapter_merged["embed"]["decode"], mixin_merged["embed"]["decode"]) + assert adapter_merged["embed"]["decode_token_start"] == mixin_merged["embed"]["decode_token_start"] + assert adapter_merged["embed"]["decode_token_end"] == mixin_merged["embed"]["decode_token_end"] + assert torch.equal(adapter_merged["codes"]["audio"], mixin_merged["codes"]["audio"]) + assert bool(adapter_merged["meta"]["finished"]) == bool(mixin_merged["meta"]["finished"]) diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 1dc7d8432cc..20a044f9ca1 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -113,7 +113,15 @@ class EmbeddingsStruct(_StructBase): prefill: torch.Tensor | None = None prefill_shape: list[int] | None = None decode: torch.Tensor | None = None + # Token-position span for `decode`: half-open [start, end). + # Producers may set these per-chunk to enable span-aware merge across + # chunks (see merge_chunk_payloads). When absent, the accumulator falls + # back to naive tensor concat. + decode_token_start: int | None = None + decode_token_end: int | None = None cached_decode: torch.Tensor | None = None + cached_decode_token_start: int | None = None + cached_decode_token_end: int | None = None tts_bos: torch.Tensor | None = None tts_eos: torch.Tensor | None = None tts_pad: torch.Tensor | None = None diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 8c04ca7defe..b6ae06b2237 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -14,6 +14,7 @@ from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec from ..utils.logging import get_connector_logger +from ..utils.payload_merge import merge_chunk_payloads from .base import OmniTransferAdapterBase logger = get_connector_logger(__name__) @@ -199,32 +200,18 @@ def _poll_single_request(self, request: Request): return False def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: - """Update the stored payload for *req_id* with the latest chunk.""" + """Update the stored payload for *req_id* with the latest chunk. + + Delegates to ``merge_chunk_payloads`` for depth-2 merge semantics so + the chunk-adapter accumulator agrees with + ``OmniConnectorModelRunnerMixin._accumulate_payload``. + """ if req_id not in self.request_payload: self.request_payload[req_id] = payload_data return payload_data - origin = self.request_payload[req_id] - raw_ok = payload_data.get("meta", {}).pop("override_keys", []) - override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok} - - for type_key, new_val in payload_data.items(): - if not isinstance(new_val, dict): - continue - origin_sub = origin.get(type_key) - if not isinstance(origin_sub, dict): - continue - for qual, value in new_val.items(): - if type_key == "meta" and qual == "finished": - continue - if (type_key, qual) in override_keys: - continue - if isinstance(value, torch.Tensor) and qual in origin_sub: - new_val[qual] = torch.cat([origin_sub[qual], value], dim=0) - elif isinstance(value, list) and qual in origin_sub: - new_val[qual] = origin_sub[qual] + value - - self.request_payload[req_id] = payload_data - return payload_data + merged = merge_chunk_payloads(self.request_payload[req_id], payload_data) + self.request_payload[req_id] = merged + return merged def _send_single_request(self, task: dict): raw_po = task["pooling_output"] diff --git a/vllm_omni/distributed/omni_connectors/utils/payload_merge.py b/vllm_omni/distributed/omni_connectors/utils/payload_merge.py new file mode 100644 index 00000000000..d58fc0a749c --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/utils/payload_merge.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Shared depth-2 chunk-payload merge for accumulating connector chunks. + +Used by both ``OmniChunkTransferAdapter._update_request_payload`` and +``OmniConnectorModelRunnerMixin._accumulate_payload`` so the two transports +agree on how nested payload dicts combine across chunks. +""" + +from collections.abc import Iterable +from typing import Any + +import torch + +from vllm_omni.worker.payload_span import ( + get_tensor_span, + merge_tensor_spans, +) + +# Span groups: each entry is (tensor_key, start_key, end_key) within the +# nested sub-dict (e.g. inside ``embed``). When all three are present in +# both origin and incoming, the tensor is merged via ``merge_tensor_spans`` +# instead of naive ``torch.cat``. +_SPAN_GROUPS: dict[str, tuple[tuple[str, str, str], ...]] = { + "embed": ( + ("decode", "decode_token_start", "decode_token_end"), + ("cached_decode", "cached_decode_token_start", "cached_decode_token_end"), + ), +} + + +def _normalize_override_keys(raw: Any) -> set[tuple[str, str] | str]: + out: set[tuple[str, str] | str] = set() + if not isinstance(raw, Iterable) or isinstance(raw, (str, bytes)): + return out + for k in raw: + if isinstance(k, list): + out.add(tuple(k)) # type: ignore[arg-type] + elif isinstance(k, tuple): + out.add(k) + elif isinstance(k, str): + out.add(k) + return out + + +def _merge_value(origin_val: Any, incoming_val: Any) -> Any: + if isinstance(incoming_val, torch.Tensor) and isinstance(origin_val, torch.Tensor): + return torch.cat([origin_val, incoming_val], dim=0) + if isinstance(incoming_val, list) and isinstance(origin_val, list): + return origin_val + incoming_val + return incoming_val + + +def _merge_subdict( + origin_sub: dict[str, Any], + incoming_sub: dict[str, Any], + *, + type_key: str, + override_keys: set[tuple[str, str] | str], +) -> dict[str, Any]: + merged = dict(origin_sub) + span_groups = _SPAN_GROUPS.get(type_key, ()) + span_handled: set[str] = set() + + for tensor_key, start_key, end_key in span_groups: + if tensor_key not in incoming_sub: + continue + if (type_key, tensor_key) in override_keys or tensor_key in override_keys: + continue + merged_span = merge_tensor_spans( + get_tensor_span(origin_sub, tensor_key=tensor_key, start_key=start_key, end_key=end_key), + get_tensor_span(incoming_sub, tensor_key=tensor_key, start_key=start_key, end_key=end_key), + ) + if merged_span is None: + continue + tensor, start, end = merged_span + merged[tensor_key] = tensor + merged[start_key] = start + merged[end_key] = end + span_handled |= {tensor_key, start_key, end_key} + + for sub_key, sub_val in incoming_sub.items(): + if sub_key in span_handled: + continue + # `meta.finished` is always taken from incoming (terminal-state signal, + # not a value to accumulate). + if type_key == "meta" and sub_key == "finished": + merged[sub_key] = sub_val + continue + if (type_key, sub_key) in override_keys or sub_key in override_keys: + merged[sub_key] = sub_val + continue + if sub_key in origin_sub: + merged[sub_key] = _merge_value(origin_sub[sub_key], sub_val) + else: + merged[sub_key] = sub_val + + return merged + + +def merge_chunk_payloads(origin: dict[str, Any], incoming: dict[str, Any]) -> dict[str, Any]: + """Merge ``incoming`` chunk payload into ``origin`` (depth-2). + + Rules: + * For each top-level key whose value is a dict, recurse one level: + - tensors are concatenated along dim=0, + - lists are extended, + - scalars/None: incoming wins, + - ``embed.decode`` / ``embed.cached_decode``: span-aware merge when + the matching ``*_token_start``/``*_token_end`` fields are present + in both origin and incoming; otherwise naive cat. + - ``meta.finished`` is always taken from incoming. + - Sub-keys listed in ``incoming.meta.override_keys`` are replaced + (no merge). Entries may be either ``(type_key, sub_key)`` tuples + or bare ``sub_key`` strings. + * Top-level non-dict keys follow the tensor/list/scalar rules above. + + Returns a shallow-copy dict; sub-dicts are also copied. Tensors and lists + are not deep-copied (callers should treat the result as read-mostly). + """ + override_meta = incoming.get("meta", {}) if isinstance(incoming.get("meta"), dict) else {} + override_keys = _normalize_override_keys(override_meta.get("override_keys", ())) + + merged: dict[str, Any] = dict(origin) + + for type_key, incoming_val in incoming.items(): + if isinstance(incoming_val, dict): + origin_sub = origin.get(type_key) + origin_sub = origin_sub if isinstance(origin_sub, dict) else {} + merged[type_key] = _merge_subdict(origin_sub, incoming_val, type_key=type_key, override_keys=override_keys) + continue + if type_key in override_keys: + merged[type_key] = incoming_val + continue + if type_key in origin: + merged[type_key] = _merge_value(origin[type_key], incoming_val) + else: + merged[type_key] = incoming_val + + return merged diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 215c3003897..c96c9b7c0dd 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -27,14 +27,13 @@ from vllm_omni.data_entry_keys import OmniPayload from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec +from vllm_omni.distributed.omni_connectors.utils.payload_merge import merge_chunk_payloads from vllm_omni.outputs import OmniConnectorOutput from vllm_omni.worker.payload_span import ( THINKER_DECODE_EMBEDDINGS_KEY, THINKER_DECODE_TOKEN_END_KEY, THINKER_DECODE_TOKEN_START_KEY, THINKER_OUTPUT_TOKEN_IDS_KEY, - get_tensor_span, - merge_tensor_spans, ) if TYPE_CHECKING: @@ -1795,82 +1794,18 @@ def _decrement_pending_save_count(self, request_id: str) -> None: def _accumulate_payload(self, req_id: str, payload_data: OmniPayload) -> OmniPayload: """Accumulate chunk payloads (concat tensors, extend lists). - Returns a **shallow copy** of the accumulated state so callers - (e.g. ``_poll_single_request``) can store it in - ``_local_stage_payload_cache`` without aliasing the authoritative - ``_send_side_request_payload`` dict. + Delegates to ``merge_chunk_payloads`` for depth-2 merge semantics so + this accumulator agrees with + ``OmniChunkTransferAdapter._update_request_payload``. Returns a + shallow copy so callers (e.g. ``_poll_single_request``) can stash it + in ``_local_stage_payload_cache`` without aliasing the authoritative + ``_send_side_request_payload`` entry. """ if req_id not in self._send_side_request_payload: self._send_side_request_payload[req_id] = dict(payload_data) return dict(self._send_side_request_payload[req_id]) - origin = self._send_side_request_payload[req_id] - merged = dict(origin) - override_keys = payload_data.get("override_keys", ()) - drop_decode_span = False - decode_span_handled = False - for key, value in payload_data.items(): - if key == "finished": - merged[key] = value - continue - if key == THINKER_DECODE_EMBEDDINGS_KEY: - merged_span = merge_tensor_spans( - get_tensor_span( - origin, - tensor_key=THINKER_DECODE_EMBEDDINGS_KEY, - start_key=THINKER_DECODE_TOKEN_START_KEY, - end_key=THINKER_DECODE_TOKEN_END_KEY, - ), - get_tensor_span( - payload_data, - tensor_key=THINKER_DECODE_EMBEDDINGS_KEY, - start_key=THINKER_DECODE_TOKEN_START_KEY, - end_key=THINKER_DECODE_TOKEN_END_KEY, - ), - ) - if merged_span is not None: - merged[key], merged[THINKER_DECODE_TOKEN_START_KEY], merged[THINKER_DECODE_TOKEN_END_KEY] = ( - merged_span - ) - decode_span_handled = True - continue - if isinstance(value, torch.Tensor) and key in origin: - if ( - THINKER_DECODE_TOKEN_START_KEY in origin - or THINKER_DECODE_TOKEN_END_KEY in origin - or THINKER_DECODE_TOKEN_START_KEY in payload_data - or THINKER_DECODE_TOKEN_END_KEY in payload_data - ): - logger.warning( - "[Stage-%s] req=%s falling back to legacy thinker decode " - "merge due to missing/invalid/non-contiguous span " - "metadata", - self._stage_id, - req_id, - ) - drop_decode_span = True - merged[key] = torch.cat([origin[key], value], dim=0) - continue - merged[key] = value - continue - if key in {THINKER_DECODE_TOKEN_START_KEY, THINKER_DECODE_TOKEN_END_KEY}: - if decode_span_handled or drop_decode_span: - continue - merged[key] = value - continue - if key in override_keys: - merged[key] = value - continue - if isinstance(value, torch.Tensor) and key in origin: - merged[key] = torch.cat([origin[key], value], dim=0) - elif isinstance(value, list) and key in origin: - merged[key] = origin[key] + value - else: - merged[key] = value - - if drop_decode_span: - merged.pop(THINKER_DECODE_TOKEN_START_KEY, None) - merged.pop(THINKER_DECODE_TOKEN_END_KEY, None) + merged = merge_chunk_payloads(self._send_side_request_payload[req_id], payload_data) self._send_side_request_payload[req_id] = merged return dict(merged) From e7e3ff278176194f734479939542cc6e0f2d3cfe Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 28 Apr 2026 08:52:12 +0000 Subject: [PATCH 48/53] delete unused encode_payload/decode_payload helpers Address PR #3149 review claim 4: the ``encode_payload`` / ``decode_payload`` helpers (added speculatively in 22635323) were never wired into the production wire path -- they used a different tensor encoding (``numpy().tobytes()``) than the actual transport ``OmniMsgpackEncoder`` (``view(torch.uint8)``), and only their own tests in ``TestNativeMsgspecEncoding`` referenced them. Delete the helpers, their module-level encoder/decoder instances, the unused ``_msgspec_enc_hook``, and the test class. ``_msgspec_dec_hook`` stays -- it's still used by ``to_struct``. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Divyansh Singhvi --- tests/test_data_entry_keys.py | 54 ----------------------------------- vllm_omni/data_entry_keys.py | 25 ---------------- 2 files changed, 79 deletions(-) diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index 6d0b2797823..73dcc82c233 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -112,60 +112,6 @@ def test_context_in_error_message(self): validate_payload({"bad": 1}, context="my_call_site") -class TestNativeMsgspecEncoding: - """Phase 6 scaffolding: native msgspec encode/decode for OmniPayloadStruct.""" - - def test_encode_decode_round_trip_tensor(self): - from vllm_omni.data_entry_keys import decode_payload, encode_payload - - original = OmniPayloadStruct( - codes=CodesStruct(audio=torch.tensor([1, 2, 3, 4], dtype=torch.long)), - meta=MetaStruct(left_context_size=5, finished=torch.tensor(True)), - ) - wire = encode_payload(original) - assert isinstance(wire, bytes) - restored = decode_payload(wire) - assert isinstance(restored, OmniPayloadStruct) - assert torch.equal(restored.codes.audio, original.codes.audio) - assert restored.meta.left_context_size == 5 - assert bool(restored.meta.finished.item()) is True - - def test_encode_decode_round_trip_dtypes(self): - from vllm_omni.data_entry_keys import decode_payload, encode_payload - - # bfloat16 excluded: numpy() doesn't support it; callers must cast before serializing. - for dtype in (torch.float32, torch.float16, torch.int64, torch.bool): - original = OmniPayloadStruct(codes=CodesStruct(audio=torch.tensor([1, 0, 1], dtype=dtype))) - restored = decode_payload(encode_payload(original)) - assert restored.codes.audio.dtype == dtype, f"dtype mismatch for {dtype}" - - def test_encode_decode_preserves_shape(self): - from vllm_omni.data_entry_keys import decode_payload, encode_payload - - t = torch.randn(3, 4, 5) - original = OmniPayloadStruct(hidden_states=HiddenStatesStruct(output=t)) - restored = decode_payload(encode_payload(original)) - assert restored.hidden_states.output.shape == (3, 4, 5) - assert torch.allclose(restored.hidden_states.output, t) - - def test_encode_decode_speaker_language(self): - from vllm_omni.data_entry_keys import decode_payload, encode_payload - - original = OmniPayloadStruct(speaker="ethan", language="en") - restored = decode_payload(encode_payload(original)) - assert restored.speaker == "ethan" - assert restored.language == "en" - - def test_decode_rejects_unknown_field(self): - from vllm_omni.data_entry_keys import _OMNI_PAYLOAD_ENCODER, decode_payload - - # Manually craft msgpack with unknown top-level field - bad_dict = {"code_predictor_codes": [1, 2, 3]} - wire = _OMNI_PAYLOAD_ENCODER.encode(bad_dict) - with pytest.raises(msgspec.ValidationError, match="unknown field"): - decode_payload(wire) - - class TestWireEquivalenceStructVsDict: """Producer return-side migration invariant: encoding an ``OmniPayloadStruct`` via the connector serializer must decode to the same payload as encoding diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 20a044f9ca1..f01d0e49906 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -215,31 +215,6 @@ def _msgspec_dec_hook(typ: type, obj: Any) -> Any: raise NotImplementedError(f"no decoder for {typ}") -def _msgspec_enc_hook(obj: Any) -> Any: - if isinstance(obj, torch.Tensor): - return { - _TENSOR_MARKER: True, - "data": obj.detach().cpu().contiguous().numpy().tobytes(), - "shape": list(obj.shape), - "dtype": _dtype_to_name(obj.dtype), - } - raise NotImplementedError(f"no encoder for {type(obj).__name__}") - - -_OMNI_PAYLOAD_ENCODER = msgspec.msgpack.Encoder(enc_hook=_msgspec_enc_hook) -_OMNI_PAYLOAD_DECODER = msgspec.msgpack.Decoder(OmniPayloadStruct, dec_hook=_msgspec_dec_hook) - - -def encode_payload(struct: OmniPayloadStruct) -> bytes: - """Encode ``OmniPayloadStruct`` to msgpack bytes for cross-process transport.""" - return _OMNI_PAYLOAD_ENCODER.encode(struct) - - -def decode_payload(data: bytes) -> OmniPayloadStruct: - """Decode msgpack bytes back to ``OmniPayloadStruct``, validating the schema.""" - return _OMNI_PAYLOAD_DECODER.decode(data) - - def to_struct(payload: dict[str, Any]) -> OmniPayloadStruct: """Convert a payload dict into ``OmniPayloadStruct``, validating types. From 21e4f3d6c610a8be3e88d9e9553685583b2cb9dd Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 28 Apr 2026 09:13:49 +0000 Subject: [PATCH 49/53] read next_stage_prompt_len from meta with legacy-flat fallback Address PR #3149 review claim 7: ``_extract_scheduling_metadata`` read ``next_stage_prompt_len`` only from the top level, but the schema defines it under ``meta`` (both ``OmniPayloadMeta`` TypedDict and ``MetaStruct``). Once a producer writes the field, the reader would silently drop it because no producer would put it at the top level. Mirror the existing ``left_context_size`` policy: prefer ``meta.``, fall back to top-level with a one-time warning. Reader now agrees with the schema and surfaces legacy-shape producers loudly. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Divyansh Singhvi --- .../worker/omni_connector_model_runner_mixin.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index c96c9b7c0dd..366349ef0fc 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -369,16 +369,26 @@ def get_local_request_metadata(self, req_id: str) -> dict[str, Any] | None: def _extract_scheduling_metadata(cls, payload: OmniPayload) -> dict[str, Any]: """Extract only the fields the scheduler needs from a full payload.""" extracted: dict[str, Any] = {} - if "next_stage_prompt_len" in payload: + meta = payload.get("meta") if isinstance(payload, dict) else None + meta = meta if isinstance(meta, dict) else {} + + if "next_stage_prompt_len" in meta: + extracted["next_stage_prompt_len"] = meta["next_stage_prompt_len"] + elif "next_stage_prompt_len" in payload: + logger.warning_once( + "legacy flat 'next_stage_prompt_len' key in payload; expected 'meta.next_stage_prompt_len'" + ) extracted["next_stage_prompt_len"] = payload["next_stage_prompt_len"] + audio_codes = cls._payload_audio_codes(payload) if audio_codes is not None: extracted["code_predictor_codes"] = audio_codes - meta = payload.get("meta") - if isinstance(meta, dict) and "left_context_size" in meta: + + if "left_context_size" in meta: extracted["left_context_size"] = meta["left_context_size"] elif "left_context_size" in payload: logger.warning_once("legacy flat 'left_context_size' key in payload; expected 'meta.left_context_size'") + return extracted _NON_CONSUMABLE_PAYLOAD_KEYS = { From 7ecfe8a483fb923cc56f90478fbfef8326394d19 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 28 Apr 2026 09:36:04 +0000 Subject: [PATCH 50/53] preserve empty top-level lists in build_mm_cpu Address PR #3149 review claim 9: f874199d's recursion refactor of ``_to_cpu`` accidentally dropped empty top-level lists. On main, ``build_mm_cpu({"k": []})`` returns ``{"k": []}`` (the ``elif v is not None`` catch-all keeps it). After f874199d, ``_to_cpu([])`` returns ``None``, and the outer ``if cpu_v is not None`` filter drops the key entirely. Restore main behavior by returning the empty list unchanged from ``_to_cpu``. The outer filter passes it through (``[] is not None``). Other behavior changes from f874199d are intentional (recursion into nested dicts, preserving non-tensor scalars inside dicts) and not reverted. ``None`` continues to be dropped at top level on both main and the branch. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Divyansh Singhvi --- vllm_omni/utils/mm_outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index 78f0531232a..cf2837a7ca1 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -47,7 +47,7 @@ def _to_cpu(value): return out or None if isinstance(value, list): if not value: - return None + return value return [_to_cpu(v) for v in value] return value From ba7e95230ff5c649a76e4b5fd0395bea532a0f19 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 28 Apr 2026 09:43:04 +0000 Subject: [PATCH 51/53] add _send_single_request edge-case tests (struct-without-meta, empty struct) Address PR #3149 review test gap (claim 2 followup): pin two contracts in the chunk-adapter sender that were exercised but not tested. - ``test_send_single_request_struct_without_meta_does_not_crash`` asserts the ``payload.meta is not None`` guard introduced by this PR. Producers that emit only ``embed`` or ``codes`` (no scheduling metadata) must not AttributeError on ``payload_data.meta.finished``. - ``test_send_single_request_empty_struct_goes_on_wire`` pins the deliberate choice to NOT filter empty ``OmniPayloadStruct()``. The ``payload is None`` check is the only filter; producers that want to skip a chunk must return ``None``, not an empty struct. Skipped tautological checks (dict-return AttributeError tests Python's error rather than our logic; ``forbid_unknown_fields`` round-trip on cosyvoice3 ``runtime_info`` was already covered by existing ``TestValidatePayload`` and ``TestOmniPayloadStruct`` rejections of unknown top-level / sub-keys). Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Divyansh Singhvi --- .../test_chunk_transfer_adapter.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py index e90275bbbd8..3547f614ea1 100644 --- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py +++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py @@ -12,7 +12,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm.v1.request import RequestStatus -from vllm_omni.data_entry_keys import MetaStruct, OmniPayload, OmniPayloadStruct +from vllm_omni.data_entry_keys import CodesStruct, MetaStruct, OmniPayload, OmniPayloadStruct from vllm_omni.distributed.omni_connectors.transfer_adapter.base import OmniTransferAdapterBase from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, @@ -140,6 +140,47 @@ def test_save_async(build_adapter): assert task["is_finished"] is False +def test_send_single_request_struct_without_meta_does_not_crash(build_adapter, monkeypatch): + """Producer may return a struct with ``meta=None`` (e.g. payload that + carries only ``embed`` or ``codes``). The sender's ``meta is not None`` + guard handles this without AttributeError; ``finished_flag`` is None and + the cleanup path is not triggered. + """ + adapter, _ = build_adapter(stage_id=1) + request = _req("req-no-meta", RequestStatus.WAITING, external_req_id="ext-no-meta") + + adapter.custom_process_next_stage_input_func = lambda **kwargs: OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor([1, 2], dtype=torch.long)), + ) + cleanup_calls = [] + monkeypatch.setattr(adapter, "cleanup", lambda *a, **kw: cleanup_calls.append((a, kw))) + + adapter._send_single_request({"pooling_output": None, "request": request, "is_finished": False}) + + assert cleanup_calls == [] # no terminal cleanup; meta.finished is unobservable + + +def test_send_single_request_empty_struct_goes_on_wire(build_adapter, monkeypatch): + """Pin the contract: an explicitly empty ``OmniPayloadStruct()`` passes + the ``payload_data is None`` check and gets sent. To skip a chunk, the + producer must return ``None``, not an empty struct. (Filtering empty + structs at the adapter would require introspecting all struct fields on + every send and was rejected for cost vs. value.) + """ + adapter, connector = build_adapter(stage_id=1) + request = _req("req-empty", RequestStatus.WAITING, external_req_id="ext-empty") + + adapter.custom_process_next_stage_input_func = lambda **kwargs: OmniPayloadStruct() + monkeypatch.setattr(adapter, "cleanup", lambda *a, **kw: None) + + adapter._send_single_request({"pooling_output": None, "request": request, "is_finished": False}) + + assert connector.put.called + sent_payload = connector.put.call_args.kwargs["data"] + assert isinstance(sent_payload, OmniPayloadStruct) + assert sent_payload.meta is None # confirms it's the empty struct on the wire + + def test_send_single_request_cleans_up_after_finished_payload(build_adapter, monkeypatch): adapter, _ = build_adapter(stage_id=1) request = _req("req-finished", RequestStatus.FINISHED_STOPPED, external_req_id="ext-finished") From a56b45dbf3398a61683584f7a229d251bf41dbc9 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 28 Apr 2026 12:33:13 +0000 Subject: [PATCH 52/53] port pre-standardize accumulator semantics to nested layout Signed-off-by: Divyansh Singhvi --- .../omni_connectors/test_payload_merge.py | 163 ------------------ tests/worker/test_omni_connector_mixin.py | 42 ++--- vllm_omni/data_entry_keys.py | 6 - .../chunk_transfer_adapter.py | 41 +++-- .../omni_connectors/utils/payload_merge.py | 140 --------------- .../omni_connector_model_runner_mixin.py | 110 ++++++++---- 6 files changed, 133 insertions(+), 369 deletions(-) delete mode 100644 tests/distributed/omni_connectors/test_payload_merge.py delete mode 100644 vllm_omni/distributed/omni_connectors/utils/payload_merge.py diff --git a/tests/distributed/omni_connectors/test_payload_merge.py b/tests/distributed/omni_connectors/test_payload_merge.py deleted file mode 100644 index d42af0e9821..00000000000 --- a/tests/distributed/omni_connectors/test_payload_merge.py +++ /dev/null @@ -1,163 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm_omni.distributed.omni_connectors.utils.payload_merge import merge_chunk_payloads - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -def _t(values, dtype=torch.float32): - return torch.tensor(values, dtype=dtype) - - -def test_depth2_tensor_concat_under_meta_embed_codes(): - origin = { - "embed": {"decode": _t([[1.0, 2.0]])}, - "codes": {"audio": torch.tensor([1, 2], dtype=torch.int64)}, - "meta": {"left_context_size": 4}, - } - incoming = { - "embed": {"decode": _t([[3.0, 4.0]])}, - "codes": {"audio": torch.tensor([3, 4], dtype=torch.int64)}, - "meta": {"left_context_size": 5, "finished": torch.tensor(False)}, - } - - merged = merge_chunk_payloads(origin, incoming) - - assert torch.equal(merged["embed"]["decode"], _t([[1.0, 2.0], [3.0, 4.0]])) - assert torch.equal(merged["codes"]["audio"], torch.tensor([1, 2, 3, 4], dtype=torch.int64)) - # `left_context_size` is a scalar and not in override_keys, so incoming wins. - assert merged["meta"]["left_context_size"] == 5 - # `finished` is always taken from incoming. - assert bool(merged["meta"]["finished"]) is False - - -def test_meta_finished_taken_from_incoming(): - origin = {"meta": {"finished": torch.tensor(False)}} - incoming = {"meta": {"finished": torch.tensor(True)}} - merged = merge_chunk_payloads(origin, incoming) - assert bool(merged["meta"]["finished"]) is True - - -def test_override_keys_replace_instead_of_merging(): - origin = { - "embed": {"prefill": _t([[1.0]])}, - "meta": {}, - } - incoming = { - "embed": {"prefill": _t([[9.0, 9.0]])}, - "meta": {"override_keys": [["embed", "prefill"]]}, - } - merged = merge_chunk_payloads(origin, incoming) - # prefill is replaced wholesale, not concatenated. - assert torch.equal(merged["embed"]["prefill"], _t([[9.0, 9.0]])) - - -def test_span_aware_merge_adjacent(): - origin = { - "embed": { - "decode": _t([[1.0], [2.0]]), - "decode_token_start": 10, - "decode_token_end": 12, - } - } - incoming = { - "embed": { - "decode": _t([[3.0], [4.0]]), - "decode_token_start": 12, - "decode_token_end": 14, - } - } - merged = merge_chunk_payloads(origin, incoming) - assert torch.equal(merged["embed"]["decode"], _t([[1.0], [2.0], [3.0], [4.0]])) - assert merged["embed"]["decode_token_start"] == 10 - assert merged["embed"]["decode_token_end"] == 14 - - -def test_span_aware_merge_partial_overlap_trims_incoming(): - origin = { - "embed": { - "decode": _t([[1.0], [2.0], [3.0]]), - "decode_token_start": 10, - "decode_token_end": 13, - } - } - incoming = { - "embed": { - "decode": _t([[3.5], [4.0], [5.0]]), - "decode_token_start": 12, - "decode_token_end": 15, - } - } - merged = merge_chunk_payloads(origin, incoming) - # Overlap = 13 - 12 = 1, so first row of incoming is trimmed. - assert torch.equal(merged["embed"]["decode"], _t([[1.0], [2.0], [3.0], [4.0], [5.0]])) - assert merged["embed"]["decode_token_start"] == 10 - assert merged["embed"]["decode_token_end"] == 15 - - -def test_span_aware_falls_back_to_naive_cat_without_span_metadata(): - origin = {"embed": {"decode": _t([[1.0]])}} - incoming = {"embed": {"decode": _t([[2.0]])}} - merged = merge_chunk_payloads(origin, incoming) - assert torch.equal(merged["embed"]["decode"], _t([[1.0], [2.0]])) - - -def test_list_extension_inside_subdict(): - origin = {"meta": {"some_list": [1, 2]}} - incoming = {"meta": {"some_list": [3, 4]}} - merged = merge_chunk_payloads(origin, incoming) - assert merged["meta"]["some_list"] == [1, 2, 3, 4] - - -def test_first_chunk_returns_dict_unchanged(): - incoming = {"embed": {"decode": _t([[1.0]])}} - merged = merge_chunk_payloads({}, incoming) - assert torch.equal(merged["embed"]["decode"], _t([[1.0]])) - - -def test_parity_chunk_adapter_vs_mixin_accumulator(): - """Both call sites should produce identical merged output for the same - chunk pair. This pins the contract that the two transports agree. - """ - from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( - OmniChunkTransferAdapter, - ) - - chunk1 = { - "embed": { - "decode": _t([[1.0], [2.0]]), - "decode_token_start": 0, - "decode_token_end": 2, - }, - "codes": {"audio": torch.tensor([1], dtype=torch.int64)}, - "meta": {"finished": torch.tensor(False)}, - } - chunk2 = { - "embed": { - "decode": _t([[3.0]]), - "decode_token_start": 2, - "decode_token_end": 3, - }, - "codes": {"audio": torch.tensor([2], dtype=torch.int64)}, - "meta": {"finished": torch.tensor(True)}, - } - - # Adapter path - adapter = OmniChunkTransferAdapter.__new__(OmniChunkTransferAdapter) - adapter.request_payload = {} - adapter._update_request_payload("r1", dict(chunk1)) - adapter_merged = adapter._update_request_payload("r1", dict(chunk2)) - - # Mixin path: emulate by calling merge_chunk_payloads directly the same - # way `_accumulate_payload` does (post-refactor it's a thin wrapper). - mixin_merged = merge_chunk_payloads(dict(chunk1), dict(chunk2)) - - assert torch.equal(adapter_merged["embed"]["decode"], mixin_merged["embed"]["decode"]) - assert adapter_merged["embed"]["decode_token_start"] == mixin_merged["embed"]["decode_token_start"] - assert adapter_merged["embed"]["decode_token_end"] == mixin_merged["embed"]["decode_token_end"] - assert torch.equal(adapter_merged["codes"]["audio"], mixin_merged["codes"]["audio"]) - assert bool(adapter_merged["meta"]["finished"]) == bool(mixin_merged["meta"]["finished"]) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index 0e4539cff19..c8c948f7510 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -1019,10 +1019,12 @@ def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self ) host._request_ids_mapping["r1"] = "r1" payload = { - "thinker_decode_embeddings": torch.ones(1, 2), - "thinker_output_token_ids": [1], - "override_keys": ["thinker_decode_embeddings", "thinker_output_token_ids"], - "finished": torch.tensor(False), + "embed": {"decode": torch.ones(1, 2)}, + "ids": {"output": [1]}, + "meta": { + "finished": torch.tensor(False), + "override_keys": [["embed", "decode"], ["ids", "output"]], + }, } host._accumulate_payload("r1", dict(payload)) @@ -1040,15 +1042,16 @@ def test_payload_consumable_ignores_token_horizon_only_updates(self): model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), ) payload = { - "thinker_output_token_ids": [1, 2, 3], - "finished": torch.tensor(False), - "override_keys": [ - "thinker_output_token_ids", - "thinker_decode_embeddings_token_start", - "thinker_decode_embeddings_token_end", - ], - "thinker_decode_embeddings_token_start": 2, - "thinker_decode_embeddings_token_end": 3, + "ids": {"output": [1, 2, 3]}, + "embed": {"decode_token_start": 2, "decode_token_end": 3}, + "meta": { + "finished": torch.tensor(False), + "override_keys": [ + ["ids", "output"], + ["embed", "decode_token_start"], + ["embed", "decode_token_end"], + ], + }, } self.assertFalse(host._payload_is_consumable(payload)) host.shutdown_omni_connectors() @@ -1060,9 +1063,9 @@ def test_payload_consumable_accepts_decode_embeddings(self): model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), ) payload = { - "thinker_output_token_ids": [1, 2, 3], - "thinker_decode_embeddings": torch.ones(1, 2), - "finished": torch.tensor(False), + "ids": {"output": [1, 2, 3]}, + "embed": {"decode": torch.ones(1, 2)}, + "meta": {"finished": torch.tensor(False)}, } self.assertTrue(host._payload_is_consumable(payload)) host.shutdown_omni_connectors() @@ -1083,15 +1086,14 @@ def test_ar_metadata_only_followup_chunk_does_not_rewake_request(self): host._omni_connector.get.side_effect = [ ( { - "thinker_decode_embeddings": torch.ones(1, 2), - "finished": torch.tensor(False), + "embed": {"decode": torch.ones(1, 2)}, + "meta": {"finished": torch.tensor(False)}, }, 1, ), ( { - "next_stage_prompt_len": 7, - "finished": torch.tensor(False), + "meta": {"next_stage_prompt_len": 7, "finished": torch.tensor(False)}, }, 1, ), diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index f01d0e49906..b239ff5ed0b 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -113,15 +113,9 @@ class EmbeddingsStruct(_StructBase): prefill: torch.Tensor | None = None prefill_shape: list[int] | None = None decode: torch.Tensor | None = None - # Token-position span for `decode`: half-open [start, end). - # Producers may set these per-chunk to enable span-aware merge across - # chunks (see merge_chunk_payloads). When absent, the accumulator falls - # back to naive tensor concat. decode_token_start: int | None = None decode_token_end: int | None = None cached_decode: torch.Tensor | None = None - cached_decode_token_start: int | None = None - cached_decode_token_end: int | None = None tts_bos: torch.Tensor | None = None tts_eos: torch.Tensor | None = None tts_pad: torch.Tensor | None = None diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index b6ae06b2237..2bdb1136976 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -14,7 +14,6 @@ from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec from ..utils.logging import get_connector_logger -from ..utils.payload_merge import merge_chunk_payloads from .base import OmniTransferAdapterBase logger = get_connector_logger(__name__) @@ -200,18 +199,40 @@ def _poll_single_request(self, request: Request): return False def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: - """Update the stored payload for *req_id* with the latest chunk. - - Delegates to ``merge_chunk_payloads`` for depth-2 merge semantics so - the chunk-adapter accumulator agrees with - ``OmniConnectorModelRunnerMixin._accumulate_payload``. - """ + """Update the stored payload for *req_id* with the latest chunk.""" if req_id not in self.request_payload: self.request_payload[req_id] = payload_data return payload_data - merged = merge_chunk_payloads(self.request_payload[req_id], payload_data) - self.request_payload[req_id] = merged - return merged + origin = self.request_payload[req_id] + raw_ok = payload_data.get("meta", {}).pop("override_keys", []) + override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok} + + for key, value in payload_data.items(): + if isinstance(value, dict): + origin_sub = origin.get(key) + if not isinstance(origin_sub, dict): + continue + for qual, qval in value.items(): + if key == "meta" and qual == "finished": + continue + if (key, qual) in override_keys: + continue + osv = origin_sub.get(qual) + if isinstance(qval, torch.Tensor) and isinstance(osv, torch.Tensor): + value[qual] = torch.cat([osv, qval], dim=0) + elif isinstance(qval, list) and isinstance(osv, list): + value[qual] = osv + qval + else: + if key in override_keys: + continue + ov = origin.get(key) + if isinstance(value, torch.Tensor) and isinstance(ov, torch.Tensor): + payload_data[key] = torch.cat([ov, value], dim=0) + elif isinstance(value, list) and isinstance(ov, list): + payload_data[key] = ov + value + + self.request_payload[req_id] = payload_data + return payload_data def _send_single_request(self, task: dict): raw_po = task["pooling_output"] diff --git a/vllm_omni/distributed/omni_connectors/utils/payload_merge.py b/vllm_omni/distributed/omni_connectors/utils/payload_merge.py deleted file mode 100644 index d58fc0a749c..00000000000 --- a/vllm_omni/distributed/omni_connectors/utils/payload_merge.py +++ /dev/null @@ -1,140 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Shared depth-2 chunk-payload merge for accumulating connector chunks. - -Used by both ``OmniChunkTransferAdapter._update_request_payload`` and -``OmniConnectorModelRunnerMixin._accumulate_payload`` so the two transports -agree on how nested payload dicts combine across chunks. -""" - -from collections.abc import Iterable -from typing import Any - -import torch - -from vllm_omni.worker.payload_span import ( - get_tensor_span, - merge_tensor_spans, -) - -# Span groups: each entry is (tensor_key, start_key, end_key) within the -# nested sub-dict (e.g. inside ``embed``). When all three are present in -# both origin and incoming, the tensor is merged via ``merge_tensor_spans`` -# instead of naive ``torch.cat``. -_SPAN_GROUPS: dict[str, tuple[tuple[str, str, str], ...]] = { - "embed": ( - ("decode", "decode_token_start", "decode_token_end"), - ("cached_decode", "cached_decode_token_start", "cached_decode_token_end"), - ), -} - - -def _normalize_override_keys(raw: Any) -> set[tuple[str, str] | str]: - out: set[tuple[str, str] | str] = set() - if not isinstance(raw, Iterable) or isinstance(raw, (str, bytes)): - return out - for k in raw: - if isinstance(k, list): - out.add(tuple(k)) # type: ignore[arg-type] - elif isinstance(k, tuple): - out.add(k) - elif isinstance(k, str): - out.add(k) - return out - - -def _merge_value(origin_val: Any, incoming_val: Any) -> Any: - if isinstance(incoming_val, torch.Tensor) and isinstance(origin_val, torch.Tensor): - return torch.cat([origin_val, incoming_val], dim=0) - if isinstance(incoming_val, list) and isinstance(origin_val, list): - return origin_val + incoming_val - return incoming_val - - -def _merge_subdict( - origin_sub: dict[str, Any], - incoming_sub: dict[str, Any], - *, - type_key: str, - override_keys: set[tuple[str, str] | str], -) -> dict[str, Any]: - merged = dict(origin_sub) - span_groups = _SPAN_GROUPS.get(type_key, ()) - span_handled: set[str] = set() - - for tensor_key, start_key, end_key in span_groups: - if tensor_key not in incoming_sub: - continue - if (type_key, tensor_key) in override_keys or tensor_key in override_keys: - continue - merged_span = merge_tensor_spans( - get_tensor_span(origin_sub, tensor_key=tensor_key, start_key=start_key, end_key=end_key), - get_tensor_span(incoming_sub, tensor_key=tensor_key, start_key=start_key, end_key=end_key), - ) - if merged_span is None: - continue - tensor, start, end = merged_span - merged[tensor_key] = tensor - merged[start_key] = start - merged[end_key] = end - span_handled |= {tensor_key, start_key, end_key} - - for sub_key, sub_val in incoming_sub.items(): - if sub_key in span_handled: - continue - # `meta.finished` is always taken from incoming (terminal-state signal, - # not a value to accumulate). - if type_key == "meta" and sub_key == "finished": - merged[sub_key] = sub_val - continue - if (type_key, sub_key) in override_keys or sub_key in override_keys: - merged[sub_key] = sub_val - continue - if sub_key in origin_sub: - merged[sub_key] = _merge_value(origin_sub[sub_key], sub_val) - else: - merged[sub_key] = sub_val - - return merged - - -def merge_chunk_payloads(origin: dict[str, Any], incoming: dict[str, Any]) -> dict[str, Any]: - """Merge ``incoming`` chunk payload into ``origin`` (depth-2). - - Rules: - * For each top-level key whose value is a dict, recurse one level: - - tensors are concatenated along dim=0, - - lists are extended, - - scalars/None: incoming wins, - - ``embed.decode`` / ``embed.cached_decode``: span-aware merge when - the matching ``*_token_start``/``*_token_end`` fields are present - in both origin and incoming; otherwise naive cat. - - ``meta.finished`` is always taken from incoming. - - Sub-keys listed in ``incoming.meta.override_keys`` are replaced - (no merge). Entries may be either ``(type_key, sub_key)`` tuples - or bare ``sub_key`` strings. - * Top-level non-dict keys follow the tensor/list/scalar rules above. - - Returns a shallow-copy dict; sub-dicts are also copied. Tensors and lists - are not deep-copied (callers should treat the result as read-mostly). - """ - override_meta = incoming.get("meta", {}) if isinstance(incoming.get("meta"), dict) else {} - override_keys = _normalize_override_keys(override_meta.get("override_keys", ())) - - merged: dict[str, Any] = dict(origin) - - for type_key, incoming_val in incoming.items(): - if isinstance(incoming_val, dict): - origin_sub = origin.get(type_key) - origin_sub = origin_sub if isinstance(origin_sub, dict) else {} - merged[type_key] = _merge_subdict(origin_sub, incoming_val, type_key=type_key, override_keys=override_keys) - continue - if type_key in override_keys: - merged[type_key] = incoming_val - continue - if type_key in origin: - merged[type_key] = _merge_value(origin[type_key], incoming_val) - else: - merged[type_key] = incoming_val - - return merged diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 366349ef0fc..f868438b0c3 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -27,15 +27,14 @@ from vllm_omni.data_entry_keys import OmniPayload from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec -from vllm_omni.distributed.omni_connectors.utils.payload_merge import merge_chunk_payloads from vllm_omni.outputs import OmniConnectorOutput from vllm_omni.worker.payload_span import ( - THINKER_DECODE_EMBEDDINGS_KEY, - THINKER_DECODE_TOKEN_END_KEY, - THINKER_DECODE_TOKEN_START_KEY, - THINKER_OUTPUT_TOKEN_IDS_KEY, + get_tensor_span, + merge_tensor_spans, ) +_EMBED_SPAN_GROUPS: tuple[tuple[str, str, str], ...] = (("decode", "decode_token_start", "decode_token_end"),) + if TYPE_CHECKING: from vllm_omni.distributed.omni_connectors.connectors.base import ( OmniConnectorBase, @@ -391,14 +390,14 @@ def _extract_scheduling_metadata(cls, payload: OmniPayload) -> dict[str, Any]: return extracted - _NON_CONSUMABLE_PAYLOAD_KEYS = { - "finished", - "override_keys", - "next_stage_prompt_len", - "left_context_size", - THINKER_OUTPUT_TOKEN_IDS_KEY, - THINKER_DECODE_TOKEN_START_KEY, - THINKER_DECODE_TOKEN_END_KEY, + _NON_CONSUMABLE_PAYLOAD_KEYS: set[tuple[str, str]] = { + ("meta", "finished"), + ("meta", "override_keys"), + ("meta", "next_stage_prompt_len"), + ("meta", "left_context_size"), + ("ids", "output"), + ("embed", "decode_token_start"), + ("embed", "decode_token_end"), } @staticmethod @@ -448,23 +447,29 @@ def _payload_is_consumable(cls, payload: OmniPayload | None) -> bool: if not isinstance(payload, dict) or not payload: return False - decode_embeddings = payload.get(THINKER_DECODE_EMBEDDINGS_KEY) - if isinstance(decode_embeddings, torch.Tensor): - if decode_embeddings.ndim == 0: - return True - return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0 + embed = payload.get("embed") + if isinstance(embed, dict): + decode_embeddings = embed.get("decode") + if isinstance(decode_embeddings, torch.Tensor): + if decode_embeddings.ndim == 0: + return True + return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0 audio_codes = cls._payload_audio_codes(payload) if audio_codes is not None: if isinstance(audio_codes, torch.Tensor): return audio_codes.numel() > 0 - # Codec code 0 is valid; non-empty code payloads are consumable. if hasattr(audio_codes, "__len__"): return len(audio_codes) > 0 return True for key, value in payload.items(): - if key in cls._NON_CONSUMABLE_PAYLOAD_KEYS: + if isinstance(value, dict): + for sk, sv in value.items(): + if (key, sk) in cls._NON_CONSUMABLE_PAYLOAD_KEYS: + continue + if cls._payload_value_has_content(sv): + return True continue if cls._payload_value_has_content(value): return True @@ -1802,20 +1807,65 @@ def _decrement_pending_save_count(self, request_id: str) -> None: # ------------------------------------------------------------------ # def _accumulate_payload(self, req_id: str, payload_data: OmniPayload) -> OmniPayload: - """Accumulate chunk payloads (concat tensors, extend lists). - - Delegates to ``merge_chunk_payloads`` for depth-2 merge semantics so - this accumulator agrees with - ``OmniChunkTransferAdapter._update_request_payload``. Returns a - shallow copy so callers (e.g. ``_poll_single_request``) can stash it - in ``_local_stage_payload_cache`` without aliasing the authoritative - ``_send_side_request_payload`` entry. - """ + """Accumulate chunk payloads (concat tensors, extend lists).""" if req_id not in self._send_side_request_payload: self._send_side_request_payload[req_id] = dict(payload_data) return dict(self._send_side_request_payload[req_id]) - merged = merge_chunk_payloads(self._send_side_request_payload[req_id], payload_data) + origin = self._send_side_request_payload[req_id] + merged = dict(origin) + raw_ok = payload_data.get("meta", {}).get("override_keys", []) if isinstance(payload_data, dict) else [] + override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok} + + for key, value in payload_data.items(): + if isinstance(value, dict): + origin_sub = origin.get(key) + merged_sub = dict(origin_sub) if isinstance(origin_sub, dict) else {} + span_handled: set[str] = set() + if key == "embed" and isinstance(origin_sub, dict): + for tk, sk, ek in _EMBED_SPAN_GROUPS: + if tk not in value or (key, tk) in override_keys: + continue + span = merge_tensor_spans( + get_tensor_span(origin_sub, tensor_key=tk, start_key=sk, end_key=ek), + get_tensor_span(value, tensor_key=tk, start_key=sk, end_key=ek), + ) + if span is None: + continue + t, s, e = span + merged_sub[tk] = t + merged_sub[sk] = s + merged_sub[ek] = e + span_handled |= {tk, sk, ek} + for qual, qval in value.items(): + if qual in span_handled: + continue + if key == "meta" and qual == "finished": + merged_sub[qual] = qval + continue + if (key, qual) in override_keys: + merged_sub[qual] = qval + continue + osv = merged_sub.get(qual) + if isinstance(qval, torch.Tensor) and isinstance(osv, torch.Tensor): + merged_sub[qual] = torch.cat([osv, qval], dim=0) + elif isinstance(qval, list) and isinstance(osv, list): + merged_sub[qual] = osv + qval + else: + merged_sub[qual] = qval + merged[key] = merged_sub + else: + if key in override_keys: + merged[key] = value + continue + ov = origin.get(key) + if isinstance(value, torch.Tensor) and isinstance(ov, torch.Tensor): + merged[key] = torch.cat([ov, value], dim=0) + elif isinstance(value, list) and isinstance(ov, list): + merged[key] = ov + value + else: + merged[key] = value + self._send_side_request_payload[req_id] = merged return dict(merged) From 036a918a3c5895a75f8d76bfa0a819c9cb0cf74c Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Tue, 12 May 2026 03:19:30 +0000 Subject: [PATCH 53/53] fix(test): use attribute access on OmniPayloadStruct in async chunk test Signed-off-by: Divyansh Singhvi --- .../stage_input_processors/test_qwen3_tts_async_chunk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index 62d0dcec2b1..950ae213f72 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -214,13 +214,13 @@ def test_connector_initial_chunk_config_overrides_dynamic_ic(): p1 = _call(tm, "r", n_frames=4) assert p1 is not None - assert len(p1["codes"]["audio"]) == _Q * 4 + assert len(p1.codes.audio) == _Q * 4 # Only the first chunk uses the small size; the next emit is 4+25. assert _call(tm, "r", n_frames=25) is None p2 = _call(tm, "r", n_frames=29) assert p2 is not None - assert p2["meta"]["left_context_size"] == 4 + assert p2.meta.left_context_size == 4 @pytest.mark.parametrize(