diff --git a/tests/core/sched/test_chunk_scheduling_coordinator.py b/tests/core/sched/test_chunk_scheduling_coordinator.py index 5e19465e224..c9459d8ec43 100644 --- a/tests/core/sched/test_chunk_scheduling_coordinator.py +++ b/tests/core/sched/test_chunk_scheduling_coordinator.py @@ -222,7 +222,7 @@ def test_generation_mode(self): self.assertEqual(req.prompt_token_ids, [10, 20, 30]) self.assertEqual(req.num_computed_tokens, 0) self.assertIsNone(req.additional_information) - self.assertEqual(req._omni_initial_model_buffer, {"left_context_size": 25}) + self.assertEqual(req._omni_initial_model_buffer, {"meta": {"left_context_size": 25}}) class TestChunkCoordinatorPostprocess(unittest.TestCase): diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py index 73f5d8938ac..b8c1761cdef 100644 --- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py +++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py @@ -12,7 +12,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm.v1.request import RequestStatus -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import CodesStruct, MetaStruct, OmniPayload, OmniPayloadStruct from vllm_omni.distributed.omni_connectors.transfer_adapter.base import OmniTransferAdapterBase from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, @@ -140,11 +140,54 @@ def test_save_async(build_adapter): assert task["is_finished"] is False +def test_send_single_request_struct_without_meta_does_not_crash(build_adapter, monkeypatch): + """Producer may return a struct with ``meta=None`` (e.g. payload that + carries only ``embed`` or ``codes``). The sender's ``meta is not None`` + guard handles this without AttributeError; ``finished_flag`` is None and + the cleanup path is not triggered. + """ + adapter, _ = build_adapter(stage_id=1) + request = _req("req-no-meta", RequestStatus.WAITING, external_req_id="ext-no-meta") + + adapter.custom_process_next_stage_input_func = lambda **kwargs: OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor([1, 2], dtype=torch.long)), + ) + cleanup_calls = [] + monkeypatch.setattr(adapter, "cleanup", lambda *a, **kw: cleanup_calls.append((a, kw))) + + adapter._send_single_request({"pooling_output": None, "request": request, "is_finished": False}) + + assert cleanup_calls == [] # no terminal cleanup; meta.finished is unobservable + + +def test_send_single_request_empty_struct_goes_on_wire(build_adapter, monkeypatch): + """Pin the contract: an explicitly empty ``OmniPayloadStruct()`` passes + the ``payload_data is None`` check and gets sent. To skip a chunk, the + producer must return ``None``, not an empty struct. (Filtering empty + structs at the adapter would require introspecting all struct fields on + every send and was rejected for cost vs. value.) + """ + adapter, connector = build_adapter(stage_id=1) + request = _req("req-empty", RequestStatus.WAITING, external_req_id="ext-empty") + + adapter.custom_process_next_stage_input_func = lambda **kwargs: OmniPayloadStruct() + monkeypatch.setattr(adapter, "cleanup", lambda *a, **kw: None) + + adapter._send_single_request({"pooling_output": None, "request": request, "is_finished": False}) + + assert connector.put.called + sent_payload = connector.put.call_args.kwargs["data"] + assert isinstance(sent_payload, OmniPayloadStruct) + assert sent_payload.meta is None # confirms it's the empty struct on the wire + + def test_send_single_request_cleans_up_after_finished_payload(build_adapter, monkeypatch): adapter, _ = build_adapter(stage_id=1) request = _req("req-finished", RequestStatus.FINISHED_STOPPED, external_req_id="ext-finished") - adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": True} + adapter.custom_process_next_stage_input_func = lambda **kwargs: OmniPayloadStruct( + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)) + ) cleanup_calls = [] monkeypatch.setattr(adapter, "cleanup", lambda *a, **kw: cleanup_calls.append((a, kw))) @@ -604,3 +647,42 @@ def _super_finish(_self, request_ids, finished_status): OmniARScheduler.finish_requests(sched, ["r1"], RequestStatus.FINISHED_ABORTED) assert order == ["adapter", "super"] + + +def test_wire_round_trip_struct_to_dict_contract(): + """Pin the wire contract: encoding ``OmniPayloadStruct`` and decoding it + yields a dict equivalent to ``to_dict(struct)``. + + The chunk-adapter sender uses struct attribute access while the receiver + uses dict-key access. This works only because ``OmniMsgpackDecoder`` has + no target type and decodes structs back to plain dicts. If this test + breaks, the receiver's dict access will silently drop fields or KeyError. + """ + from vllm_omni.data_entry_keys import CodesStruct, to_dict + from vllm_omni.distributed.omni_connectors.utils.serialization import ( + OmniMsgpackDecoder, + OmniMsgpackEncoder, + ) + + struct = OmniPayloadStruct( + meta=MetaStruct( + finished=torch.tensor(True, dtype=torch.bool), + left_context_size=12, + ), + codes=CodesStruct(audio=torch.tensor([1, 2, 3], dtype=torch.int64)), + ) + + encoded = OmniMsgpackEncoder().encode(struct) + decoded = OmniMsgpackDecoder().decode(encoded) + + assert isinstance(decoded, dict) + assert isinstance(decoded["meta"], dict) + assert isinstance(decoded["meta"]["finished"], torch.Tensor) + assert bool(decoded["meta"]["finished"].item()) is True + assert decoded["meta"]["left_context_size"] == 12 + assert torch.equal(decoded["codes"]["audio"], torch.tensor([1, 2, 3], dtype=torch.int64)) + + expected = to_dict(struct) + assert set(decoded.keys()) == set(expected.keys()) + assert set(decoded["meta"].keys()) == set(expected["meta"].keys()) + assert set(decoded["codes"].keys()) == set(expected["codes"].keys()) diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index 134320403b2..bc6a2646639 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -164,11 +164,12 @@ def test_forward_prefers_token_offset_when_present(): runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "token_offset": 2, - "left_context_size": 1, + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "meta": {"left_context_size": 2}, } ] @@ -193,10 +194,12 @@ def test_forward_falls_back_to_left_context_size_for_backward_compat(): runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "left_context_size": 2, + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "meta": {"left_context_size": 2}, } ] @@ -214,10 +217,12 @@ def test_forward_ignores_single_request_padded_tail_tokens(): model = _make_code2wav_model(with_stride_cfg=True) runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "token_offset": 0, + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "meta": {"left_context_size": 0}, } ] @@ -238,10 +243,12 @@ def test_forward_uses_non_stream_decode_without_chunk_metadata(): runtime_info = [ { - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "prefix_ids": [101, 102], + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "ids": {"prompt": [101, 102]}, "generated_len": 3, } ] @@ -275,12 +282,16 @@ def test_forward_reuses_streaming_cache_state_between_chunks(): ) runtime_info = [ { - "req_id": ["rid-stream"], - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "token_offset": 0, - "stream_finished": torch.tensor(False), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "meta": { + "req_id": ["rid-stream"], + "stream_finished": torch.tensor(False), + "left_context_size": 0, + }, } ] @@ -321,12 +332,16 @@ def test_forward_clears_streaming_cache_on_terminal_chunk(): ) runtime_info = [ { - "req_id": ["rid-stream"], - "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), - "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), - "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), - "token_offset": 0, - "stream_finished": torch.tensor(False), + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "meta": { + "req_id": ["rid-stream"], + "stream_finished": torch.tensor(False), + "left_context_size": 0, + }, } ] @@ -338,7 +353,7 @@ def test_forward_clears_streaming_cache_on_terminal_chunk(): ) assert "rid-stream" in model._stream_vocoder_cache_by_req - runtime_info[0]["stream_finished"] = torch.tensor(True) + runtime_info[0]["meta"]["stream_finished"] = torch.tensor(True) out = model.forward( input_ids=torch.tensor([0, 1, 2], dtype=torch.long), positions=torch.tensor([0, 1, 2], dtype=torch.long), diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py index a762ce415e2..d54533fd0cb 100644 --- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py +++ b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py @@ -64,9 +64,11 @@ def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): external_req_id="rid-0", output_token_ids=[1, 2, 6562, 3], additional_information={ - "speech_token": [torch.tensor([[11, 12, 13]])], - "speech_feat": [torch.tensor([[[0.1, 0.2], [0.3, 0.4]]])], - "embedding": [torch.tensor([[0.5, 0.6]])], + "embed": { + "speech_token": [torch.tensor([[11, 12, 13]])], + "speech_feat": [torch.tensor([[[0.1, 0.2], [0.3, 0.4]]])], + "embedding": [torch.tensor([[0.5, 0.6]])], + }, }, is_finished=lambda: True, ) @@ -79,15 +81,14 @@ def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): ) assert payload is not None - assert payload["meta"]["finished"].item() is True - assert payload["codes"]["audio"] == [1, 2, 3] - assert payload["token_offset"] == 0 - assert payload["left_context_size"] == 0 - assert payload["req_id"] == ["rid-0"] - assert payload["stream_finished"].item() is True - assert "speech_token" in payload - assert "speech_feat" in payload - assert "embedding" in payload + assert payload.meta.finished.item() is True + assert payload.codes.audio.tolist() == [1, 2, 3] + assert payload.meta.left_context_size == 0 + assert payload.meta.req_id == ["rid-0"] + assert payload.meta.stream_finished.item() is True + assert payload.embed.speech_token is not None + assert payload.embed.speech_feat is not None + assert payload.embed.embedding is not None def test_talker2code2wav_async_chunk_emits_eof_when_finished_without_valid_codes(): @@ -107,8 +108,8 @@ def test_talker2code2wav_async_chunk_emits_eof_when_finished_without_valid_codes ) assert payload is not None - assert payload["codes"]["audio"] == [] - assert payload["meta"]["finished"].item() is True + assert payload.codes.audio.tolist() == [] + assert payload.meta.finished.item() is True def test_talker2code2wav_async_chunk_does_not_reemit_without_new_tokens(): @@ -134,8 +135,8 @@ def test_talker2code2wav_async_chunk_does_not_reemit_without_new_tokens(): ) assert payload1 is not None - assert payload1["codes"]["audio"] == [1, 2] - assert payload1["token_offset"] == 0 + assert payload1.codes.audio.tolist() == [1, 2] + assert payload1.meta.left_context_size == 0 assert payload2 is None @@ -164,9 +165,9 @@ def test_talker2code2wav_async_chunk_waits_for_prelookahead_and_emits_cumulative assert payload_pending is None assert payload_ready is not None - assert payload_ready["codes"]["audio"] == [1, 2, 3] - assert payload_ready["token_offset"] == 0 - assert payload_ready["meta"]["finished"].item() is False + assert payload_ready.codes.audio.tolist() == [1, 2, 3] + assert payload_ready.meta.left_context_size == 0 + assert payload_ready.meta.finished.item() is False def test_talker2code2wav_async_chunk_final_flush_uses_previous_token_offset(): @@ -193,13 +194,13 @@ def test_talker2code2wav_async_chunk_final_flush_uses_previous_token_offset(): ) assert payload_stream is not None - assert payload_stream["meta"]["finished"].item() is False - assert payload_stream["codes"]["audio"] == [3, 4, 5] - assert payload_stream["token_offset"] == 0 + assert payload_stream.meta.finished.item() is False + assert payload_stream.codes.audio.tolist() == [3, 4, 5] + assert payload_stream.meta.left_context_size == 0 assert payload_final is not None - assert payload_final["meta"]["finished"].item() is True - assert payload_final["codes"]["audio"] == [3, 4, 5, 6] - assert payload_final["token_offset"] == 2 + assert payload_final.meta.finished.item() is True + assert payload_final.codes.audio.tolist() == [3, 4, 5, 6] + assert payload_final.meta.left_context_size == 2 def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): @@ -208,7 +209,7 @@ def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): external_req_id="rid-pad", output_token_ids=[8, 9, 10], additional_information={ - "speech_token": [torch.tensor([[1, 2, 3]])], + "embed": {"speech_token": [torch.tensor([[1, 2, 3]])]}, }, is_finished=lambda: False, ) @@ -229,8 +230,8 @@ def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk(): assert payload_pending is None assert payload_ready is not None - assert payload_ready["codes"]["audio"] == [8, 9, 10, 11] - assert payload_ready["token_offset"] == 0 + assert payload_ready.codes.audio.tolist() == [8, 9, 10, 11] + assert payload_ready.meta.left_context_size == 0 def test_talker2code2wav_async_chunk_emits_terminal_eof_without_duplicate_audio(): @@ -256,8 +257,8 @@ def test_talker2code2wav_async_chunk_emits_terminal_eof_without_duplicate_audio( ) assert payload_stream is not None - assert payload_stream["meta"]["finished"].item() is False - assert payload_stream["codes"]["audio"] == [3, 4] + assert payload_stream.meta.finished.item() is False + assert payload_stream.codes.audio.tolist() == [3, 4] assert payload_final is not None - assert payload_final["meta"]["finished"].item() is True - assert payload_final["codes"]["audio"] == [] + assert payload_final.meta.finished.item() is True + assert payload_final.codes.audio.tolist() == [] diff --git a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py index b30da97800b..acdd95372b8 100644 --- a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py +++ b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py @@ -20,16 +20,16 @@ def test_flush_remaining_codes_when_no_codes_accumulated_missing_request_id(): """No entry for request_id: treat as empty, return finished sentinel with empty codes.""" tm = SimpleNamespace(code_prompt_token_ids={}) out = _flush_remaining_codes(tm, "missing", chunk_size=3, left_context_size=3) - assert out["codes"]["audio"] == _sentinel()["codes"]["audio"] - assert out["meta"]["finished"].item() is True + assert out.codes.audio.tolist() == _sentinel()["codes"]["audio"] + assert out.meta.finished.item() is True def test_flush_remaining_codes_when_no_codes_accumulated_empty_list(): """Explicit empty accumulation list returns the same sentinel.""" tm = SimpleNamespace(code_prompt_token_ids={"r": []}) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) - assert out["codes"]["audio"] == [] - assert out["meta"]["finished"].item() is True + assert out.codes.audio.tolist() == [] + assert out.meta.finished.item() is True def test_flush_remaining_codes_partial_chunk_remaining(): @@ -41,8 +41,8 @@ def test_flush_remaining_codes_partial_chunk_remaining(): code_prompt_token_ids={"r": [[1], [2], [3], [4], [5], [6], [7]]}, ) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) - assert out["meta"]["finished"].item() is True - assert out["codes"]["audio"] == [4, 5, 6, 7] + assert out.meta.finished.item() is True + assert out.codes.audio.tolist() == [4, 5, 6, 7] def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size(): @@ -52,7 +52,7 @@ def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size(): ) out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3) # context_length = chunk_size = 3, end_index = min(6, 6) -> all 6 - assert out["codes"]["audio"] == [1, 2, 3, 4, 5, 6] + assert out.codes.audio.tolist() == [1, 2, 3, 4, 5, 6] @pytest.mark.parametrize( @@ -74,5 +74,5 @@ def test_flush_remaining_codes_context_window_end_index( tm = SimpleNamespace(code_prompt_token_ids={"r": accumulated}) out = _flush_remaining_codes(tm, "r", chunk_size=chunk_size, left_context_size=left_context) expected_flat = list(range(length - expected_end_index, length)) - assert out["codes"]["audio"] == expected_flat - assert out["meta"]["finished"].item() is True + assert out.codes.audio.tolist() == expected_flat + assert out.meta.finished.item() is True diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index e4045f37fb3..950ae213f72 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -80,8 +80,8 @@ def test_eof_marker_when_finished_empty(): request=_req("r", finished=True), is_finished=True, ) - assert p["codes"] == {"audio": []} - assert p["meta"]["finished"].item() is True + assert p.codes.audio.tolist() == [] + assert p.meta.finished.item() is True def test_flush_on_finish(): @@ -94,8 +94,8 @@ def test_flush_on_finish(): is_finished=True, ) assert p is not None - assert p["meta"]["finished"].item() is True - assert len(p["codes"]["audio"]) == _Q * 24 + assert p.meta.finished.item() is True + assert len(p.codes.audio) == _Q * 24 _CASES = [ @@ -149,8 +149,8 @@ def test_streaming_phases(config, n_frames, finished, expected): else: exp_ctx, exp_window = expected assert payload is not None - assert payload["meta"]["left_context_size"] == exp_ctx - assert len(payload["codes"]["audio"]) == _Q * exp_window + assert payload.meta.left_context_size == exp_ctx + assert len(payload.codes.audio) == _Q * exp_window def test_dynamic_ic_adapts_to_load(): @@ -160,14 +160,14 @@ def test_dynamic_ic_adapts_to_load(): # Low load (1/8) -> IC=2 -> emit at 2 p1 = _call(tm, "r", n_frames=2) assert p1 is not None - assert len(p1["codes"]["audio"]) == _Q * 2 + assert len(p1.codes.audio) == _Q * 2 # High load on a new request: active=6/8 -> IC=8 -> emit at 8 for i in range(4): tm.code_prompt_token_ids[f"other-{i}"] = [[0]] p2 = _call(tm, "new-high-load", n_frames=8) assert p2 is not None - assert len(p2["codes"]["audio"]) == _Q * 8 + assert len(p2.codes.audio) == _Q * 8 # Requests past initial phase still count in load factor tm2 = _tm(max_num_seqs=4) @@ -176,7 +176,7 @@ def test_dynamic_ic_adapts_to_load(): # active=4/4=1.0 -> IC=16 p3 = _call(tm2, "new", n_frames=16) assert p3 is not None - assert len(p3["codes"]["audio"]) == _Q * 16 + assert len(p3.codes.audio) == _Q * 16 def test_ic_load_change_mid_request(): @@ -195,7 +195,7 @@ def test_ic_load_change_mid_request(): assert _call(tm, "r", n_frames=25) is None p3 = _call(tm, "r", n_frames=27) assert p3 is not None - assert p3["meta"]["left_context_size"] == 2 + assert p3.meta.left_context_size == 2 # A *new* request under high load gets IC=16 (not IC=2). # Frame 2 would emit under IC=2 but must hold under IC=16. @@ -214,13 +214,13 @@ def test_connector_initial_chunk_config_overrides_dynamic_ic(): p1 = _call(tm, "r", n_frames=4) assert p1 is not None - assert len(p1["codes"]["audio"]) == _Q * 4 + assert len(p1.codes.audio) == _Q * 4 # Only the first chunk uses the small size; the next emit is 4+25. assert _call(tm, "r", n_frames=25) is None p2 = _call(tm, "r", n_frames=29) assert p2 is not None - assert p2["meta"]["left_context_size"] == 4 + assert p2.meta.left_context_size == 4 @pytest.mark.parametrize( @@ -268,8 +268,8 @@ def test_first_streaming_chunk_prepends_ref_code_context(): ) assert payload is not None - assert payload["meta"]["left_context_size"] == 2 - assert len(payload["codes"]["audio"]) == _Q * 12 + assert payload.meta.left_context_size == 2 + assert len(payload.codes.audio) == _Q * 12 def test_ref_code_context_applies_to_all_streaming_chunks(): @@ -290,8 +290,8 @@ def test_ref_code_context_applies_to_all_streaming_chunks(): assert payload is not None # ref_code (2 frames) prepended as left context on second chunk too - assert payload["meta"]["left_context_size"] == 10 + 2 - assert len(payload["codes"]["audio"]) == _Q * (35 + 2) + assert payload.meta.left_context_size == 10 + 2 + assert len(payload.codes.audio) == _Q * (35 + 2) def test_ref_code_context_can_be_buffered_before_first_emit(): @@ -325,8 +325,8 @@ def test_ref_code_context_can_be_buffered_before_first_emit(): assert payload is not None # ref_code (2 frames) is kept (not popped) for subsequent chunks - assert payload["meta"]["left_context_size"] == 2 - assert len(payload["codes"]["audio"]) == _Q * 12 + assert payload.meta.left_context_size == 2 + assert len(payload.codes.audio) == _Q * 12 assert rid in tm.request_payload @@ -353,7 +353,7 @@ def test_non_async_processor_prepends_ref_code_and_sets_trim_context(): assert len(prompts) == 1 prompt = prompts[0] - assert prompt["additional_information"] == {"meta": {"left_context_size": [2]}} + assert prompt["additional_information"] == {"meta": {"left_context_size": 2}} assert prompt["prompt_token_ids"] == [ 9, 8, @@ -401,4 +401,4 @@ def test_non_async_processor_filters_out_of_range_codec_values(): prompt = prompts[0] # Only ref_code (1 frame) + 2 valid frames = 3 frames * 4 quantizers = 12 codes assert len(prompt["prompt_token_ids"]) == 4 * 3 - assert prompt["additional_information"] == {"meta": {"left_context_size": [1]}} + assert prompt["additional_information"] == {"meta": {"left_context_size": 1}} diff --git a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py index 7d6fc6e74c9..6738c763aec 100644 --- a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py @@ -57,8 +57,8 @@ def test_latent2vae_async_chunk_serializes_latent_payload(): ) assert payload is not None - assert torch.equal(payload["finished"], torch.tensor(False, dtype=torch.bool)) - recovered = _decode_serialized_latent(payload["code_predictor_codes"]) + assert torch.equal(payload.meta.finished, torch.tensor(False, dtype=torch.bool)) + recovered = _decode_serialized_latent(payload.codes.audio.tolist()) torch.testing.assert_close(recovered, latent.to(torch.bfloat16).to(torch.float32).unsqueeze(0)) @@ -70,10 +70,8 @@ def test_latent2vae_async_chunk_returns_terminal_marker_without_latent(): is_finished=False, ) - assert payload == { - "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), - } + assert payload.codes.audio.tolist() == [] + assert torch.equal(payload.meta.finished, torch.tensor(True, dtype=torch.bool)) def test_latent2vae_async_chunk_returns_none_for_nonterminal_empty_chunk(): diff --git a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py index a61c11b5e73..7e63ede9e61 100644 --- a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py @@ -114,8 +114,8 @@ def test_flush_tail_when_finished(): ) assert payload is not None - assert payload["meta"]["finished"].item() is True - codes = payload["codes"]["audio"] + assert payload.meta.finished.item() is True + codes = payload.codes.audio # Format: [ctx_frames, context_length, ...flat_codes] assert len(codes) >= 2 # At least ctx_frames + context_length header ctx_frames = codes[0] @@ -137,8 +137,8 @@ def test_eof_marker_when_finished_with_no_frames(): request=request, ) - assert payload["codes"] == {"audio": []} - assert payload["meta"]["finished"].item() is True + assert payload.codes.audio.tolist() == [] + assert payload.meta.finished.item() is True def test_normal_chunk_emission(): @@ -162,7 +162,7 @@ def test_normal_chunk_emission(): # A chunk should be emitted assert payload is not None - codes = payload["codes"]["audio"] + codes = payload.codes.audio ctx_frames = codes[0] context_length = codes[1] assert ctx_frames == 20 # 25 - 5(chunk_size_at_begin) @@ -189,7 +189,7 @@ def test_small_initial_chunks(): ) assert payload is not None - codes = payload["codes"]["audio"] + codes = payload.codes.audio ctx_frames = codes[0] context_length = codes[1] assert ctx_frames == 0 @@ -238,19 +238,16 @@ def test_context_handling_format(): ) assert payload is not None - codes = payload["codes"]["audio"] - # First two elements are ctx_frames and context_length - ctx_frames = codes[0] - context_length = codes[1] + codes = payload.codes.audio + # codes is a 1-D long tensor: [ctx_frames, context_length, ...flat_codes] + ctx_frames = int(codes[0].item()) + context_length = int(codes[1].item()) flat_codes = codes[2:] - # The window has ctx_frames + context_length frames - assert isinstance(ctx_frames, int) - assert isinstance(context_length, int) assert ctx_frames >= 0 assert context_length > 0 # flat_codes = total_window_frames * codebook_dim total_window_frames = ctx_frames + context_length - assert len(flat_codes) == total_window_frames * 2 # codebook_dim=2 + assert flat_codes.numel() == total_window_frames * 2 # codebook_dim=2 def test_none_pooling_output_not_finished_returns_none(): diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py index f4a50e677de..73dcc82c233 100644 --- a/tests/test_data_entry_keys.py +++ b/tests/test_data_entry_keys.py @@ -1,41 +1,175 @@ -"""Tests for data_entry_keys: TypedDict payload structure, flatten/unflatten, serialize/deserialize.""" +"""Tests for data_entry_keys.""" +import msgspec +import pytest import torch from vllm_omni.data_entry_keys import ( + CodesStruct, + EmbeddingsStruct, + HiddenStatesStruct, + IdsStruct, + MetaStruct, OmniPayload, + OmniPayloadStruct, deserialize_payload, flatten_payload, serialize_payload, + to_dict, + to_struct, unflatten_payload, ) from vllm_omni.engine import AdditionalInformationPayload -class TestOmniPayload: - def test_nested_payload_structure(self): - """Verify OmniPayload can be constructed with nested dicts.""" - payload: OmniPayload = { - "hidden_states": {"output": torch.tensor([1.0])}, - "embed": {"prefill": torch.tensor([2.0])}, - "codes": {"audio": torch.tensor([3.0])}, - "ids": {"all": [1, 2, 3]}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, +class TestOmniPayloadStruct: + """Runtime-validated mirror of OmniPayload (msgspec.Struct).""" + + def test_to_struct_validates_dict(self): + d = {"meta": {"left_context_size": 25, "finished": torch.tensor(False)}} + s = to_struct(d) + assert s.meta.left_context_size == 25 + + def test_to_struct_rejects_legacy_flat_top_level(self): + with pytest.raises(msgspec.ValidationError, match="unknown field"): + to_struct({"code_predictor_codes": torch.zeros(3, 8)}) + + def test_to_struct_rejects_legacy_flat_meta_field(self): + # `left_context_size` at top level (legacy) instead of under `meta` + with pytest.raises(msgspec.ValidationError, match="unknown field"): + to_struct({"left_context_size": 25}) + + def test_to_struct_rejects_typo_in_subkey(self): + with pytest.raises(msgspec.ValidationError, match="unknown field"): + to_struct({"meta": {"finisheed": True}}) + + def test_to_struct_rejects_wrong_type(self): + with pytest.raises(msgspec.ValidationError, match="Expected"): + to_struct({"meta": {"left_context_size": "not_an_int"}}) + + def test_round_trip_dict_struct_dict(self): + original = { + "meta": {"left_context_size": 7, "finished": torch.tensor(True)}, + "codes": {"audio": torch.zeros(2, 8)}, + "hidden_states": {"output": torch.zeros(4, 16)}, + } + s = to_struct(original) + d = to_dict(s) + assert sorted(d.keys()) == sorted(original.keys()) + for top, sub in original.items(): + assert sorted(d[top].keys()) == sorted(sub.keys()) + + def test_to_dict_drops_unset_fields(self): + s = OmniPayloadStruct(meta=MetaStruct(left_context_size=10)) + d = to_dict(s) + assert d == {"meta": {"left_context_size": 10}} + + def test_struct_with_all_categories(self): + d = { + "hidden_states": {"output": torch.zeros(1)}, + "embed": {"prefill": torch.zeros(1), "tts_bos": torch.zeros(1)}, + "ids": {"all": [1, 2], "prompt": [1]}, + "codes": {"audio": torch.zeros(1)}, + "meta": {"left_context_size": 3, "num_processed_tokens": 7}, } - assert torch.equal(payload["hidden_states"]["output"], torch.tensor([1.0])) - assert torch.equal(payload["embed"]["prefill"], torch.tensor([2.0])) - assert torch.equal(payload["codes"]["audio"], torch.tensor([3.0])) - assert payload["ids"]["all"] == [1, 2, 3] - assert payload["meta"]["finished"].item() is True + s = to_struct(d) + assert isinstance(s.hidden_states, HiddenStatesStruct) + assert isinstance(s.embed, EmbeddingsStruct) + assert isinstance(s.ids, IdsStruct) + assert isinstance(s.codes, CodesStruct) + assert isinstance(s.meta, MetaStruct) + assert s.ids.all == [1, 2] + assert s.meta.num_processed_tokens == 7 + + +class TestValidatePayload: + def test_raises_on_unknown_top_level(self): + from vllm_omni.data_entry_keys import validate_payload + + with pytest.raises(msgspec.ValidationError, match="unknown field"): + validate_payload({"code_predictor_codes": torch.zeros(3, 8)}, context="test_boundary") + + def test_raises_on_unknown_sub_key(self): + from vllm_omni.data_entry_keys import validate_payload - def test_partial_payload(self): - """OmniPayload fields are all optional (total=False).""" - payload: OmniPayload = {"meta": {"finished": torch.tensor(False, dtype=torch.bool)}} - assert payload["meta"]["finished"].item() is False + with pytest.raises(msgspec.ValidationError, match="unknown field"): + validate_payload({"meta": {"finisheed": True}}) - def test_empty_payload(self): - payload: OmniPayload = {} - assert len(payload) == 0 + def test_none_is_ok(self): + from vllm_omni.data_entry_keys import validate_payload + + validate_payload(None) # should not raise + + def test_valid_payload_passes(self): + from vllm_omni.data_entry_keys import validate_payload + + validate_payload({"meta": {"left_context_size": 5}}) + + def test_context_in_error_message(self): + from vllm_omni.data_entry_keys import validate_payload + + with pytest.raises(msgspec.ValidationError, match="my_call_site"): + validate_payload({"bad": 1}, context="my_call_site") + + +class TestWireEquivalenceStructVsDict: + """Producer return-side migration invariant: encoding an ``OmniPayloadStruct`` + via the connector serializer must decode to the same payload as encoding + the equivalent ``to_dict(struct)``. + + Guards against regressions where the wire format diverges between the two + paths (e.g. msgspec adds a struct tag, or ``to_dict`` drops a non-default + sub-field that ``omit_defaults`` retains). + """ + + @staticmethod + def _round_trip(obj): + from vllm_omni.distributed.omni_connectors.utils.serialization import ( + OmniMsgpackDecoder, + OmniMsgpackEncoder, + ) + + return OmniMsgpackDecoder().decode(OmniMsgpackEncoder().encode(obj)) + + @staticmethod + def _assert_decoded_equal(a, b): + if isinstance(a, dict): + assert isinstance(b, dict) + assert sorted(a.keys()) == sorted(b.keys()) + for k in a: + TestWireEquivalenceStructVsDict._assert_decoded_equal(a[k], b[k]) + elif isinstance(a, torch.Tensor): + assert isinstance(b, torch.Tensor) + assert a.dtype == b.dtype + assert a.shape == b.shape + assert torch.equal(a, b) + elif isinstance(a, list): + assert isinstance(b, list) and len(a) == len(b) + for x, y in zip(a, b, strict=True): + TestWireEquivalenceStructVsDict._assert_decoded_equal(x, y) + else: + assert a == b + + def test_basic_payload(self): + struct = OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor([1, 2, 3], dtype=torch.long)), + meta=MetaStruct(left_context_size=10, finished=torch.tensor(True, dtype=torch.bool)), + ) + self._assert_decoded_equal(self._round_trip(struct), self._round_trip(to_dict(struct))) + + def test_nested_sub_structs(self): + # Exercises depth-2 sub-struct encoding (embed.*) which was the case that + # exposed schema drift in the #1829 migration. + struct = OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor([5, 6], dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(False, dtype=torch.bool)), + embed=EmbeddingsStruct( + speech_token=torch.tensor([[1, 2, 3]]), + speech_feat=torch.tensor([[[0.1, 0.2], [0.3, 0.4]]]), + embedding=torch.tensor([[0.5, 0.6]]), + ), + ) + self._assert_decoded_equal(self._round_trip(struct), self._round_trip(to_dict(struct))) class TestFlattenPayload: @@ -76,9 +210,6 @@ def test_hidden_states_layers_expanded(self): assert torch.equal(flat["hidden_states.layer_24"], torch.tensor([3.0])) assert "hidden_states.layers" not in flat - def test_empty_payload(self): - assert flatten_payload({}) == {} - def test_mixed_nested_and_top_level(self): nested: OmniPayload = { "codes": {"audio": torch.tensor([1.0])}, @@ -118,9 +249,6 @@ def test_hidden_states_layers_collected(self): assert torch.equal(nested["hidden_states"]["layers"][0], torch.tensor([2.0])) assert torch.equal(nested["hidden_states"]["layers"][24], torch.tensor([3.0])) - def test_empty_payload(self): - assert unflatten_payload({}) == {} - class TestFlattenUnflattenRoundTrip: def test_round_trip_simple(self): diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index 3575f62b18c..b8eb899fa32 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -1044,10 +1044,12 @@ def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self ) host._request_ids_mapping["r1"] = "r1" payload = { - "thinker_decode_embeddings": torch.ones(1, 2), - "thinker_output_token_ids": [1], - "override_keys": ["thinker_decode_embeddings", "thinker_output_token_ids"], - "finished": torch.tensor(False), + "embed": {"decode": torch.ones(1, 2)}, + "ids": {"output": [1]}, + "meta": { + "finished": torch.tensor(False), + "override_keys": [["embed", "decode"], ["ids", "output"]], + }, } host._accumulate_payload("r1", dict(payload)) @@ -1065,15 +1067,16 @@ def test_payload_consumable_ignores_token_horizon_only_updates(self): model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), ) payload = { - "thinker_output_token_ids": [1, 2, 3], - "finished": torch.tensor(False), - "override_keys": [ - "thinker_output_token_ids", - "thinker_decode_embeddings_token_start", - "thinker_decode_embeddings_token_end", - ], - "thinker_decode_embeddings_token_start": 2, - "thinker_decode_embeddings_token_end": 3, + "ids": {"output": [1, 2, 3]}, + "embed": {"decode_token_start": 2, "decode_token_end": 3}, + "meta": { + "finished": torch.tensor(False), + "override_keys": [ + ["ids", "output"], + ["embed", "decode_token_start"], + ["embed", "decode_token_end"], + ], + }, } self.assertFalse(host._payload_is_consumable(payload)) host.shutdown_omni_connectors() @@ -1085,9 +1088,9 @@ def test_payload_consumable_accepts_decode_embeddings(self): model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), ) payload = { - "thinker_output_token_ids": [1, 2, 3], - "thinker_decode_embeddings": torch.ones(1, 2), - "finished": torch.tensor(False), + "ids": {"output": [1, 2, 3]}, + "embed": {"decode": torch.ones(1, 2)}, + "meta": {"finished": torch.tensor(False)}, } self.assertTrue(host._payload_is_consumable(payload)) host.shutdown_omni_connectors() @@ -1108,15 +1111,14 @@ def test_ar_metadata_only_followup_chunk_does_not_rewake_request(self): host._omni_connector.get.side_effect = [ ( { - "thinker_decode_embeddings": torch.ones(1, 2), - "finished": torch.tensor(False), + "embed": {"decode": torch.ones(1, 2)}, + "meta": {"finished": torch.tensor(False)}, }, 1, ), ( { - "next_stage_prompt_len": 7, - "finished": torch.tensor(False), + "meta": {"next_stage_prompt_len": 7, "finished": torch.tensor(False)}, }, 1, ), diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index f22aad9d46d..b239ff5ed0b 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -1,41 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Structured payload types for inter-stage communication. -Adding a new model? -~~~~~~~~~~~~~~~~~~~ -Every key you put into the inter-stage payload (``additional_information``, -``multimodal_output``, ``pooling_output``) **must** use the nested -``OmniPayload`` TypedDict structure. For each category, every known -qualifier is an explicit field so misspellings are caught statically. - -Categories +Categories under ``OmniPayload``: hidden_states – intermediate / output hidden-state tensors embed – embedding tensors (prefill, decode, special tokens) ids – token-ID sequences codes – codec / audio code tensors meta – scalar metadata, control flags, shapes - -This module provides: -- Structured ``TypedDict`` types for static type checking (``OmniPayload``) -- ``serialize_payload`` / ``deserialize_payload`` for transport across - process boundaries via ``AdditionalInformationPayload`` """ from __future__ import annotations from typing import TYPE_CHECKING, Any, TypedDict +import msgspec import numpy as np import torch if TYPE_CHECKING: from vllm_omni.engine import AdditionalInformationEntry, AdditionalInformationPayload -# ── Structured payload types ── -# These are TypedDicts (plain dicts at runtime, zero overhead) that give -# static type checking and IDE autocomplete for inter-stage payloads. -# Every field is optional (total=False) because each stage only populates -# the subset it needs. - class HiddenStates(TypedDict, total=False): output: torch.Tensor @@ -54,6 +39,8 @@ class Embeddings(TypedDict, total=False): tts_pad_projected: torch.Tensor voice: torch.Tensor speech_feat: torch.Tensor + speech_token: torch.Tensor + embedding: torch.Tensor thinker_reply: torch.Tensor @@ -72,6 +59,8 @@ class Ids(TypedDict, total=False): class OmniPayloadMeta(TypedDict, total=False): finished: torch.Tensor + stream_finished: torch.Tensor + req_id: list[str] left_context_size: int override_keys: list[tuple[str, str]] num_processed_tokens: int @@ -105,45 +94,187 @@ class OmniPayload(TypedDict, total=False): request_id: str -# ── Keys whose values are nested dicts (TypedDict sub-categories) ── -_NESTED_KEYS = frozenset({"hidden_states", "embed", "ids", "codes", "meta"}) - -# Sub-TypedDict for each nested category, used by runtime validation. -_NESTED_SCHEMAS: dict[str, type] = { - "hidden_states": HiddenStates, - "embed": Embeddings, - "ids": Ids, - "codes": Codes, - "meta": OmniPayloadMeta, +# ── msgspec.Struct mirror of the TypedDicts (runtime-validated) ── + + +class _StructBase(msgspec.Struct, omit_defaults=True, kw_only=True, forbid_unknown_fields=True): + pass + + +class HiddenStatesStruct(_StructBase): + output: torch.Tensor | None = None + output_shape: list[int] | None = None + trailing_text: torch.Tensor | None = None + last: torch.Tensor | None = None + layers: dict[int, torch.Tensor] | None = None + + +class EmbeddingsStruct(_StructBase): + prefill: torch.Tensor | None = None + prefill_shape: list[int] | None = None + decode: torch.Tensor | None = None + decode_token_start: int | None = None + decode_token_end: int | None = None + cached_decode: torch.Tensor | None = None + tts_bos: torch.Tensor | None = None + tts_eos: torch.Tensor | None = None + tts_pad: torch.Tensor | None = None + tts_pad_projected: torch.Tensor | None = None + voice: torch.Tensor | None = None + speech_feat: torch.Tensor | None = None + speech_token: torch.Tensor | None = None + embedding: torch.Tensor | None = None + thinker_reply: torch.Tensor | None = None + + +class CodesStruct(_StructBase): + audio: torch.Tensor | None = None + ref: torch.Tensor | None = None + + +class IdsStruct(_StructBase): + all: list[int] | None = None + prompt: list[int] | None = None + output: list[int] | None = None + speech_token: list[int] | None = None + prior_image: list[int] | None = None + + +class MetaStruct(_StructBase): + finished: torch.Tensor | None = None + stream_finished: torch.Tensor | None = None + req_id: list[str] | None = None + left_context_size: int | None = None + override_keys: list[tuple[str, str]] | None = None + num_processed_tokens: int | None = None + next_stage_prompt_len: int | None = None + ar_width: int | None = None + eol_token_id: int | None = None + visual_token_start_id: int | None = None + visual_token_end_id: int | None = None + gen_token_mask: torch.Tensor | None = None + omni_task: list[str] | None = None + height: int | None = None + width: int | None = None + decode_flag: bool | None = None + codec_streaming: bool | None = None + ref_code_len: int | None = None + talker_prefill_offset: int | None = None + codec_chunk_frames: int | None = None + codec_left_context_frames: int | None = None + code_flat_numel: int | None = None + omni_final_stage_id: int | None = None + + +class OmniPayloadStruct(_StructBase): + hidden: torch.Tensor | None = None + hidden_states: HiddenStatesStruct | None = None + embed: EmbeddingsStruct | None = None + ids: IdsStruct | None = None + codes: CodesStruct | None = None + meta: MetaStruct | None = None + latent: torch.Tensor | None = None + generated_len: int | None = None + model_outputs: list[torch.Tensor] | None = None + mtp_inputs: tuple[torch.Tensor, torch.Tensor] | None = None + speaker: list[str] | str | None = None + language: list[str] | str | None = None + request_id: str | None = None + past_key_values: list[int] | None = None + kv_metadata: dict[str, Any] | None = None + + +_NESTED_STRUCTS: dict[str, type[_StructBase]] = { + "hidden_states": HiddenStatesStruct, + "embed": EmbeddingsStruct, + "ids": IdsStruct, + "codes": CodesStruct, + "meta": MetaStruct, } -_ROOT_KEYS: frozenset[str] = frozenset(OmniPayload.__annotations__.keys()) + +_TENSOR_MARKER = "__tensor__" -def assert_payload(payload: dict[str, Any], *, context: str = "payload") -> None: - """Validate ``payload`` matches the ``OmniPayload`` nested schema. +def _msgspec_dec_hook(typ: type, obj: Any) -> Any: + """Bridge non-msgspec types when decoding bytes/dicts into Structs.""" + if typ is torch.Tensor: + if isinstance(obj, torch.Tensor): + return obj + if isinstance(obj, dict) and obj.get(_TENSOR_MARKER): + arr = np.frombuffer(obj["data"], dtype=np.dtype(obj["dtype"])) + arr = arr.reshape(obj["shape"]) + return torch.from_numpy(arr.copy()) + raise TypeError(f"cannot decode {type(obj).__name__} into torch.Tensor") + raise NotImplementedError(f"no decoder for {typ}") - TypedDict is a static-only contract in Python; this helper closes the - loop at runtime by rejecting: - * non-dict payloads - * top-level keys not declared on ``OmniPayload`` - * nested-category values that aren't dicts - * sub-keys not declared on the matching nested TypedDict - Call at producer/consumer boundaries when a schema violation should - crash the pipeline instead of silently degrading audio quality. +def to_struct(payload: dict[str, Any]) -> OmniPayloadStruct: + """Convert a payload dict into ``OmniPayloadStruct``, validating types. + + Raises ``msgspec.ValidationError`` on: + * unknown top-level keys (typos, legacy flat keys) + * unknown sub-keys under any nested category + * type mismatches (e.g., ``meta.left_context_size`` not an ``int``) """ - assert isinstance(payload, dict), f"{context}: expected dict, got {type(payload).__name__}" - extra_top = set(payload) - _ROOT_KEYS - assert not extra_top, f"{context}: unknown top-level keys {sorted(extra_top)!r}" - for nested_key, schema in _NESTED_SCHEMAS.items(): - if nested_key not in payload: + return msgspec.convert(payload, OmniPayloadStruct, dec_hook=_msgspec_dec_hook) + + +def validate_payload(payload: dict[str, Any] | None, *, context: str = "payload") -> None: + """Validate a payload matches the ``OmniPayload`` schema, raising on drift. + + Wraps :func:`to_struct` and re-raises ``msgspec.ValidationError`` with + the call-site ``context`` prepended. ``None`` is allowed (treated as + "no payload to check"). + """ + if payload is None: + return + try: + to_struct(payload) + except msgspec.ValidationError as exc: + raise msgspec.ValidationError(f"{context}: {exc}") from exc + + +def to_dict(struct: OmniPayloadStruct) -> dict[str, Any]: + """Convert ``OmniPayloadStruct`` to a plain dict, dropping ``None`` fields.""" + out: dict[str, Any] = {} + for field in OmniPayloadStruct.__struct_fields__: + value = getattr(struct, field) + if value is None: continue - sub = payload[nested_key] - assert isinstance(sub, dict), f"{context}: payload[{nested_key!r}] must be dict, got {type(sub).__name__}" - known_sub = frozenset(schema.__annotations__.keys()) - extra_sub = set(sub) - known_sub - assert not extra_sub, f"{context}: payload[{nested_key!r}] unknown sub-keys {sorted(extra_sub)!r}" + if isinstance(value, _StructBase): + sub: dict[str, Any] = {} + for sk in value.__struct_fields__: + sv = getattr(value, sk) + if sv is not None: + sub[sk] = sv + if sub: + out[field] = sub + else: + out[field] = value + return out + + +_DTYPE_TO_NAME: dict[torch.dtype, str] = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float64: "float64", + torch.int64: "int64", + torch.int32: "int32", + torch.int16: "int16", + torch.int8: "int8", + torch.uint8: "uint8", + torch.bool: "bool", +} + + +def _dtype_to_name(dtype: torch.dtype) -> str: + return _DTYPE_TO_NAME.get(dtype, str(dtype).replace("torch.", "")) + + +# ── Keys whose values are nested dicts (TypedDict sub-categories) ── +_NESTED_KEYS = frozenset({"hidden_states", "embed", "ids", "codes", "meta"}) def flatten_payload(payload: dict[str, Any]) -> dict[str, Any]: @@ -192,25 +323,6 @@ def unflatten_payload(flat: dict[str, Any]) -> dict[str, Any]: return result -# ── dtype helpers ── -_DTYPE_TO_NAME: dict[torch.dtype, str] = { - torch.float32: "float32", - torch.float16: "float16", - torch.bfloat16: "bfloat16", - torch.float64: "float64", - torch.int64: "int64", - torch.int32: "int32", - torch.int16: "int16", - torch.int8: "int8", - torch.uint8: "uint8", - torch.bool: "bool", -} - - -def _dtype_to_name(dtype: torch.dtype) -> str: - return _DTYPE_TO_NAME.get(dtype, str(dtype).replace("torch.", "")) - - def _serialize_tensor(t: torch.Tensor) -> AdditionalInformationEntry: from vllm_omni.engine import AdditionalInformationEntry diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 4535c2596dd..2bdb1136976 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -3,12 +3,13 @@ import importlib from collections import defaultdict, deque +from collections.abc import Callable from typing import Any import torch from vllm.v1.request import Request, RequestStatus -from vllm_omni.data_entry_keys import unflatten_payload +from vllm_omni.data_entry_keys import OmniPayloadStruct, unflatten_payload from ..factory import OmniConnectorFactory from ..utils.config import ConnectorSpec @@ -43,7 +44,7 @@ def __init__(self, vllm_config: Any): super().__init__(model_config) self.model_mode = getattr(model_config, "worker_type", None) or "ar" # State specific to Chunk management - self.custom_process_next_stage_input_func = None + self.custom_process_next_stage_input_func: Callable[..., OmniPayloadStruct | None] | None = None custom_process_next_stage_input_func = getattr(model_config, "custom_process_next_stage_input_func", None) if custom_process_next_stage_input_func: module_path, func_name = custom_process_next_stage_input_func.rsplit(".", 1) @@ -162,7 +163,11 @@ def _poll_single_request(self, request: Request): if meta.get("finished"): self.finished_requests.add(req_id) - new_ids = payload_data.get("codes", {}).get("audio", []) + new_ids = payload_data.get("codes", {}).get("audio") + if isinstance(new_ids, torch.Tensor): + new_ids = new_ids.tolist() + elif new_ids is None: + new_ids = [] request.prompt_token_ids = new_ids prev_info = getattr(request, "additional_information", None) info = dict(prev_info) if isinstance(prev_info, dict) else {} @@ -202,21 +207,29 @@ def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> raw_ok = payload_data.get("meta", {}).pop("override_keys", []) override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok} - for type_key, new_val in payload_data.items(): - if not isinstance(new_val, dict): - continue - origin_sub = origin.get(type_key) - if not isinstance(origin_sub, dict): - continue - for qual, value in new_val.items(): - if type_key == "meta" and qual == "finished": + for key, value in payload_data.items(): + if isinstance(value, dict): + origin_sub = origin.get(key) + if not isinstance(origin_sub, dict): continue - if (type_key, qual) in override_keys: + for qual, qval in value.items(): + if key == "meta" and qual == "finished": + continue + if (key, qual) in override_keys: + continue + osv = origin_sub.get(qual) + if isinstance(qval, torch.Tensor) and isinstance(osv, torch.Tensor): + value[qual] = torch.cat([osv, qval], dim=0) + elif isinstance(qval, list) and isinstance(osv, list): + value[qual] = osv + qval + else: + if key in override_keys: continue - if isinstance(value, torch.Tensor) and qual in origin_sub: - new_val[qual] = torch.cat([origin_sub[qual], value], dim=0) - elif isinstance(value, list) and qual in origin_sub: - new_val[qual] = origin_sub[qual] + value + ov = origin.get(key) + if isinstance(value, torch.Tensor) and isinstance(ov, torch.Tensor): + payload_data[key] = torch.cat([ov, value], dim=0) + elif isinstance(value, list) and isinstance(ov, list): + payload_data[key] = ov + value self.request_payload[req_id] = payload_data return payload_data @@ -232,7 +245,7 @@ def _send_single_request(self, task: dict): chunk_id = self.put_req_chunk[external_req_id] connector_put_key = f"{external_req_id}_{stage_id}_{chunk_id}" # Process payload in save_loop thread - payload_data = None + payload_data: OmniPayloadStruct | None = None if self.custom_process_next_stage_input_func: try: payload_data = self.custom_process_next_stage_input_func( @@ -245,7 +258,7 @@ def _send_single_request(self, task: dict): except Exception as e: logger.error(f"Failed to use custom_process_input_func for payload extraction: {e}") - if not payload_data: + if payload_data is None: return success, size, metadata = self.connector.put( @@ -258,7 +271,12 @@ def _send_single_request(self, task: dict): if success: self.put_req_chunk[external_req_id] += 1 logger.debug(f"[Stage-{stage_id}] Sent {connector_put_key}") - finished_flag = payload_data.get("meta", {}).get("finished", payload_data.get("finished")) + # Sender uses struct attr access here; the receive path in + # `_load_one_request` / `_update_request_payload` reads dict keys. + # That asymmetry is intentional: `OmniMsgpackDecoder` is type-erased + # (no target type), so the wire round-trips struct -> dict. If you + # change the schema, update both ends — see test_wire_round_trip. + finished_flag = payload_data.meta.finished if payload_data.meta is not None else None is_payload_finished = False if isinstance(finished_flag, torch.Tensor): is_payload_finished = finished_flag.numel() == 1 and bool(finished_flag.item()) diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 9134b292b7d..b9bfff2635e 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -32,6 +32,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler +from vllm_omni.data_entry_keys import EmbeddingsStruct, OmniPayloadStruct, to_dict, to_struct from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config from vllm_omni.model_executor.models.cosyvoice3.utils import ( concat_text_with_prompt_ids, @@ -379,43 +380,6 @@ def _create_llm_vllm_config(self, parent_config: VllmConfig) -> VllmConfig: # Use parent's cache config - critical for PagedAttention to work correctly return parent_config.with_hf_config(qwen_hf_config, architectures=["Qwen2Model"]) - @staticmethod - def _as_tensor(value: object) -> torch.Tensor | None: - """Extract tensor payload from runtime info fields.""" - if isinstance(value, list): - if not value: - return None - value = value[0] - if isinstance(value, torch.Tensor): - return value - return None - - @staticmethod - def _as_str(value: object) -> str | None: - """Extract string payload from runtime info fields.""" - if isinstance(value, list): - if not value: - return None - value = value[0] - if value is None: - return None - return str(value) - - @staticmethod - def _as_bool(value: object) -> bool: - """Extract boolean payload from runtime info fields.""" - if isinstance(value, list): - if not value: - return False - value = value[0] - if isinstance(value, torch.Tensor): - if value.numel() == 0: - return False - return bool(value.reshape(-1)[0].item()) - if value is None: - return False - return bool(value) - @staticmethod def _cross_fade_audio(audio: torch.Tensor, prev_tail: torch.Tensor) -> torch.Tensor: """Blend previous chunk tail into current chunk head using a Hamming window. @@ -710,12 +674,17 @@ def forward( multimodal_outputs = {} if "speech_token" in kwargs: - # Wrap in lists to pass through gpu_ar_model_runner shape filtering - multimodal_outputs = { - "speech_token": [kwargs.get("speech_token")], - "speech_feat": [kwargs.get("speech_feat")], - "embedding": [kwargs.get("embedding")], - } + # Prompt conditioning tensors for code2wav: live under + # ``embed.*`` per OmniPayloadStruct schema. + multimodal_outputs = to_dict( + OmniPayloadStruct( + embed=EmbeddingsStruct( + speech_token=kwargs.get("speech_token"), + speech_feat=kwargs.get("speech_feat"), + embedding=kwargs.get("embedding"), + ), + ) + ) return OmniOutput(text_hidden_states=hidden_states, multimodal_outputs=multimodal_outputs) elif self.model_stage == "cosyvoice3_code2wav": @@ -738,23 +707,29 @@ def forward( runtime_info = [] for idx, req_ids in enumerate(request_ids_list): - info = runtime_info[idx] if idx < len(runtime_info) and isinstance(runtime_info[idx], dict) else {} - req_id = self._as_str(info.get("req_id")) if info else None - stream_finished = self._as_bool(info.get("stream_finished")) if info else False - speech_token = self._as_tensor(info.get("speech_token")) if info else None - speech_feat = self._as_tensor(info.get("speech_feat")) if info else None - embedding = self._as_tensor(info.get("embedding")) if info else None + raw = runtime_info[idx] if idx < len(runtime_info) and isinstance(runtime_info[idx], dict) else {} + payload = to_struct(raw) + meta = payload.meta + embed = payload.embed + + req_id = meta.req_id[0] if (meta and meta.req_id) else None + stream_finished = ( + bool(meta.stream_finished.item()) if (meta and meta.stream_finished is not None) else False + ) + speech_token = embed.speech_token if embed else None + speech_feat = embed.speech_feat if embed else None + embedding = embed.embedding if embed else None if speech_token is None or speech_feat is None or embedding is None: if stream_finished and req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"): with self._stream_audio_cache_lock: self._stream_vocoder_cache_by_req.pop(req_id, None) audios[idx] = self._stitch_stream_audio(req_id, empty_audio, stream_finished) - if ( - req_ids.numel() > 0 - and info - and ("token_offset" in info or "left_context_size" in info or "generated_len" in info) + if req_ids.numel() > 0 and ( + (meta and meta.left_context_size is not None) or payload.generated_len is not None ): - info_keys = ",".join(sorted(info.keys())) if info else "" + info_keys = ",".join( + sorted(f for f in payload.__struct_fields__ if getattr(payload, f) is not None) + ) logger.warning_once( "CosyVoice3 code2wav missing prompt conditioning for non-empty codec tokens: " "raw_len=%d info_keys=%s", @@ -780,18 +755,11 @@ def forward( # `generated_len` is injected for many models by the generic # runner, so only explicit chunk-routing fields should switch # code2wav into the streaming path. - uses_streaming_decode = bool(info) and ( - "stream_finished" in info or "token_offset" in info or "left_context_size" in info + uses_streaming_decode = meta and ( + meta.stream_finished is not None or meta.left_context_size is not None ) if uses_streaming_decode: - token_offset = 0 - try: - if info and "token_offset" in info: - token_offset = max(0, int(info.get("token_offset", 0))) - elif info: - token_offset = max(0, int(info.get("left_context_size", 0))) - except (TypeError, ValueError): - token_offset = 0 + token_offset = max(0, meta.left_context_size or 0) cache_state = None if req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"): diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 31cdecebe85..fb4721ca25b 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -137,13 +137,7 @@ def forward( break meta = info.get("meta", {}) if "left_context_size" in meta: - # left_context_size may come through serialization as an int, [int], or tensor([int]). - value = meta["left_context_size"] - if isinstance(value, list): - value = value[0] if value else 0 - if isinstance(value, torch.Tensor): - value = value.reshape(-1)[0].item() if value.numel() > 0 else 0 - left_context_size[i] = int(value) + left_context_size[i] = meta["left_context_size"] for i, req_ids in enumerate(request_ids_list): if req_ids.numel() < 1: parsed.append((0, 0)) diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 377d8d19b33..cf1ca39ee59 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -8,9 +8,29 @@ import torch from vllm.inputs import TextPrompt +from vllm_omni.data_entry_keys import ( + CodesStruct, + EmbeddingsStruct, + MetaStruct, + OmniPayloadStruct, +) from vllm_omni.inputs.data import OmniTokensPrompt +def _build_prompt_embed_struct(prompt_payload: dict[str, Any]) -> EmbeddingsStruct | None: + """Wrap prompt_payload's flat speech_token/speech_feat/embedding tensors into EmbeddingsStruct.""" + speech_token = prompt_payload.get("speech_token") + speech_feat = prompt_payload.get("speech_feat") + embedding = prompt_payload.get("embedding") + if speech_token is None and speech_feat is None and embedding is None: + return None + return EmbeddingsStruct( + speech_token=speech_token, + speech_feat=speech_feat, + embedding=embedding, + ) + + def _ensure_list(x: Any) -> list[Any]: if hasattr(x, "_x"): return list(x._x) @@ -88,7 +108,7 @@ def talker2code2wav_async_chunk( pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayloadStruct | None: """CosyVoice3 async_chunk processor: talker token stream -> code2wav chunks.""" with nullcontext(): request_id = request.external_req_id @@ -114,16 +134,18 @@ def talker2code2wav_async_chunk( if not isinstance(request_state, dict) or "_cosyvoice3_async_state" not in request_state: with nullcontext(): info = _decode_additional_information(getattr(request, "additional_information", None)) + info_embed = info.get("embed", {}) if isinstance(info, dict) else {} prompt_payload = {} for key in ("speech_token", "speech_feat", "embedding"): - value = _to_cpu_tensor(info.get(key)) + value = _to_cpu_tensor(info_embed.get(key)) if value is not None: prompt_payload[key] = value if isinstance(pooling_output, dict): + po_embed = pooling_output.get("embed", {}) if isinstance(pooling_output.get("embed"), dict) else {} for key in ("speech_token", "speech_feat", "embedding"): if key in prompt_payload: continue - value = _to_cpu_tensor(pooling_output.get(key)) + value = _to_cpu_tensor(po_embed.get(key)) if value is not None: prompt_payload[key] = value prompt_token = prompt_payload.get("speech_token") @@ -176,27 +198,29 @@ def talker2code2wav_async_chunk( if length <= 0: if not finished: return None - payload: dict[str, Any] = { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + embed_struct = None if not state.get("sent_prompt", False): - payload.update(state.get("prompt_payload", {})) + embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True state["terminal_sent"] = True - return payload + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + embed=embed_struct, + ) emitted_token_len = int(state.get("emitted_token_len", 0)) if finished and length <= emitted_token_len: - payload = { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + embed_struct = None if not state.get("sent_prompt", False): - payload.update(state.get("prompt_payload", {})) + embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True state["terminal_sent"] = True - return payload + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + embed=embed_struct, + ) with nullcontext(): token_hop_len = max(1, int(state.get("token_hop_len", chunk_size))) @@ -220,18 +244,22 @@ def talker2code2wav_async_chunk( with nullcontext(): code_predictor_codes = [int(frame[0]) for frame in token_frames[:prefix_len]] - payload = { - "codes": {"audio": code_predictor_codes}, - "meta": {"finished": torch.tensor(finished, dtype=torch.bool)}, - "token_offset": token_offset, - "left_context_size": token_offset, - "req_id": [request_id], - "stream_finished": torch.tensor(finished, dtype=torch.bool), - } + embed_struct = None if not state.get("sent_prompt", False): - payload.update(state.get("prompt_payload", {})) + embed_struct = _build_prompt_embed_struct(state.get("prompt_payload", {})) state["sent_prompt"] = True + payload = OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(code_predictor_codes, dtype=torch.long)), + meta=MetaStruct( + finished=torch.tensor(finished, dtype=torch.bool), + stream_finished=torch.tensor(finished, dtype=torch.bool), + req_id=[request_id], + left_context_size=token_offset, + ), + embed=embed_struct, + ) + if not finished: state["emitted_token_len"] = emitted_token_len + this_token_hop_len state["token_hop_len"] = min( diff --git a/vllm_omni/model_executor/stage_input_processors/fish_speech.py b/vllm_omni/model_executor/stage_input_processors/fish_speech.py index 542ebefd4ea..6eac326a3d3 100644 --- a/vllm_omni/model_executor/stage_input_processors/fish_speech.py +++ b/vllm_omni/model_executor/stage_input_processors/fish_speech.py @@ -5,6 +5,12 @@ import torch from vllm.logger import init_logger +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayloadStruct, +) + logger = init_logger(__name__) @@ -59,7 +65,7 @@ def slow_ar_to_dac_decoder_async_chunk( pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayloadStruct | None: """Async streaming processor: emit code chunks as they are produced. Accumulates per-step codes and emits fixed-size chunks with left context @@ -107,10 +113,10 @@ def slow_ar_to_dac_decoder_async_chunk( if length <= 0: if finished: - return { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) return None in_initial_phase = initial_chunk_size > 0 and length <= chunk_size @@ -138,9 +144,12 @@ def slow_ar_to_dac_decoder_async_chunk( # Pack into codebook-major flat codes. stacked_frames = torch.stack(window_frames, dim=0) - code_predictor_codes = stacked_frames.transpose(0, 1).reshape(-1).tolist() - - return { - "codes": {"audio": code_predictor_codes}, - "meta": {"left_context_size": left_context_size, "finished": torch.tensor(finished, dtype=torch.bool)}, - } + code_predictor_codes = stacked_frames.transpose(0, 1).reshape(-1) + + return OmniPayloadStruct( + codes=CodesStruct(audio=code_predictor_codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(finished, dtype=torch.bool), + ), + ) diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index 39071f760e8..b1df30cb1fb 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -4,6 +4,12 @@ from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, +) from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID @@ -53,9 +59,12 @@ def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torc return y_col_major -def _make_finished_sentinel() -> dict[str, Any]: +def _make_finished_sentinel() -> OmniPayloadStruct: """Return a minimal payload with finished=True so Stage-1 can end the request.""" - return {"codes": {"audio": []}, "meta": {"finished": torch.tensor(True, dtype=torch.bool)}} + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) def _flush_remaining_codes( @@ -63,7 +72,7 @@ def _flush_remaining_codes( request_id: str, chunk_size: int, left_context_size: int, -) -> dict[str, Any]: +) -> OmniPayloadStruct: """Flush any accumulated but unsent codes when the request finishes.""" accumulated = transfer_manager.code_prompt_token_ids.get(request_id, []) if not accumulated: @@ -76,18 +85,18 @@ def _flush_remaining_codes( # Align with qwen3_omni talker2code2wav_async_chunk: decoder strip uses explicit frame count. left_ctx_frames = max(0, min(length - context_length, left_context_size)) - flat_codes = torch.tensor(accumulated[-end_index:]).reshape(-1).tolist() + flat_codes = torch.tensor(accumulated[-end_index:]).reshape(-1) - return { - "codes": {"audio": flat_codes}, - "meta": { - "left_context_size": left_ctx_frames, - "codec_chunk_frames": chunk_size, - "codec_left_context_frames": left_context_size, - "code_flat_numel": len(flat_codes), - "finished": torch.tensor(True, dtype=torch.bool), - }, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=flat_codes), + meta=MetaStruct( + left_context_size=left_ctx_frames, + codec_chunk_frames=chunk_size, + codec_left_context_frames=left_context_size, + code_flat_numel=int(flat_codes.numel()), + finished=torch.tensor(True, dtype=torch.bool), + ), + ) def _is_codes_empty(codes: Any) -> bool: @@ -114,10 +123,10 @@ def _to_code_tensor(codes: Any) -> torch.Tensor | None: def llm2code2wav_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayloadStruct | None: """ Async chunk version: convert stage-0 pooling_output to code2wav payload (pooling / connector accumulation). @@ -167,18 +176,16 @@ def llm2code2wav_async_chunk( left_ctx_frames = max(0, min(length - context_length, left_context_size)) flat_codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).reshape(-1).tolist() - return { - "codes": { - "audio": flat_codes, - }, - "meta": { - "left_context_size": left_ctx_frames, - "codec_chunk_frames": chunk_size, - "codec_left_context_frames": left_context_size, - "code_flat_numel": len(flat_codes), - "finished": torch.tensor(is_finished, dtype=torch.bool), - }, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(flat_codes)), + meta=MetaStruct( + left_context_size=left_ctx_frames, + codec_chunk_frames=chunk_size, + codec_left_context_frames=left_context_size, + code_flat_numel=len(flat_codes), + finished=torch.tensor(is_finished, dtype=torch.bool), + ), + ) def llm2code2wav( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index a69ef522749..472dcb93386 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -1,7 +1,14 @@ import torch from vllm.inputs import TextPrompt -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import ( + EmbeddingsStruct, + HiddenStatesStruct, + IdsStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.inputs.data import OmniTokensPrompt TALKER_CODEC_PAD_TOKEN_ID = 8292 @@ -30,17 +37,15 @@ def thinker2talker( mm: OmniPayload = output.multimodal_output latent = mm["latent"] thinker_hidden_states = latent.clone().detach().to(latent.device) - additional_information = { - "hidden_states": { - "output": thinker_hidden_states[prompt_token_ids_len:].to(torch.float32), - "output_shape": list(thinker_hidden_states[prompt_token_ids_len:].shape), - }, - "embed": { - "prefill": thinker_hidden_states[:prompt_token_ids_len].to(torch.float32), - "prefill_shape": list(thinker_hidden_states[:prompt_token_ids_len].shape), - }, - "ids": {"prompt": prompt_token_ids, "output": thinker_output_ids}, - } + decode_hidden = thinker_hidden_states[prompt_token_ids_len:].to(torch.float32) + prefill_hidden = thinker_hidden_states[:prompt_token_ids_len].to(torch.float32) + additional_information = to_dict( + OmniPayloadStruct( + hidden_states=HiddenStatesStruct(output=decode_hidden, output_shape=list(decode_hidden.shape)), + embed=EmbeddingsStruct(prefill=prefill_hidden, prefill_shape=list(prefill_hidden.shape)), + ids=IdsStruct(prompt=list(prompt_token_ids), output=list(thinker_output_ids)), + ) + ) talker_inputs.append( OmniTokensPrompt( prompt_token_ids=[TALKER_CODEC_START_TOKEN_ID] diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 0eea1c2c038..63403619e9b 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -11,7 +11,16 @@ from vllm.inputs import TextPrompt from vllm.platforms import current_platform -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import ( + CodesStruct, + EmbeddingsStruct, + HiddenStatesStruct, + IdsStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.stage_input_processors.tts_utils import ( @@ -275,10 +284,10 @@ def _get_streaming_codec_delta_len( def thinker2talker_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: OmniEngineCoreRequest, is_finished: bool = False, -) -> list[dict[str, Any]]: +) -> OmniPayloadStruct | None: """ Process thinker outputs to create talker inputs. 1. thinker's text generation outputs (token IDs + hidden states) @@ -294,7 +303,8 @@ def thinker2talker_async_chunk( thinker_hs = pooling_output.get("hidden_states", {}) thinker_layers = thinker_hs.get("layers", {}) if isinstance(thinker_hs, dict) else {} - thinker_embed = pooling_output.get("embed", {}) if isinstance(pooling_output.get("embed", {}), dict) else {} + thinker_embed_raw = pooling_output.get("embed", {}) + thinker_embed = thinker_embed_raw if isinstance(thinker_embed_raw, dict) else {} thinker_emb = _layer_tensor(thinker_layers, _EMBED_LAYER_KEY) thinker_hid = _layer_tensor(thinker_layers, _HIDDEN_LAYER_KEY) if thinker_emb is None or thinker_hid is None: @@ -305,82 +315,62 @@ def thinker2talker_async_chunk( thinker_hid is not None, ) return None + speaker = extract_speaker_from_request(request) + language = extract_language_from_request(request) + + def _maybe_cpu(t: Any) -> torch.Tensor | None: + return t.detach().cpu() if isinstance(t, torch.Tensor) else None if chunk_id == 0: - all_token_ids = request.all_token_ids # prefill + decode - prompt_token_ids = request.prompt_token_ids - # Convert ConstantList to regular list for OmniSerializer serialization - all_token_ids = _ensure_list(all_token_ids) - prompt_token_ids = _ensure_list(prompt_token_ids) - payload: OmniPayload = { - "embed": { - "prefill": thinker_emb.detach().cpu(), - # Provide thinker-side TTS token embeddings for talker projection - "tts_bos": thinker_embed.get("tts_bos").detach().cpu() - if isinstance(thinker_embed.get("tts_bos"), torch.Tensor) - else None, - "tts_eos": thinker_embed.get("tts_eos").detach().cpu() - if isinstance(thinker_embed.get("tts_eos"), torch.Tensor) - else None, - "tts_pad": thinker_embed.get("tts_pad").detach().cpu() - if isinstance(thinker_embed.get("tts_pad"), torch.Tensor) - else None, - }, - "hidden_states": {"output": thinker_hid.detach().cpu()}, - "ids": {"all": all_token_ids, "prompt": prompt_token_ids}, - "meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)}, - } - talker_additional_info = payload - speaker = extract_speaker_from_request(request) - if speaker is not None: - talker_additional_info["speaker"] = speaker - language = extract_language_from_request(request) - if language is not None: - talker_additional_info["language"] = language + all_token_ids = _ensure_list(request.all_token_ids) + prompt_token_ids = _ensure_list(request.prompt_token_ids) + payload = OmniPayloadStruct( + embed=EmbeddingsStruct( + prefill=thinker_emb.detach().cpu(), + tts_bos=_maybe_cpu(thinker_embed.get("tts_bos")), + tts_eos=_maybe_cpu(thinker_embed.get("tts_eos")), + tts_pad=_maybe_cpu(thinker_embed.get("tts_pad")), + ), + hidden_states=HiddenStatesStruct(output=thinker_hid.detach().cpu()), + ids=IdsStruct(all=all_token_ids, prompt=prompt_token_ids), + meta=MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool)), + speaker=speaker, + language=language, + ) if transfer_manager.request_payload.get(request_id) is None: if not is_finished: - transfer_manager.request_payload[request_id] = talker_additional_info + transfer_manager.request_payload[request_id] = to_dict(payload) return None else: save_payload = transfer_manager.request_payload.pop(request_id) - talker_additional_info["embed"]["prefill"] = torch.cat( - ( - save_payload.get("embed", {}).get("prefill"), - talker_additional_info.get("embed", {}).get("prefill"), - ), - dim=0, + payload.embed.prefill = torch.cat( + (save_payload.get("embed", {}).get("prefill"), payload.embed.prefill), dim=0 ) - talker_additional_info["hidden_states"]["output"] = torch.cat( - ( - save_payload.get("hidden_states", {}).get("output"), - talker_additional_info.get("hidden_states", {}).get("output"), - ), - dim=0, + payload.hidden_states.output = torch.cat( + (save_payload.get("hidden_states", {}).get("output"), payload.hidden_states.output), dim=0 ) else: - output_token_ids = request.output_token_ids - # Convert ConstantList to regular list for OmniSerializer serialization - output_token_ids = _ensure_list(output_token_ids) - - talker_additional_info: OmniPayload = { - "meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)}, - } - speaker = extract_speaker_from_request(request) - if speaker is not None: - talker_additional_info["speaker"] = speaker - language = extract_language_from_request(request) - if language is not None: - talker_additional_info["language"] = language - + output_token_ids = _ensure_list(request.output_token_ids) + meta = MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool)) if output_token_ids: - talker_additional_info["meta"]["override_keys"] = [("embed", "decode"), ("ids", "output")] - talker_additional_info["embed"] = {"decode": thinker_emb.detach().cpu()} - talker_additional_info["ids"] = {"output": output_token_ids} + meta.override_keys = [("embed", "decode"), ("ids", "output")] + payload = OmniPayloadStruct( + meta=meta, + embed=EmbeddingsStruct(decode=thinker_emb.detach().cpu()), + ids=IdsStruct(output=output_token_ids), + speaker=speaker, + language=language, + ) else: # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. - talker_additional_info["embed"] = {"prefill": thinker_emb.detach().cpu()} - talker_additional_info["hidden_states"] = {"output": thinker_hid.detach().cpu()} - return talker_additional_info + payload = OmniPayloadStruct( + meta=meta, + embed=EmbeddingsStruct(prefill=thinker_emb.detach().cpu()), + hidden_states=HiddenStatesStruct(output=thinker_hid.detach().cpu()), + speaker=speaker, + language=language, + ) + return payload def thinker2talker( @@ -459,36 +449,26 @@ def thinker2talker( except Exception as exc: logger.warning("[PD] Could not merge prefill embeddings: %s", exc) - payload: OmniPayload = { - "embed": { - "prefill": thinker_emb, - "tts_bos": _resolve_tts_token_embedding( + payload = OmniPayloadStruct( + embed=EmbeddingsStruct( + prefill=thinker_emb, + tts_bos=_resolve_tts_token_embedding( "tts_bos", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device ), - "tts_eos": _resolve_tts_token_embedding( + tts_eos=_resolve_tts_token_embedding( "tts_eos", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device ), - "tts_pad": _resolve_tts_token_embedding( + tts_pad=_resolve_tts_token_embedding( "tts_pad", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device ), - }, - "hidden_states": { - "output": thinker_hid, - }, - "ids": { - "all": thinker_sequences, - "prompt": thinker_input_ids, - }, - } - info = payload - speaker = extract_speaker_from_prompt(prompt, index=i) - if speaker is not None: - info["speaker"] = speaker - language = extract_language_from_prompt(prompt, index=i) - if language is not None: - info["language"] = language - - prompt_len = _compute_talker_prompt_ids_length(payload, device=device) + ), + hidden_states=HiddenStatesStruct(output=thinker_hid), + ids=IdsStruct(all=thinker_sequences, prompt=thinker_input_ids), + speaker=extract_speaker_from_prompt(prompt, index=i), + language=extract_language_from_prompt(prompt, index=i), + ) + info = to_dict(payload) + prompt_len = _compute_talker_prompt_ids_length(info, device=device) talker_inputs.append( OmniTokensPrompt( @@ -509,10 +489,10 @@ def thinker2talker( def talker2code2wav_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: OmniEngineCoreRequest, is_finished: bool = False, -): +) -> OmniPayloadStruct | None: """ Pooling version. """ @@ -565,17 +545,15 @@ def talker2code2wav_async_chunk( left_context_size = max(0, min(length - context_length, left_context_size_config)) end_index = min(length, left_context_size + context_length) - codes = ( - torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) - .transpose(0, 1) - .reshape(-1) - .tolist() - ) + codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).transpose(0, 1).reshape(-1) - return { - "codes": {"audio": codes}, - "meta": {"left_context_size": left_context_size, "finished": torch.tensor(is_finished, dtype=torch.bool)}, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(is_finished, dtype=torch.bool), + ), + ) def talker2code2wav( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 8badab79f45..35edbfc1e2c 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -5,6 +5,13 @@ import torch from vllm.logger import init_logger +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, + to_dict, +) from vllm_omni.model_executor.stage_input_processors.chunk_size_utils import ( compute_dynamic_initial_chunk_size, max_ic_for_chunk_size, @@ -95,18 +102,13 @@ def talker2code2wav( ref_code_len = 0 # Code2Wav expects codebook-major flat: [Q*num_frames] codec_codes = audio_codes.transpose(0, 1).cpu().reshape(-1).tolist() - additional_information: dict[str, Any] = {} - if ref_code_len > 0: - additional_information["meta"] = {"left_context_size": [ref_code_len]} - # Propagate speaker and language from the original prompt so they are - # available as runtime_additional_information in later pipeline stages, - # consistent with qwen3-omni and qwen2.5-omni stage input processors. - speaker = extract_speaker_from_prompt(prompt, index=i) - if speaker is not None: - additional_information["speaker"] = speaker - language = extract_language_from_prompt(prompt, index=i) - if language is not None: - additional_information["language"] = language + additional_information = to_dict( + OmniPayloadStruct( + meta=MetaStruct(left_context_size=ref_code_len) if ref_code_len > 0 else None, + speaker=extract_speaker_from_prompt(prompt, index=i), + language=extract_language_from_prompt(prompt, index=i), + ) + ) code2wav_inputs.append( OmniTokensPrompt( prompt_token_ids=codec_codes, @@ -118,7 +120,7 @@ def talker2code2wav( return code2wav_inputs -def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: +def _extract_last_frame(pooling_output: OmniPayload) -> torch.Tensor | None: audio_codes = pooling_output.get("codes", {}).get("audio") if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: return None @@ -134,10 +136,10 @@ def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: def talker2code2wav_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any] | None, + pooling_output: OmniPayload | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayloadStruct | None: request_id = request.external_req_id finished = bool(is_finished or request.is_finished()) request_payload = getattr(transfer_manager, "request_payload", None) @@ -209,10 +211,10 @@ def talker2code2wav_async_chunk( if length <= 0: if finished: - return { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) return None use_first_chunk = initial_chunk_size > 0 and initial_chunk_size < chunk_size @@ -249,19 +251,17 @@ def talker2code2wav_async_chunk( num_quantizers = len(window_frames[0]) num_frames = len(window_frames) - code_predictor_codes = [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)] + code_predictor_codes = torch.tensor( + [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)], + dtype=torch.long, + ) - info: dict[str, Any] = { - "codes": {"audio": code_predictor_codes}, - "meta": {"left_context_size": left_context_size, "finished": torch.tensor(finished, dtype=torch.bool)}, - } - # Propagate speaker and language from the request so they are available - # as runtime_additional_information in subsequent pipeline stages, consistent - # with qwen3-omni and qwen2.5-omni stage input processors. - speaker = extract_speaker_from_request(request) - if speaker is not None: - info["speaker"] = speaker - language = extract_language_from_request(request) - if language is not None: - info["language"] = language - return info + return OmniPayloadStruct( + codes=CodesStruct(audio=code_predictor_codes), + meta=MetaStruct( + left_context_size=left_context_size, + finished=torch.tensor(finished, dtype=torch.bool), + ), + speaker=extract_speaker_from_request(request), + language=extract_language_from_request(request), + ) diff --git a/vllm_omni/model_executor/stage_input_processors/voxcpm.py b/vllm_omni/model_executor/stage_input_processors/voxcpm.py index c6bff820ca6..125335cd263 100644 --- a/vllm_omni/model_executor/stage_input_processors/voxcpm.py +++ b/vllm_omni/model_executor/stage_input_processors/voxcpm.py @@ -5,6 +5,11 @@ import torch from vllm.inputs import TextPrompt +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayloadStruct, +) from vllm_omni.inputs.data import OmniTokensPrompt _VOXCPM_LATENT_MAGIC = 131071 @@ -74,12 +79,19 @@ def latent2vae( return vae_inputs +def _eof_payload() -> OmniPayloadStruct: + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) + + def latent2vae_async_chunk( transfer_manager: Any, pooling_output: dict[str, Any] | None, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayloadStruct | None: """Stage-0 latent → stage-1 VAE under ``async_chunk`` (connector payload).""" # Kept for callback signature compatibility with OmniChunkTransferAdapter. _ = transfer_manager @@ -87,28 +99,17 @@ def latent2vae_async_chunk( if callable(getattr(request, "is_finished", None)): finished_request = finished_request or _coerce_finished_flag(request.is_finished()) if not isinstance(pooling_output, dict): - if finished_request: - return { - "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), - } - return None + return _eof_payload() if finished_request else None latent = pooling_output.get("latent_audio_feat") if isinstance(latent, torch.Tensor) and latent.numel() == 0: latent = None if latent is None: - if finished_request: - return { - "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), - } - return None + return _eof_payload() if finished_request else None serialized_codes = _serialize_latent_to_codes(latent) - out: dict[str, Any] = { - "code_predictor_codes": serialized_codes, - "finished": torch.tensor(finished_request, dtype=torch.bool), - } - return out + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.tensor(serialized_codes, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(finished_request, dtype=torch.bool)), + ) diff --git a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py index fc32ce6b179..be251f9c90b 100644 --- a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py @@ -4,6 +4,12 @@ from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm_omni.data_entry_keys import ( + CodesStruct, + MetaStruct, + OmniPayload, + OmniPayloadStruct, +) from vllm_omni.inputs.data import OmniTokensPrompt logger = init_logger(__name__) @@ -29,7 +35,7 @@ def generator2tokenizer( return tokenizer_inputs -def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: +def _extract_last_frame(pooling_output: OmniPayload) -> torch.Tensor | None: audio = pooling_output.get("audio") if not isinstance(audio, torch.Tensor) or audio.numel() == 0: return None @@ -38,10 +44,10 @@ def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: def generator2tokenizer_async_chunk( transfer_manager: Any, - pooling_output: dict[str, Any], + pooling_output: OmniPayload, request: Any, is_finished: bool = False, -) -> dict[str, Any] | None: +) -> OmniPayloadStruct | None: request_id = request.external_req_id finished = bool(is_finished or request.is_finished()) @@ -71,10 +77,10 @@ def generator2tokenizer_async_chunk( # finished and nothing was produced, emit an EOF marker. if length <= 0: if finished: - return { - "codes": {"audio": []}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, - } + return OmniPayloadStruct( + codes=CodesStruct(audio=torch.empty(0, dtype=torch.long)), + meta=MetaStruct(finished=torch.tensor(True, dtype=torch.bool)), + ) return None # Use a small chunk size at begin @@ -94,7 +100,12 @@ def generator2tokenizer_async_chunk( # Pack context + chunk into codebook-major flat codes for adapter. code_predictor_codes = torch.tensor(window_frames).reshape(-1).tolist() - return { - "codes": {"audio": [int(ctx_frames)] + [int(context_length)] + code_predictor_codes}, - "meta": {"finished": torch.tensor(finished, dtype=torch.bool)}, - } + return OmniPayloadStruct( + codes=CodesStruct( + audio=torch.tensor( + [int(ctx_frames), int(context_length)] + code_predictor_codes, + dtype=torch.long, + ), + ), + meta=MetaStruct(finished=torch.tensor(finished, dtype=torch.bool)), + ) diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index 66d4e6ffe04..cf2837a7ca1 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -28,28 +28,30 @@ def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]: if multimodal_outputs: for k, v in multimodal_outputs.items(): - if isinstance(v, torch.Tensor): - mm_cpu[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor): - sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() - if sub_dict: - mm_cpu[k] = sub_dict - elif isinstance(v, list) and len(v) > 0: - cpu_list = [] - for elem in v: - if isinstance(elem, torch.Tensor): - cpu_list.append(elem.detach().to("cpu").contiguous()) - else: - cpu_list.append(elem) - mm_cpu[k] = cpu_list - elif v is not None: - mm_cpu[k] = v + cpu_v = _to_cpu(v) + if cpu_v is not None: + mm_cpu[k] = cpu_v return mm_cpu +def _to_cpu(value): + """Recursively detach + move tensors to CPU; preserve dict/list nesting.""" + if isinstance(value, torch.Tensor): + return value.detach().to("cpu").contiguous() + if isinstance(value, dict): + out = {} + for k, v in value.items(): + cpu_v = _to_cpu(v) + if cpu_v is not None: + out[k] = cpu_v + return out or None + if isinstance(value, list): + if not value: + return value + return [_to_cpu(v) for v in value] + return value + + def to_payload_element( element: object, idx: int, start: int, end: int, pass_lists_through: bool = False, seq_len: int | None = None ): @@ -77,7 +79,10 @@ def to_payload_element( # Every other case is shared between prefix cache (passthrough data) # and running a model without prefix caching. elif isinstance(element, dict): - return {sk: sv[start:end].contiguous() for sk, sv in element.items()} + return { + sk: to_payload_element(sv, idx, start, end, pass_lists_through=pass_lists_through, seq_len=seq_len) + for sk, sv in element.items() + } elif isinstance(element, list): # For lists, clone tensors to avoid cross-request aliasing if pass_lists_through: diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 1decf13cb69..f87529106d1 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -261,7 +261,7 @@ def _maybe_update_prefix_cache( self.omni_prefix_cache.update_omni_tensor_prefix_cache( hidden_states=hidden_states, - multimodal_outputs=multimodal_outputs, + multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, num_tokens_padded=num_tokens_padded, @@ -289,7 +289,7 @@ def _maybe_get_combined_prefix_cache_tensors( combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( query_start_loc=self.query_start_loc.cpu, input_batch=self.input_batch, - multimodal_outputs=multimodal_outputs, + multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs, num_scheduled_tokens=num_scheduled_tokens, ) return combined_hidden_states, combined_multimodal_outputs @@ -954,15 +954,21 @@ def propose_draft_token_ids(sampled_token_ids): if combined_multimodal_outputs: # Prefix cache enabled; all items have already been processed # and split apart for each request as needed, and all tensors - # have already been detached to the CPU. The only exception is - # lists, which we keep as passthrough data for consistent behavior - # in postprocess. + # have already been detached to the CPU. Lists are kept as + # passthrough data for consistent behavior in postprocess. + # Recurse into nested dicts so list-valued sub-keys (e.g. + # embed.tts_bos = [tensor]) are unwrapped to bare tensors + # at the leaves; downstream flatten_payload then yields a + # wire-clean dict[str, torch.Tensor]. + def _unwrap_lists(v): + if isinstance(v, list): + return v[idx] if idx < len(v) else v[0] + if isinstance(v, dict): + return {k: _unwrap_lists(sv) for k, sv in v.items()} + return v + for mm_key in combined_multimodal_outputs.keys(): - value = combined_multimodal_outputs[mm_key][rid] - if isinstance(value, list): - mm_payload[mm_key] = value[idx] if idx < len(value) else value[0] - else: - mm_payload[mm_key] = value + mm_payload[mm_key] = _unwrap_lists(combined_multimodal_outputs[mm_key][rid]) else: # Prefix cache disabled; we still need to process the data diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 25f6c040cb5..6955aebd857 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -24,18 +24,17 @@ from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger +from vllm_omni.data_entry_keys import OmniPayload from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec from vllm_omni.outputs import OmniConnectorOutput from vllm_omni.worker.payload_span import ( - THINKER_DECODE_EMBEDDINGS_KEY, - THINKER_DECODE_TOKEN_END_KEY, - THINKER_DECODE_TOKEN_START_KEY, - THINKER_OUTPUT_TOKEN_IDS_KEY, get_tensor_span, merge_tensor_spans, ) +_EMBED_SPAN_GROUPS: tuple[tuple[str, str, str], ...] = (("decode", "decode_token_start", "decode_token_end"),) + if TYPE_CHECKING: from vllm_omni.distributed.omni_connectors.connectors.base import ( OmniConnectorBase, @@ -345,15 +344,15 @@ def prune_inactive_requests(self, active_req_ids: Any) -> set[str]: # Local payload cache (RFC §2.4 – Model Runner ownership) # ------------------------------------------------------------------ # - def put_local_stage_payload(self, req_id: str, payload: dict[str, Any]) -> None: + def put_local_stage_payload(self, req_id: str, payload: OmniPayload) -> None: """Store a full stage payload in the local cache.""" self._local_stage_payload_cache[req_id] = payload - def get_local_stage_payload(self, req_id: str) -> dict[str, Any] | None: + def get_local_stage_payload(self, req_id: str) -> OmniPayload | None: """Read a stage payload without removing it.""" return self._local_stage_payload_cache.get(req_id) - def pop_local_stage_payload(self, req_id: str) -> dict[str, Any] | None: + def pop_local_stage_payload(self, req_id: str) -> OmniPayload | None: """Remove and return a stage payload (consume after use).""" return self._local_stage_payload_cache.pop(req_id, None) @@ -370,29 +369,39 @@ def get_local_request_metadata(self, req_id: str) -> dict[str, Any] | None: # ------------------------------------------------------------------ # @classmethod - def _extract_scheduling_metadata(cls, payload: dict[str, Any]) -> dict[str, Any]: + def _extract_scheduling_metadata(cls, payload: OmniPayload) -> dict[str, Any]: """Extract only the fields the scheduler needs from a full payload.""" extracted: dict[str, Any] = {} - if "next_stage_prompt_len" in payload: + meta = payload.get("meta") if isinstance(payload, dict) else None + meta = meta if isinstance(meta, dict) else {} + + if "next_stage_prompt_len" in meta: + extracted["next_stage_prompt_len"] = meta["next_stage_prompt_len"] + elif "next_stage_prompt_len" in payload: + logger.warning_once( + "legacy flat 'next_stage_prompt_len' key in payload; expected 'meta.next_stage_prompt_len'" + ) extracted["next_stage_prompt_len"] = payload["next_stage_prompt_len"] + audio_codes = cls._payload_audio_codes(payload) if audio_codes is not None: extracted["code_predictor_codes"] = audio_codes - meta = payload.get("meta") - if isinstance(meta, dict) and "left_context_size" in meta: + + if "left_context_size" in meta: extracted["left_context_size"] = meta["left_context_size"] elif "left_context_size" in payload: logger.warning_once("legacy flat 'left_context_size' key in payload; expected 'meta.left_context_size'") + return extracted - _NON_CONSUMABLE_PAYLOAD_KEYS = { - "finished", - "override_keys", - "next_stage_prompt_len", - "left_context_size", - THINKER_OUTPUT_TOKEN_IDS_KEY, - THINKER_DECODE_TOKEN_START_KEY, - THINKER_DECODE_TOKEN_END_KEY, + _NON_CONSUMABLE_PAYLOAD_KEYS: set[tuple[str, str]] = { + ("meta", "finished"), + ("meta", "override_keys"), + ("meta", "next_stage_prompt_len"), + ("meta", "left_context_size"), + ("ids", "output"), + ("embed", "decode_token_start"), + ("embed", "decode_token_end"), } @staticmethod @@ -431,7 +440,7 @@ def _payload_audio_codes(payload: Any) -> Any: return None @classmethod - def _payload_is_consumable(cls, payload: dict[str, Any] | None) -> bool: + def _payload_is_consumable(cls, payload: OmniPayload | None) -> bool: """Return True when an async payload can drive a real forward step. Metadata-only wake-ups should not transition WAITING_FOR_CHUNK requests @@ -442,23 +451,29 @@ def _payload_is_consumable(cls, payload: dict[str, Any] | None) -> bool: if not isinstance(payload, dict) or not payload: return False - decode_embeddings = payload.get(THINKER_DECODE_EMBEDDINGS_KEY) - if isinstance(decode_embeddings, torch.Tensor): - if decode_embeddings.ndim == 0: - return True - return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0 + embed = payload.get("embed") + if isinstance(embed, dict): + decode_embeddings = embed.get("decode") + if isinstance(decode_embeddings, torch.Tensor): + if decode_embeddings.ndim == 0: + return True + return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0 audio_codes = cls._payload_audio_codes(payload) if audio_codes is not None: if isinstance(audio_codes, torch.Tensor): return audio_codes.numel() > 0 - # Codec code 0 is valid; non-empty code payloads are consumable. if hasattr(audio_codes, "__len__"): return len(audio_codes) > 0 return True for key, value in payload.items(): - if key in cls._NON_CONSUMABLE_PAYLOAD_KEYS: + if isinstance(value, dict): + for sk, sv in value.items(): + if (key, sk) in cls._NON_CONSUMABLE_PAYLOAD_KEYS: + continue + if cls._payload_value_has_content(sv): + return True continue if cls._payload_value_has_content(value): return True @@ -1796,85 +1811,66 @@ def _decrement_pending_save_count(self, request_id: str) -> None: # Payload accumulation (ported from OmniChunkTransferAdapter) # ------------------------------------------------------------------ # - def _accumulate_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: - """Accumulate chunk payloads (concat tensors, extend lists). - - Returns a **shallow copy** of the accumulated state so callers - (e.g. ``_poll_single_request``) can store it in - ``_local_stage_payload_cache`` without aliasing the authoritative - ``_send_side_request_payload`` dict. - """ + def _accumulate_payload(self, req_id: str, payload_data: OmniPayload) -> OmniPayload: + """Accumulate chunk payloads (concat tensors, extend lists).""" if req_id not in self._send_side_request_payload: self._send_side_request_payload[req_id] = dict(payload_data) return dict(self._send_side_request_payload[req_id]) origin = self._send_side_request_payload[req_id] merged = dict(origin) - override_keys = payload_data.get("override_keys", ()) - drop_decode_span = False - decode_span_handled = False + raw_ok = payload_data.get("meta", {}).get("override_keys", []) if isinstance(payload_data, dict) else [] + override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok} + for key, value in payload_data.items(): - if key == "finished": - merged[key] = value - continue - if key == THINKER_DECODE_EMBEDDINGS_KEY: - merged_span = merge_tensor_spans( - get_tensor_span( - origin, - tensor_key=THINKER_DECODE_EMBEDDINGS_KEY, - start_key=THINKER_DECODE_TOKEN_START_KEY, - end_key=THINKER_DECODE_TOKEN_END_KEY, - ), - get_tensor_span( - payload_data, - tensor_key=THINKER_DECODE_EMBEDDINGS_KEY, - start_key=THINKER_DECODE_TOKEN_START_KEY, - end_key=THINKER_DECODE_TOKEN_END_KEY, - ), - ) - if merged_span is not None: - merged[key], merged[THINKER_DECODE_TOKEN_START_KEY], merged[THINKER_DECODE_TOKEN_END_KEY] = ( - merged_span - ) - decode_span_handled = True - continue - if isinstance(value, torch.Tensor) and key in origin: - if ( - THINKER_DECODE_TOKEN_START_KEY in origin - or THINKER_DECODE_TOKEN_END_KEY in origin - or THINKER_DECODE_TOKEN_START_KEY in payload_data - or THINKER_DECODE_TOKEN_END_KEY in payload_data - ): - logger.warning( - "[Stage-%s] req=%s falling back to legacy thinker decode " - "merge due to missing/invalid/non-contiguous span " - "metadata", - self._stage_id, - req_id, + if isinstance(value, dict): + origin_sub = origin.get(key) + merged_sub = dict(origin_sub) if isinstance(origin_sub, dict) else {} + span_handled: set[str] = set() + if key == "embed" and isinstance(origin_sub, dict): + for tk, sk, ek in _EMBED_SPAN_GROUPS: + if tk not in value or (key, tk) in override_keys: + continue + span = merge_tensor_spans( + get_tensor_span(origin_sub, tensor_key=tk, start_key=sk, end_key=ek), + get_tensor_span(value, tensor_key=tk, start_key=sk, end_key=ek), ) - drop_decode_span = True - merged[key] = torch.cat([origin[key], value], dim=0) - continue - merged[key] = value - continue - if key in {THINKER_DECODE_TOKEN_START_KEY, THINKER_DECODE_TOKEN_END_KEY}: - if decode_span_handled or drop_decode_span: - continue - merged[key] = value - continue - if key in override_keys: - merged[key] = value - continue - if isinstance(value, torch.Tensor) and key in origin: - merged[key] = torch.cat([origin[key], value], dim=0) - elif isinstance(value, list) and key in origin: - merged[key] = origin[key] + value + if span is None: + continue + t, s, e = span + merged_sub[tk] = t + merged_sub[sk] = s + merged_sub[ek] = e + span_handled |= {tk, sk, ek} + for qual, qval in value.items(): + if qual in span_handled: + continue + if key == "meta" and qual == "finished": + merged_sub[qual] = qval + continue + if (key, qual) in override_keys: + merged_sub[qual] = qval + continue + osv = merged_sub.get(qual) + if isinstance(qval, torch.Tensor) and isinstance(osv, torch.Tensor): + merged_sub[qual] = torch.cat([osv, qval], dim=0) + elif isinstance(qval, list) and isinstance(osv, list): + merged_sub[qual] = osv + qval + else: + merged_sub[qual] = qval + merged[key] = merged_sub else: - merged[key] = value + if key in override_keys: + merged[key] = value + continue + ov = origin.get(key) + if isinstance(value, torch.Tensor) and isinstance(ov, torch.Tensor): + merged[key] = torch.cat([ov, value], dim=0) + elif isinstance(value, list) and isinstance(ov, list): + merged[key] = ov + value + else: + merged[key] = value - if drop_decode_span: - merged.pop(THINKER_DECODE_TOKEN_START_KEY, None) - merged.pop(THINKER_DECODE_TOKEN_END_KEY, None) self._send_side_request_payload[req_id] = merged return dict(merged)