diff --git a/tests/core/sched/test_omni_scheduler_mixin_timeouts.py b/tests/core/sched/test_omni_scheduler_mixin_timeouts.py new file mode 100644 index 00000000000..053c926cc24 --- /dev/null +++ b/tests/core/sched/test_omni_scheduler_mixin_timeouts.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit coverage for _process_pending_input_timeouts. + +Verifies that the mixin correctly *delegates* timed-out requests to the +base scheduler's ``finish_requests`` API with ``RequestStatus.FINISHED_ERROR``. +The end-to-end effect (queue removal + status set + per-request cleanup + +client-facing FINISHED_ERROR emission) is the responsibility of upstream +vLLM's ``finish_requests`` implementation and is covered by upstream tests; +this file only asserts the wiring from the mixin to that API. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin + + +class _FakeCoordinator: + def __init__(self, timed_out_ids): + self._timed_out_ids = set(timed_out_ids) + self.calls = [] + + def collect_timed_out_request_ids(self, timeout_s): + self.calls.append(timeout_s) + return set(self._timed_out_ids) + + +class _FakeScheduler(OmniSchedulerMixin): + def __init__(self, requests, coordinator): + self.requests = requests + self.input_coordinator = coordinator + self.finish_calls = [] + + def finish_requests(self, req_ids, status): + self.finish_calls.append((set(req_ids), status)) + + +def test_process_pending_input_timeouts_delegates_to_finish_requests(): + """Timed-out request present in self.requests is forwarded to finish_requests.""" + req_id = "stuck-req" + requests = {req_id: SimpleNamespace(request_id=req_id)} + coord = _FakeCoordinator(timed_out_ids={req_id}) + scheduler = _FakeScheduler(requests, coord) + + scheduler._process_pending_input_timeouts() + + assert len(coord.calls) == 1, "coordinator should be polled once" + assert coord.calls[0] > 0, "timeout must be positive when enabled" + + assert len(scheduler.finish_calls) == 1 + finished_ids, status = scheduler.finish_calls[0] + assert finished_ids == {req_id} + # RequestStatus is the upstream enum; the mixin imports it as + # RequestStatus.FINISHED_ERROR. Check by name to avoid hard import here. + assert getattr(status, "name", str(status)).endswith("FINISHED_ERROR") + + +def test_process_pending_input_timeouts_skips_already_freed_request(): + """Timed-out id no longer in self.requests must not be forwarded.""" + coord = _FakeCoordinator(timed_out_ids={"already-freed"}) + scheduler = _FakeScheduler(requests={}, coordinator=coord) + + scheduler._process_pending_input_timeouts() + + assert coord.calls == [coord.calls[0]] and coord.calls[0] > 0 + assert scheduler.finish_calls == [] + + +def test_process_pending_input_timeouts_noop_without_coordinator(): + """No coordinator => no finish_requests call, no crash.""" + + class _NoCoord(OmniSchedulerMixin): + def __init__(self): + self.requests = {} + self.input_coordinator = None + self.finish_calls = [] + + def finish_requests(self, req_ids, status): + self.finish_calls.append((set(req_ids), status)) + + scheduler = _NoCoord() + scheduler._process_pending_input_timeouts() + assert scheduler.finish_calls == [] + + +def test_process_pending_input_timeouts_disabled_when_timeout_zero(monkeypatch): + """Setting DEFAULT_INPUT_WAIT_TIMEOUT_S <= 0 disables the safety net.""" + from vllm_omni.core.sched import omni_scheduler_mixin + + monkeypatch.setattr(omni_scheduler_mixin, "DEFAULT_INPUT_WAIT_TIMEOUT_S", 0.0) + + coord = _FakeCoordinator(timed_out_ids={"r1"}) + scheduler = _FakeScheduler(requests={"r1": SimpleNamespace(request_id="r1")}, coordinator=coord) + scheduler._process_pending_input_timeouts() + assert coord.calls == [], "coordinator must not be polled when timeout is disabled" + assert scheduler.finish_calls == [] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/core/sched/test_omni_scheduling_coordinator.py b/tests/core/sched/test_omni_scheduling_coordinator.py index 8c7cc20b0d9..1b36cd784d8 100644 --- a/tests/core/sched/test_omni_scheduling_coordinator.py +++ b/tests/core/sched/test_omni_scheduling_coordinator.py @@ -16,7 +16,7 @@ import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, - uses_qwen3_omni_full_payload_input_coordinator, + uses_full_payload_input_coordinator, ) # ------------------------------------------------------------------ # @@ -92,47 +92,96 @@ def remove_requests(self, requests): class TestFullPayloadCoordinatorSelection(unittest.TestCase): - def test_qwen3_omni_talker_and_code2wav_use_full_payload_input_coordinator(self): - for model_stage in ("talker", "code2wav"): + """Tests for the (model_arch, model_stage) whitelist gate. + + The init_omni_connectors arch allowlist is keyed by ``model_arch`` and + is a superset of the stages registered here -- consumer-wait stages + must be registered explicitly in ``_FULL_PAYLOAD_INPUT_STAGES``, while + the init allowlist covers both producer- and consumer-side runners. + These tests pin which ``(arch, stage)`` pairs the gate fires for today. + """ + + # Expected whitelist (model_arch, model_stage). Hardcoded to avoid the + # tautology of importing _FULL_PAYLOAD_INPUT_STAGES and asserting it + # against itself; any drift between this matrix and the whitelist will + # fail loudly here. + EXPECTED_FULL_PAYLOAD_INPUT_STAGES: frozenset[tuple[str, str]] = frozenset( + { + ("Qwen3OmniMoeForConditionalGeneration", "talker"), + ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), + ("Qwen2_5OmniForConditionalGeneration", "talker"), + ("Qwen2_5OmniForConditionalGeneration", "code2wav"), + ("CovoAudioForConditionalGeneration", "code2wav"), + ("MiMoAudioModel", "code2wav"), + ("Qwen3TTSCode2Wav", "code2wav"), + ("CosyVoice3Model", "cosyvoice3_code2wav"), + ("DyninOmniForConditionalGeneration", "token2image"), + ("DyninOmniForConditionalGeneration", "token2audio"), + } + ) + + def test_whitelist_matches_expected_matrix(self): + """_FULL_PAYLOAD_INPUT_STAGES must equal the hardcoded expected matrix. + + Catches both accidental additions (which would silently enable the + consumer-wait gate for a new arch) and accidental removals (which + would silently disable an enabled arch). + """ + from vllm_omni.core.sched.omni_scheduling_coordinator import _FULL_PAYLOAD_INPUT_STAGES + + self.assertEqual( + frozenset(_FULL_PAYLOAD_INPUT_STAGES), + self.EXPECTED_FULL_PAYLOAD_INPUT_STAGES, + msg="_FULL_PAYLOAD_INPUT_STAGES drifted from the expected matrix; " + "update EXPECTED_FULL_PAYLOAD_INPUT_STAGES if intentional.", + ) + + def test_all_whitelisted_arch_stage_pairs_fire_gate(self): + """Every (arch, stage) pair in the expected matrix must fire + the gate when stage_id > 0 and async_chunk=False. + """ + for arch, stage in self.EXPECTED_FULL_PAYLOAD_INPUT_STAGES: model_config = SimpleNamespace( stage_id=1, async_chunk=False, - model_arch="Qwen3OmniMoeForConditionalGeneration", - model_stage=model_stage, + model_arch=arch, + model_stage=stage, + ) + self.assertTrue( + uses_full_payload_input_coordinator(model_config), + msg=f"expected gate to fire for {arch}/{stage}", ) - self.assertTrue(uses_qwen3_omni_full_payload_input_coordinator(model_config)) - - def test_async_chunk_and_non_qwen3_omni_do_not_use_full_payload_input_coordinator(self): + def test_other_arch_or_stage_or_mode_does_not_fire(self): cases = [ SimpleNamespace( - stage_id=1, - async_chunk=True, - model_arch="Qwen3OmniMoeForConditionalGeneration", - model_stage="talker", + stage_id=1, async_chunk=True, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="talker" ), SimpleNamespace( - stage_id=1, - async_chunk=False, - model_arch="Qwen3TTSForConditionalGeneration", - model_stage="code2wav", + stage_id=0, async_chunk=False, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="thinker" ), SimpleNamespace( stage_id=1, async_chunk=False, - model_arch="Qwen2_5OmniForConditionalGeneration", - model_stage="talker", + model_arch="Qwen3OmniMoeForConditionalGeneration", + model_stage="some_future_stage", ), SimpleNamespace( - stage_id=0, - async_chunk=False, - model_arch="Qwen3OmniMoeForConditionalGeneration", - model_stage="thinker", + stage_id=1, async_chunk=False, model_arch="Qwen3TTSForConditionalGeneration", model_stage="code2wav" + ), + SimpleNamespace( + stage_id=1, async_chunk=False, model_arch="MingFlashOmniForConditionalGeneration", model_stage="talker" + ), + SimpleNamespace(stage_id=1, async_chunk=False, model_arch=None, model_stage="talker"), + SimpleNamespace( + stage_id=1, async_chunk=False, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage=None ), ] - for model_config in cases: - self.assertFalse(uses_qwen3_omni_full_payload_input_coordinator(model_config)) + self.assertFalse( + uses_full_payload_input_coordinator(model_config), + msg=f"expected gate OFF for {model_config}", + ) class TestChunkCoordinatorStateTransition(unittest.TestCase): diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py index 0e071f724e5..bf2261cb920 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py @@ -293,7 +293,13 @@ def inference(self, speech_feat, finalize=True): model = object.__new__(CosyVoice3Code2Wav) nn.Module.__init__(model) model.hift = DummyHiFT() - model._forward_mel = lambda **_: torch.ones((1, 80, 8), dtype=torch.float32) + forward_mel_calls = [] + + def fake_forward_mel(**kwargs): + forward_mel_calls.append(kwargs) + return torch.ones((1, 80, 8), dtype=torch.float32) + + model._forward_mel = fake_forward_mel out = model.forward( token=torch.tensor([[1, 2, 3]], dtype=torch.int32), @@ -304,3 +310,4 @@ def inference(self, speech_feat, finalize=True): assert out.shape == (1, 1, 8) assert model.hift.finalize_calls == [True] + assert forward_mel_calls[0]["token_offset_tokens"] == 0 diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index 956d32528bb..b0afc95921a 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -138,27 +138,6 @@ def _make_sampling_metadata( ) -def test_split_request_ids_uses_seq_token_counts(): - CosyVoice3Model, _ = _cosyvoice3_model_and_runner() - ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long) - chunks = CosyVoice3Model._split_request_ids(ids, [2, 2, 2]) - assert [c.tolist() for c in chunks] == [[10, 11], [12, 13], [14]] - - -def test_split_request_ids_honors_single_request_seq_token_counts(): - CosyVoice3Model, _ = _cosyvoice3_model_and_runner() - ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long) - chunks = CosyVoice3Model._split_request_ids(ids, [3]) - assert [c.tolist() for c in chunks] == [[10, 11, 12]] - - -def test_sanitize_codec_tokens_filters_out_of_range(): - model = _make_code2wav_model() - raw = torch.tensor([-1, 0, 3, 4, 99], dtype=torch.long) - clean = model._sanitize_codec_tokens(raw) - assert clean.tolist() == [0, 3] - - def test_forward_prefers_token_offset_when_present(): model = _make_code2wav_model() @@ -265,6 +244,31 @@ def test_forward_uses_non_stream_decode_without_chunk_metadata(): assert len(model.code2wav.forward_streaming_calls) == 0 call = model.code2wav.forward_calls[0] assert call["token"].tolist() == [[0, 1, 2]] + assert call["token_offset_tokens"] == 0 + + +def test_forward_uses_non_stream_talker_prefill_offset(): + model = _make_code2wav_model() + + runtime_info = [ + { + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "meta": {"talker_prefill_offset": 3}, + } + ] + + model.forward( + input_ids=torch.tensor([0, 1, 2], dtype=torch.long), + positions=torch.tensor([0, 1, 2], dtype=torch.long), + model_intermediate_buffer=runtime_info, + seq_token_counts=[3], + ) + + assert model.code2wav.forward_calls[0]["token_offset_tokens"] == 3 def test_forward_reuses_streaming_cache_state_between_chunks(): diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py index d54533fd0cb..6debabb0e3f 100644 --- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py +++ b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py @@ -6,13 +6,19 @@ import torch -from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import talker2code2wav_async_chunk, text2flow +from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + talker2code2wav_async_chunk, + text2flow, + text2flow_full_payload, + text2flow_token_only, +) -def _source_output(request_id: str, prompt_ids: list[int], out_ids: list[int], mm: dict): +def _source_output(request_id: str, prompt_ids: list[int], out_ids: list[int], mm: dict, finished: bool = True): return SimpleNamespace( request_id=request_id, prompt_token_ids=prompt_ids, + finished=finished, outputs=[SimpleNamespace(token_ids=out_ids, cumulative_token_ids=out_ids, multimodal_output=mm)], ) @@ -45,8 +51,8 @@ def _transfer_manager( def test_text2flow_supports_batched_source_outputs(): source_outputs = [ - _source_output("req-0", [10, 11], [1, 2, 3], {"speech_token": torch.tensor([[1, 2]])}), - _source_output("req-1", [20, 21], [4, 5], {"speech_token": torch.tensor([[3, 4]])}), + _source_output("req-0", [10, 11], [1, 2, 3], {"speech_token": torch.tensor([[8, 9]])}), + _source_output("req-1", [20, 21], [4, 5], {"speech_token": torch.tensor([[6, 7]])}), ] outputs = text2flow(source_outputs=source_outputs, prompt=None) @@ -58,6 +64,65 @@ def test_text2flow_supports_batched_source_outputs(): assert outputs[1]["additional_information"]["ids"]["prompt"] == [20, 21] +def test_text2flow_strips_reference_speech_prefix_from_cumulative_ids(): + source_outputs = [ + _source_output("req-0", [10, 11], [8, 9, 1, 2, 3], {"speech_token": torch.tensor([[8, 9]])}), + ] + + outputs = text2flow(source_outputs=source_outputs, prompt=None) + + assert outputs[0]["prompt_token_ids"] == [1, 2, 3] + + +def test_text2flow_token_only_strips_reference_speech_prefix_from_cumulative_ids(): + source_outputs = [ + _source_output( + "req-strip", + [10, 11], + [4, 5, 1, 2, 3], + {"embed": {"speech_token": torch.tensor([[4, 5]])}}, + ) + ] + + outputs = text2flow_token_only(source_outputs=source_outputs, prompt=None) + + assert len(outputs) == 1 + assert outputs[0]["prompt_token_ids"] == [1, 2, 3] + assert outputs[0]["additional_information"]["ids"]["prompt"] == [10, 11] + + +def test_text2flow_token_only_marks_prompt_trim_for_stop_token_completion(): + source_outputs = [ + _source_output( + "req-stop", + [10, 11], + [4, 5, 1, 2, 6562], + {"embed": {"speech_token": torch.tensor([[4, 5]])}}, + ) + ] + + outputs = text2flow_token_only(source_outputs=source_outputs, prompt=None) + + assert outputs[0]["prompt_token_ids"] == [1, 2, 6562] + assert outputs[0]["additional_information"]["meta"]["talker_prefill_offset"] == 2 + + +def test_text2flow_full_payload_does_not_send_codec_ids(): + payload = text2flow_full_payload( + None, + { + "embed.speech_token": torch.tensor([[1, 2]], dtype=torch.long), + "codes.audio": torch.tensor([7, 8, 9], dtype=torch.long), + }, + SimpleNamespace(), + ) + + assert payload is not None + assert "codes" not in payload + assert "next_stage_prompt_len" not in payload["meta"] + assert torch.equal(payload["embed"]["speech_token"], torch.tensor([[1, 2]], dtype=torch.long)) + + def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): transfer_manager = _transfer_manager() request = SimpleNamespace( diff --git a/tests/model_executor/stage_input_processors/test_qwen2_5_omni_thinker2talker.py b/tests/model_executor/stage_input_processors/test_qwen2_5_omni_thinker2talker.py new file mode 100644 index 00000000000..0cb61a972f2 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_qwen2_5_omni_thinker2talker.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Light coverage for qwen2_5_omni.thinker2talker_full_payload. + +Covers the finish-reason-aware stop-row trim contract: when the request +status is FINISHED_STOPPED, the builder must drop one row from the +accumulated hidden states (vLLM v1 appends the sampled stop token to +output_token_ids before check_stop, so the trailing hidden-state row +corresponds to the stop emission and must not reach the talker). +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + thinker2talker_full_payload, +) + + +def _make_request( + prompt_token_ids, + output_token_ids, + status_name: str | None = "FINISHED_STOPPED", +): + status = SimpleNamespace(name=status_name) if status_name else None + return SimpleNamespace( + request_id="r1", + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + all_token_ids=list(prompt_token_ids) + list(output_token_ids), + status=status, + sampling_params=None, + ) + + +def test_finished_stopped_trims_one_decode_row(): + """FINISHED_STOPPED: drop trailing hidden-state row so talker does not + consume the stop-emission row. + """ + prompt = [1, 2, 3] + output = [10, 11, 12] + request = _make_request(prompt, output, status_name="FINISHED_STOPPED") + # 6 prompt+output rows + 1 stop-emission row = 7 hidden rows total. + hidden = torch.arange(7 * 4, dtype=torch.float32).reshape(7, 4) + pooling = {"hidden": hidden} + + payload = thinker2talker_full_payload(transfer_manager=None, pooling_output=pooling, request=request) + + assert payload is not None + # ids.output had one trailing stop row dropped: 3 - 1 = 2 remaining. + assert payload["ids"]["output"] == output[:-1] + # embed.prefill must cover only the prompt rows. + assert payload["embed"]["prefill"].shape[0] == len(prompt) + # hidden_states.output covers the decode rows minus the dropped stop row. + assert payload["hidden_states"]["output"].shape[0] == len(output) - 1 + + +def test_finished_length_capped_keeps_all_rows(): + """FINISHED_LENGTH_CAPPED: no row drop; hidden_states.output covers + all decode rows. + """ + prompt = [1, 2, 3] + output = [10, 11, 12] + request = _make_request(prompt, output, status_name="FINISHED_LENGTH_CAPPED") + hidden = torch.arange(6 * 4, dtype=torch.float32).reshape(6, 4) + pooling = {"hidden": hidden} + + payload = thinker2talker_full_payload(transfer_manager=None, pooling_output=pooling, request=request) + + assert payload is not None + assert payload["ids"]["output"] == output + assert payload["embed"]["prefill"].shape[0] == len(prompt) + assert payload["hidden_states"]["output"].shape[0] == len(output) + + +def test_missing_hidden_returns_none(): + """Defensive: pooling_output without "hidden" returns None.""" + request = _make_request([1, 2], [3], status_name="FINISHED_STOPPED") + assert thinker2talker_full_payload(transfer_manager=None, pooling_output={}, request=request) is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index f11a4654ec2..a04a75de875 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -168,6 +168,7 @@ def test_talker2code2wav_full_payload_keeps_all_zero_codec_rows() -> None: def test_thinker2talker_full_payload_packs_complete_tensors() -> None: + """Full-payload path drops the terminal thinker row before talker prefill.""" request = SimpleNamespace( request_id="thinker", prompt_token_ids=[151644, 872], @@ -187,3 +188,741 @@ def test_thinker2talker_full_payload_packs_complete_tensors() -> None: assert payload["embed"]["prefill"].device.type == "cpu" assert payload["hidden_states"]["output"].device.type == "cpu" assert payload["next_stage_prompt_len"] > 0 + assert payload["embed"]["prefill"].shape[0] == 2 + assert payload["hidden_states"]["output"].shape[0] == 2 + + +def test_accumulator_replaces_keys_in_replace_set() -> None: + """REPLACE-key semantics: subsequent emissions of the same key replace, not append.""" + from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin + + class _StubMixin(OmniConnectorModelRunnerMixin): + def __init__(self): + self._pending_full_payload_send = {} + self._full_payload_replace_keys_cached = frozenset({"model_outputs"}) + + stub = _StubMixin() + stub.accumulate_full_payload_output( + "req1", + { + "model_outputs": torch.tensor([[1.0, 2.0]]), + "hidden_states.output": torch.tensor([[10.0]]), + }, + request=None, + ) + stub.accumulate_full_payload_output( + "req1", + { + "model_outputs": torch.tensor([[3.0, 4.0]]), + "hidden_states.output": torch.tensor([[20.0]]), + }, + request=None, + ) + output, _ = stub._materialize_full_payload_entry(stub._pending_full_payload_send["req1"]) + # model_outputs REPLACED (second value only): + assert torch.equal(output["model_outputs"], torch.tensor([[3.0, 4.0]])) + # hidden_states.output CONCATENATED: + assert torch.equal(output["hidden_states.output"], torch.tensor([[10.0], [20.0]])) + + +def test_accumulator_concat_default_when_no_replace_keys() -> None: + """Default semantics: 2-D+ tensors concat across emissions when not in replace_keys.""" + from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin + + class _StubMixin(OmniConnectorModelRunnerMixin): + def __init__(self): + self._pending_full_payload_send = {} + self._full_payload_replace_keys_cached = frozenset() + + stub = _StubMixin() + stub.accumulate_full_payload_output( + "req1", + {"embed.prefill": torch.tensor([[1.0]])}, + request=None, + ) + stub.accumulate_full_payload_output( + "req1", + {"embed.prefill": torch.tensor([[2.0]])}, + request=None, + ) + output, _ = stub._materialize_full_payload_entry(stub._pending_full_payload_send["req1"]) + assert torch.equal(output["embed.prefill"], torch.tensor([[1.0], [2.0]])) + + +def test_covo_audio_llm2code2wav_token_only_smoke() -> None: + """Smoke: covo_audio token-only builder returns placeholder prompts sized to audio_codes count.""" + # source_outputs is a list of objects with .outputs[0].token_ids + from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX + from vllm_omni.model_executor.stage_input_processors.covo_audio import ( + llm2code2wav_token_only, + ) + + class _Out: + def __init__(self, tids): + self.token_ids = tids + + class _Wrapper: + def __init__(self, tids): + self.outputs = [_Out(tids)] + + # 3 codec tokens + 2 non-codec + src = [_Wrapper([COVO_AUDIO_TOKEN_INDEX + 0, COVO_AUDIO_TOKEN_INDEX + 1, COVO_AUDIO_TOKEN_INDEX + 2, 100, 200])] + out = llm2code2wav_token_only(src) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 3 + assert out[0]["additional_information"] is None + + +def test_covo_audio_llm2code2wav_full_payload_smoke() -> None: + """Smoke: covo_audio producer-side payload builder returns audio_codes + finished.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX + from vllm_omni.model_executor.stage_input_processors.covo_audio import ( + llm2code2wav_full_payload, + ) + + req = SimpleNamespace( + output_token_ids=[COVO_AUDIO_TOKEN_INDEX + 5, COVO_AUDIO_TOKEN_INDEX + 6, 99], + ) + payload = llm2code2wav_full_payload(None, {}, req) + assert payload is not None + assert payload["codes"]["audio"] == [5, 6] + assert payload["meta"]["finished"].item() is True + + +def test_dynin_omni_token_only_smoke() -> None: + """Smoke: dynin_omni token-only builders return placeholders.""" + from vllm_omni.model_executor.stage_input_processors.dynin_omni import ( + token2text_to_token2image_token_only, + ) + + class _Out: + def __init__(self, tids, mm=None): + self.token_ids = tids + self.multimodal_output = mm + + class _Wrapper: + def __init__(self, tids, mm=None): + self.outputs = [_Out(tids, mm)] + self.request_id = "r0" + + class _Stage: + def __init__(self, outs): + self.engine_outputs = outs + + src = [_Wrapper([10, 11, 12])] + out = token2text_to_token2image_token_only([_Stage(src)], [0]) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 3 + assert out[0]["additional_information"] is None + + +def test_dynin_omni_full_payload_smoke() -> None: + """Smoke: dynin_omni producer-side payload builder returns nested OmniPayload + carries metadata.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.stage_input_processors.dynin_omni import ( + token2text_to_token2image_full_payload, + ) + + pooling = {"token_ids": [1, 2, 3]} + req = SimpleNamespace(output_token_ids=[], additional_information={"speaker": ["alice"]}) + payload = token2text_to_token2image_full_payload(None, pooling, req) + assert payload is not None + assert payload["codes"]["audio"] == [1, 2, 3] + assert payload["meta"]["finished"].item() is True + # additional_information is normalized + carried forward (speaker stays list-wrapped). + assert payload.get("speaker") == ["alice"] + + +def test_qwen2_5_omni_talker2code2wav_token_only_smoke() -> None: + """Smoke: qwen2_5_omni talker→code2wav token_only marker + boundary strip.""" + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + talker2code2wav_token_only, + ) + + class _Out: + def __init__(self, tids): + self.cumulative_token_ids = tids + + class _Wrap: + def __init__(self, tids): + self.outputs = [_Out(tids)] + + # 3 inner codes wrapped by START + END + src = [_Wrap([TALKER_CODEC_START_TOKEN_ID, 10, 11, 12, TALKER_CODEC_END_TOKEN_ID])] + out = talker2code2wav_token_only(src) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 3 + assert out[0]["additional_information"] is None + + +def test_qwen2_5_omni_talker2code2wav_full_payload_smoke() -> None: + """Smoke: qwen2_5_omni producer-side payload builder strips boundaries.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + talker2code2wav_full_payload, + ) + + req = SimpleNamespace( + output_token_ids=[TALKER_CODEC_START_TOKEN_ID, 5, 6, 7, TALKER_CODEC_END_TOKEN_ID], + ) + payload = talker2code2wav_full_payload(None, {}, req) + assert payload is not None + assert payload["codes"]["audio"] == [5, 6, 7] + assert payload["meta"]["finished"].item() is True + + +def test_qwen2_5_omni_talker2code2wav_filters_control_tokens_and_placeholders() -> None: + """Qwen2.5 code2wav receives codec ids only, not talker prompt/control ids.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_PAD_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + talker2code2wav_full_payload, + talker2code2wav_token_only, + ) + + class _Out: + def __init__(self, tids): + self.cumulative_token_ids = tids + + class _Wrap: + def __init__(self, tids): + self.outputs = [_Out(tids)] + + raw_ids = [ + TALKER_CODEC_START_TOKEN_ID, + TALKER_CODEC_PAD_TOKEN_ID, + 5, + 6, + TALKER_CODEC_END_TOKEN_ID, + -1, + -1, + ] + + token_only = talker2code2wav_token_only([_Wrap(raw_ids)]) + assert len(token_only) == 1 + assert len(token_only[0]["prompt_token_ids"]) == 4 + + payload = talker2code2wav_full_payload(None, {}, SimpleNamespace(output_token_ids=raw_ids)) + assert payload is not None + assert payload["codes"]["audio"] == [5, 6, 6, 6] + assert payload["meta"]["finished"].item() is True + + +def test_mimo_audio_llm2code2wav_token_only_smoke() -> None: + """Smoke: mimo_audio token-only builder sizes prompt.""" + import torch + + from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( + llm2code2wav_token_only, + ) + + class _Out: + def __init__(self, mm): + self.multimodal_output = mm + + class _Wrap: + def __init__(self, mm): + self.outputs = [_Out(mm)] + + # 3 batch rows of [1, 8, 4]: prepend_and_flatten_colmajor → 3*1*4*9 = 108 + codes = torch.arange(96, dtype=torch.long).reshape(3, 1, 8, 4) + codes = codes.clamp(min=1) # ensure nonzero so zero-row filter doesn't drop them + src = [_Wrap({"codes": {"audio": codes}})] + out = llm2code2wav_token_only(src) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 108 + assert out[0]["additional_information"] is None + + +def test_mimo_audio_llm2code2wav_full_payload_smoke() -> None: + """Smoke: mimo_audio producer-side payload builder reads flat codes.audio + flattens.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( + TALKER_CODEC_PAD_TOKEN_ID, + llm2code2wav_full_payload, + ) + + # Simulate accumulator output: 2 steps of [1, 1, 8, 4] CONCAT'd → [2, 1, 8, 4] + audio = torch.arange(2 * 1 * 8 * 4, dtype=torch.long).reshape(2, 1, 8, 4) + audio = audio.clamp(min=1) # avoid zero-row drop + pooling_output = {"codes.audio": audio} + req = SimpleNamespace(output_token_ids=[]) + payload = llm2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert "codes" in payload and "audio" in payload["codes"] + # Flattened length = numel + B*4 (per-batch pad_vec prepended by prepend_and_flatten_colmajor) + batch_size = int(audio.shape[0]) + assert len(payload["codes"]["audio"]) == audio.numel() + batch_size * 4 + # prepend_and_flatten_colmajor: PAD appears at column start in col-major flatten. + # For shape [B=2, 1, 9, 4], each column has 1 PAD then 8 codec vals → PAD at indices 0, 9, 18, 27. + out = payload["codes"]["audio"] + assert out[0] == TALKER_CODEC_PAD_TOKEN_ID + assert out[9] == TALKER_CODEC_PAD_TOKEN_ID + assert payload["meta"]["finished"].item() is True + + +def test_mimo_audio_full_payload_nested_fallback() -> None: + """Back-compat: full_payload still works if runtime returns nested codes.audio.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( + llm2code2wav_full_payload, + ) + + audio = torch.arange(1 * 1 * 8 * 4, dtype=torch.long).reshape(1, 1, 8, 4) + audio = audio.clamp(min=1) + pooling_output = {"codes": {"audio": audio}} # nested, not flat + req = SimpleNamespace(output_token_ids=[]) + payload = llm2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert len(payload["codes"]["audio"]) == audio.numel() + int(audio.shape[0]) * 4 + + +def test_qwen3_tts_talker2code2wav_token_only_smoke() -> None: + """Smoke: qwen3_tts token-only sizes placeholder.""" + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_token_only, + ) + + class _Out: + def __init__(self, mm, tids): + self.multimodal_output = mm + self.cumulative_token_ids = tids + + class _Wrap: + def __init__(self, mm, tids): + self.outputs = [_Out(mm, tids)] + self.finished = True + + # 3 valid codec frames Q=16; non-zero & under codebook size + audio = torch.arange(3 * 16, dtype=torch.long).reshape(3, 16) + 1 + mm = {"codes": {"audio": audio}} + src = [_Wrap(mm, list(range(10)))] # seq_len = 9; 3 < 9, no trim + out = talker2code2wav_token_only(src) + assert len(out) == 1 + # Codebook-major flat: 16 * 3 = 48 + assert len(out[0]["prompt_token_ids"]) == 48 + + +def test_qwen3_tts_talker2code2wav_full_payload_smoke() -> None: + """Smoke: qwen3_tts full_payload reads flat codes.audio + flattens codebook-major.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_full_payload, + ) + + # 3 valid codec frames [3, 16] CONCAT'd from per-step emits via flatten + audio = torch.arange(3 * 16, dtype=torch.long).reshape(3, 16) + 1 + pooling_output = {"codes.audio": audio} + req = SimpleNamespace(output_token_ids=list(range(10))) # seq_len = 9 + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert "codes" in payload and "audio" in payload["codes"] + # codebook-major: shape [3, 16] -> [16, 3] -> flatten = 48 entries + assert isinstance(payload["codes"]["audio"], torch.Tensor) + assert payload["codes"]["audio"].shape == (48,) + expected = audio.transpose(0, 1).reshape(-1) + assert torch.equal(payload["codes"]["audio"], expected) + assert payload["meta"]["finished"].item() is True + + +def test_qwen3_tts_full_payload_with_ref_code() -> None: + """Exact: ref_code is prepended (not appended) to audio, ref_code_len trims + ref, and the flatten is codebook-major. Protects against ref-append-position + regressions, ref_code_len-not-applied bugs, and flatten-order regressions.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_full_payload, + ) + + # Audio: 3 frames [3, 16] (no filter drops these — all positive, in-range). + audio = torch.arange(3 * 16, dtype=torch.long).reshape(3, 16) + 1 + # Ref code: 2 frames [2, 16] (already 2-D), distinct value range so we can + # detect the prepend ordering. + ref = torch.arange(2 * 16, dtype=torch.long).reshape(2, 16) + 100 + pooling_output = { + "codes.audio": audio, + "codes.ref": [ref], + "meta.ref_code_len": torch.tensor([2], dtype=torch.int32), + } + req = SimpleNamespace(output_token_ids=list(range(10))) # seq_len = 9 > 3, no audio crop + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + + # Exact expected: ref (prepended) + audio (no crop since seq_len > rows), then + # transpose [5, 16] -> [16, 5] and flatten row-major (codebook-major). + expected = torch.cat([ref, audio], dim=0).transpose(0, 1).reshape(-1) + assert torch.equal(payload["codes"]["audio"], expected), ( + f"codec flatten mismatch -- got first 8 = {payload['codes']['audio'][:8].tolist()}, " + f"expected first 8 = {expected[:8].tolist()}" + ) + assert payload["codes"]["audio"].shape == (80,) # 16 quantizers * (2 ref + 3 audio) frames + + # Sanity guards: first codebook-major column = [ref[0,0], ref[1,0], audio[0,0], ...], + # so the prepend order must put 100 before 1. + first_col = payload["codes"]["audio"][:5].tolist() + assert first_col == [100, 116, 1, 17, 33], ( + f"first column wrong: {first_col} -- ref likely appended instead of prepended" + ) + + +def test_qwen3_tts_full_payload_nested_fallback() -> None: + """Back-compat: full_payload works if pooler returns un-flattened nested dict.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_full_payload, + ) + + audio = torch.arange(2 * 16, dtype=torch.long).reshape(2, 16) + 1 + pooling_output = {"codes": {"audio": audio}} # nested, not flat + req = SimpleNamespace(output_token_ids=list(range(10))) + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert isinstance(payload["codes"]["audio"], torch.Tensor) + assert payload["codes"]["audio"].shape == (32,) # 16 * 2 + + +def test_qwen3_tts_code2wav_prefers_connector_tensor_payload() -> None: + """Code2Wav should consume connector codec tensor instead of placeholder zeros.""" + import torch + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import ( + _codec_ids_from_payload_or_input, + ) + + placeholder = torch.zeros(6, dtype=torch.long) + codec = torch.arange(12, dtype=torch.long) + + out = _codec_ids_from_payload_or_input( + placeholder, + {"codes": {"audio": codec}}, + ) + + assert torch.equal(out, codec) + + +def test_qwen3_tts_code2wav_accepts_legacy_list_payload() -> None: + """Back-compat: old list full-payloads still override placeholder tokens.""" + import torch + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import ( + _codec_ids_from_payload_or_input, + ) + + placeholder = torch.zeros(6, dtype=torch.long) + + out = _codec_ids_from_payload_or_input( + placeholder, + {"codes": {"audio": [1, 2, 3, 4]}}, + ) + + assert torch.equal(out, torch.tensor([1, 2, 3, 4], dtype=torch.long)) + + +def test_qwen3_tts_code2wav_forward_decodes_connector_payload() -> None: + """Forward should decode real connector codes, not token-only placeholders.""" + from collections import Counter + + import torch + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import ( + Qwen3TTSCode2Wav, + ) + + class _Decoder: + def __init__(self): + self.last_codes = None + + def chunked_decode(self, codes, **_kwargs): + self.last_codes = codes.detach().clone() + return codes.sum(dim=1).to(torch.float32) + + decoder = _Decoder() + model = Qwen3TTSCode2Wav.__new__(Qwen3TTSCode2Wav) + torch.nn.Module.__init__(model) + model.decoder = decoder + model._num_quantizers = 2 + model._total_upsample = 1 + model._output_sample_rate = 24000 + model._decode_chunk_frames = 300 + model._decode_left_context_frames = 25 + model._decode_batch_bucket_frames = [] + model._decode_batch_max_size = 0 + model._decode_variable_chunk_batch_min_frames = 326 + model._logged_codec_stats = True + model._logged_malformed_codec_lengths = set() + model._batch_stats_enabled = False + model._batch_stats_log_every = 0 + model._batch_stats_forwards = 0 + model._batch_stats_groups = 0 + model._batch_stats_requests = 0 + model._batch_stats_padded_frames = 0 + model._batch_stats_decoded_frames = 0 + model._batch_stats_actual_frames = Counter() + model._batch_stats_bucket_groups = Counter() + + payload_codes = torch.tensor([1, 3, 2, 4], dtype=torch.long) + out = model.forward( + input_ids=torch.zeros(4, dtype=torch.long), + runtime_additional_information=[{"codes": {"audio": payload_codes}, "meta": {}}], + ) + + assert decoder.last_codes is not None + assert torch.equal(decoder.last_codes, torch.tensor([[[1, 3], [2, 4]]], dtype=torch.long)) + assert torch.equal(out.multimodal_outputs["model_outputs"][0], torch.tensor([3.0, 7.0])) + + +def test_qwen3_tts_codec_filter_and_crop_edge_cases() -> None: + """Regression gate for codec filter + seq_len crop on both token_only and full_payload. + + Mixes valid / all-zero / negative / >=_CODEBOOK_SIZE rows. Asserts: + - Token-only placeholder length matches Q * (#kept rows after crop). + - Full-payload codes.audio matches the exact codebook-major flatten + of the kept-and-cropped rows. + + Protects against future cleanup reverting the codex P2 #3 (negative + codec filter) or the _CODEBOOK_SIZE upper bound. + """ + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + _CODEBOOK_SIZE, + talker2code2wav_full_payload, + talker2code2wav_token_only, + ) + + Q = 4 # simulated num_quantizers (default is 16; small here for readability) + # 7 rows: valid / all-zero / negative / out-of-range / boundary-valid / valid / valid. + audio_rows = [ + [10, 20, 30, 40], # row 0: valid -> KEEP + [0, 0, 0, 0], # row 1: all-zero -> DROP + [50, -1, 60, 70], # row 2: negative -> DROP + [100, _CODEBOOK_SIZE, 110, 120], # row 3: >= 2048 -> DROP + [200, _CODEBOOK_SIZE - 1, 210, 220], # row 4: boundary 2047 -> KEEP + [300, 310, 320, 330], # row 5: valid -> KEEP + [400, 410, 420, 430], # row 6: valid -> KEEP + ] + audio = torch.tensor(audio_rows, dtype=torch.long) + kept = [audio_rows[i] for i in (0, 4, 5, 6)] # 4 rows after filter + + # === token_only path === + # cumulative_token_ids of length 4 -> seq_len = 3 -> crop kept[-3:] = rows {4, 5, 6} + class _Out: + def __init__(self, ctids, mm): + self.cumulative_token_ids = ctids + self.multimodal_output = mm + + class _Wrap: + def __init__(self, ctids, mm): + self.outputs = [_Out(ctids, mm)] + self.finished = True + + mm = {"codes": {"audio": audio}, "meta": {}} + src = [_Wrap(ctids=[1, 2, 3, 4], mm=mm)] + out = talker2code2wav_token_only(src, prompt=None) + assert len(out) == 1 + # No ref_code -> ref_frames = 0; expected prompt_len = Q * (#kept-after-crop) = 4 * 3 = 12 + assert len(out[0]["prompt_token_ids"]) == Q * 3 + + # === full_payload path === + pooling_output = {"codes.audio": audio} + req = SimpleNamespace(output_token_ids=[1, 2, 3, 4]) # seq_len = 3 + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + # After filter + crop, kept rows = [row4, row5, row6] = [[200,2047,210,220],[300,310,320,330],[400,410,420,430]] + # Codebook-major flatten: transpose [3, Q] -> [Q, 3] -> reshape(-1) + cropped = torch.tensor(kept[-3:], dtype=torch.long) + expected = cropped.transpose(0, 1).reshape(-1) + assert torch.equal(payload["codes"]["audio"], expected) + # Sanity: confirm the boundary-valid 2047 survived (codex P2 #3 regression guard). + assert _CODEBOOK_SIZE - 1 in payload["codes"]["audio"].tolist() + # Sanity: confirm no negative or >=_CODEBOOK_SIZE codec id leaked through. + assert bool(((payload["codes"]["audio"] >= 0) & (payload["codes"]["audio"] < _CODEBOOK_SIZE)).all()) + + +def test_cosyvoice3_text2flow_token_only_smoke() -> None: + """Smoke: cosyvoice3 token-only carries ids.prompt only.""" + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + text2flow_token_only, + ) + + class _Out: + def __init__(self, tids): + self.cumulative_token_ids = tids + self.multimodal_output = {} + + class _Wrap: + def __init__(self, output_tids, prompt_tids): + self.outputs = [_Out(output_tids)] + self.prompt_token_ids = prompt_tids + self.finished = True + + # multimodal_output has embed.* + we expect token_only to preserve it. + import torch + + embed = {"speech_token": torch.zeros(2, 4)} + src = [_Wrap(output_tids=[10, 20, 30], prompt_tids=[1, 2, 3, 4])] + src[0].outputs[0].multimodal_output = {"embed": embed} + out = text2flow_token_only(src) + assert len(out) == 1 + # prompt_token_ids is the talker's cumulative_token_ids (real codec tokens, not zeros). + assert out[0]["prompt_token_ids"] == [10, 20, 30] + # additional_information carries ids.prompt PLUS the original multimodal_output (embed.* still inline). + # Heavy embed.* removal pending the model_intermediate_buffer plumbing on the code2wav side. + assert out[0]["additional_information"]["ids"]["prompt"] == [1, 2, 3, 4] + assert "embed" in out[0]["additional_information"] + + +def test_cosyvoice3_text2flow_full_payload_smoke() -> None: + """Smoke: cosyvoice3 producer-side reads flat embed.* keys.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + text2flow_full_payload, + ) + + speech_token = torch.randn(4, 8) + speech_feat = torch.randn(4, 16) + embedding = torch.randn(1, 32) + pooling_output = { + "embed.speech_token": speech_token, + "embed.speech_feat": speech_feat, + "embed.embedding": embedding, + } + req = SimpleNamespace(external_req_id="r-1") + payload = text2flow_full_payload(None, pooling_output, req) + assert payload is not None + assert "embed" in payload + assert torch.equal(payload["embed"]["speech_token"], speech_token) + assert torch.equal(payload["embed"]["speech_feat"], speech_feat) + assert torch.equal(payload["embed"]["embedding"], embedding) + assert payload["meta"]["finished"].item() is True + + +def test_cosyvoice3_text2flow_full_payload_nested_fallback() -> None: + """Back-compat: full_payload works if pooler returns un-flattened nested embed.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + text2flow_full_payload, + ) + + speech_token = torch.randn(3, 8) + pooling_output = {"embed": {"speech_token": speech_token}} # nested, not flat + req = SimpleNamespace(external_req_id="r-2") + payload = text2flow_full_payload(None, pooling_output, req) + assert payload is not None + assert "speech_token" in payload["embed"] + assert torch.equal(payload["embed"]["speech_token"], speech_token) + + +def test_cosyvoice3_full_payload_replace_keys_present() -> None: + """Confirm _FULL_PAYLOAD_REPLACE_KEYS lists the three embed.* keys.""" + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + _FULL_PAYLOAD_REPLACE_KEYS, + ) + + assert _FULL_PAYLOAD_REPLACE_KEYS == frozenset({"embed.speech_token", "embed.speech_feat", "embed.embedding"}) + + +def test_ming_flash_omni_thinker2talker_token_only_smoke() -> None: + """Smoke: ming_flash_omni token-only carries voice metadata.""" + from vllm_omni.model_executor.stage_input_processors.ming_flash_omni import ( + thinker2talker_token_only, + ) + + class _Out: + def __init__(self, text): + self.text = text + + class _Wrap: + def __init__(self, text): + self.outputs = [_Out(text)] + + class _Prompt: + def __init__(self, info): + self.additional_information = info + + src = [_Wrap("hello world")] + prompt = _Prompt({"voice_name": "ZH_FEMALE", "prompt_text": "ref text"}) + out = thinker2talker_token_only(src, prompt=prompt) + assert len(out) == 1 + assert out[0]["prompt_token_ids"] == [0] # talker self-tokenizes; dummy id + info = out[0]["additional_information"] + assert info["text"] == "hello world" + assert info["voice_name"] == "ZH_FEMALE" + assert info["prompt_text"] == "ref text" + assert info["ming_task"] == "omni" + + +def test_qwen2_5_omni_thinker2talker_token_only_smoke() -> None: + """Smoke: qwen2_5_omni thinker token-only allocates prompt slots; bulk payload ships via connector.""" + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_PAD_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + thinker2talker_token_only, + ) + + class _Wrap: + def __init__(self, prompt_tids, rid): + self.outputs = [object()] + self.prompt_token_ids = prompt_tids + self.request_id = rid + + class _Prompt(dict): + pass + + src = [_Wrap(prompt_tids=[1, 2, 3, 4, 5], rid="r-1")] + prompt = [_Prompt(multi_modal_data=None)] + out = thinker2talker_token_only(src, prompt=prompt) + assert len(out) == 1 + expected_prompt_len = 1 + len([1, 2, 3, 4, 5]) + 1 + assert len(out[0]["prompt_token_ids"]) == expected_prompt_len + assert out[0]["prompt_token_ids"][0] == TALKER_CODEC_START_TOKEN_ID + assert out[0]["prompt_token_ids"][-1] == TALKER_CODEC_END_TOKEN_ID + assert all(t == TALKER_CODEC_PAD_TOKEN_ID for t in out[0]["prompt_token_ids"][1:-1]) + assert out[0]["additional_information"] is None + + +def test_qwen2_5_omni_thinker2talker_full_payload_noop() -> None: + """thinker2talker_full_payload returns None when pooling_output lacks the "hidden" key (defensive).""" + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + thinker2talker_full_payload, + ) + + payload = thinker2talker_full_payload(None, {"any": "thing"}, None) + assert payload is None diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index c54a39a1e38..0db1d4fbae8 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -1841,12 +1841,19 @@ def test_async_chunk_dispatches_processors(self): ) assert async_stages[1].custom_process_input_func is None - # async_chunk=False → stage 0 has no streaming processor, stage 1's - # batch-end processor wires up. + # async_chunk=False → stage 0 ships the bulk codec via the + # worker-connector full-payload producer; stage 1 wires the + # ``_token_only`` placeholder so the orchestrator emits no + # legacy ``additional_information``-shaped input (PR3 sync- + # via-connector data plane). sync_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False)) - assert "custom_process_next_stage_input_func" not in sync_stages[0].yaml_engine_args + assert ( + sync_stages[0] + .yaml_engine_args["custom_process_next_stage_input_func"] + .endswith("talker2code2wav_full_payload") + ) assert sync_stages[1].custom_process_input_func is not None - assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav") + assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav_token_only") def test_async_chunk_dispatches_qwen3_omni_processors(self): import runpy @@ -1883,6 +1890,44 @@ def test_async_chunk_dispatches_qwen3_omni_processors(self): .endswith("talker2code2wav_full_payload") ) + def test_ming_flash_omni_topology(self): + """Guard ming_flash_omni's PR3 cleanup: stage 0 has no full-payload + producer hook (the connector path was removed as fake -- arch is not + in ``_FULL_PAYLOAD_INPUT_STAGES``), and stage 1 still wires the + legacy ``thinker2talker`` (custom_process_input_func) plus the + ``thinker2talker_token_only`` placeholder (sync_process_input_func). + Merge under either async_chunk mode must not re-introduce a + stage-0 full-payload hook.""" + from vllm_omni.config.stage_config import DeployConfig, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["ming_flash_omni"] + + stage0, stage1 = pipeline.stages + assert stage0.custom_process_next_stage_input_func is None, ( + "ming_flash_omni stage 0 must not declare a full-payload producer " + "(connector path is not active for this arch)." + ) + assert stage1.custom_process_input_func is not None + assert stage1.custom_process_input_func.endswith("thinker2talker") + assert stage1.sync_process_input_func is not None + assert stage1.sync_process_input_func.endswith("thinker2talker_token_only") + + # async_chunk=True must now be rejected: removing the fake hook means + # there is no next-stage input processor for the validator to accept. + # (Positive consequence -- users can't accidentally enable async_chunk + # on an arch that doesn't actually support it.) + import pytest as _pytest + + with _pytest.raises(ValueError, match="async_chunk=True"): + merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=True)) + + # async_chunk=False merges cleanly and stage-0 yaml_engine_args carries + # no spurious full-payload hook. + merged = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False)) + assert "custom_process_next_stage_input_func" not in merged[0].yaml_engine_args, ( + "stage-0 full-payload hook unexpectedly re-appeared in yaml_engine_args" + ) + class TestSamplingConstraintsPrecedence: """Test that pipeline sampling_constraints override deploy defaults.""" diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index b834d8733b0..9c8640b51bd 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -400,16 +400,23 @@ def _make_full_payload_accumulation_runner( model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="talker", async_chunk=False, + final_output=False, + custom_process_next_stage_input_func="module.full_payload", ): runner = object.__new__(OmniConnectorModelRunnerMixin) runner.model_config = SimpleNamespace( model_arch=model_arch, model_stage=model_stage, async_chunk=async_chunk, + final_output=final_output, + custom_process_next_stage_input_func=custom_process_next_stage_input_func, ) runner._custom_process_func = object() runner._pending_full_payload_send = {} runner._stage_id = 1 + # Non-None sentinel: the gate short-circuits to False when no connector + # is configured at all (terminal stages in pipelines with no connector). + runner._omni_connector = object() return runner @@ -425,7 +432,7 @@ def test_accumulate_full_payload_output_preserves_aligned_all_zero_qwen3_omni_co def test_accumulate_full_payload_output_keeps_misaligned_all_zero_qwen3_omni_codec_rows(): - # After removing the sender-side zero filter, the accumulator keeps every + # After removing the sender-side zero filter, the full-payload accumulator keeps every # codec row including misaligned all-zero rows. The downstream consumer # (_extract_qwen3_full_payload_codec_rows) is the authoritative crop and # filters by output_token_ids. @@ -472,18 +479,45 @@ def test_accumulate_full_payload_output_keeps_all_zero_qwen3_omni_prefill_placeh def test_full_payload_output_accumulation_hook_matrix(): + """Producer-side gate: fires iff an explicit next-stage payload hook is loaded. + + A derived `*_full_payload` helper from `custom_process_input_func` is not + enough: terminal/input-only consumer stages must not enqueue orphan + downstream payloads. + """ + # Thinker / talker producer stages: explicit next-stage payload hook -> gate fires. assert _make_full_payload_accumulation_runner(model_stage="thinker")._should_accumulate_full_payload_output() assert _make_full_payload_accumulation_runner(model_stage="talker")._should_accumulate_full_payload_output() - assert not _make_full_payload_accumulation_runner(model_stage="code2wav")._should_accumulate_full_payload_output() + + # Terminal stage: even if _load_custom_func derived a builder from + # custom_process_input_func, final output stages are not producers. + runner = _make_full_payload_accumulation_runner(model_stage="code2wav", final_output=True) + assert not runner._should_accumulate_full_payload_output() + + # Input-only consumer stage without an explicit producer hook must not + # accumulate/send just because a same-module *_full_payload helper exists. + runner = _make_full_payload_accumulation_runner( + model_stage="token2audio", + custom_process_next_stage_input_func=None, + ) + assert not runner._should_accumulate_full_payload_output() + + # async_chunk mode -> gate off. assert not _make_full_payload_accumulation_runner( model_stage="talker", async_chunk=True )._should_accumulate_full_payload_output() - assert not _make_full_payload_accumulation_runner( - model_arch="Qwen3TTSForConditionalGeneration" - )._should_accumulate_full_payload_output() - assert not _make_full_payload_accumulation_runner( - model_arch="Qwen2_5OmniForConditionalGeneration" - )._should_accumulate_full_payload_output() + + # Non-qwen3 arches: gate is arch-agnostic, but if the fixture's arch + # does not configure a connector payload builder, its runtime + # `_custom_process_func` is None. Emulate that. + runner = _make_full_payload_accumulation_runner(model_arch="Qwen3TTSForConditionalGeneration") + runner._custom_process_func = None + runner._should_accumulate_full_payload_output_cached = None + assert not runner._should_accumulate_full_payload_output() + runner = _make_full_payload_accumulation_runner(model_arch="Qwen2_5OmniForConditionalGeneration") + runner._custom_process_func = None + runner._should_accumulate_full_payload_output_cached = None + assert not runner._should_accumulate_full_payload_output() def test_sync_local_stage_payloads_retains_payload_until_request_is_active(): diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 09ee55ba972..3d51765a9af 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -22,9 +22,8 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, - uses_qwen3_omni_full_payload_input_coordinator, + uses_full_payload_input_coordinator, ) -from vllm_omni.core.sched.output import OmniSchedulerOutput from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, ) @@ -81,7 +80,7 @@ def __init__(self, *args, **kwargs): if getattr(model_config, "async_chunk", False): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) self.input_coordinator: OmniSchedulingCoordinator | None = None - if uses_qwen3_omni_full_payload_input_coordinator(model_config): + if uses_full_payload_input_coordinator(model_config): self.input_coordinator = OmniSchedulingCoordinator( scheduler_max_num_seqs=self.vllm_config.scheduler_config.max_num_seqs, stage_id=getattr(model_config, "stage_id", 0), @@ -211,18 +210,8 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] for req in list(queue): if getattr(req, "status", None) == RequestStatus.FINISHED_ABORTED: queue.remove(req) - connector_output = self._latest_omni_connector_output - self._latest_omni_connector_output = None - if self.input_coordinator: - if connector_output and connector_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, connector_output.request_metadata, model_mode="ar" - ) - self.input_coordinator.process_pending_full_payload_inputs( - self.waiting, - self.running, - connector_output.stage_recv_req_ids if connector_output else set(), - ) + self._consume_pending_connector_output(model_mode="ar") + self._process_pending_input_timeouts() if self.chunk_transfer_adapter: self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running) @@ -278,13 +267,9 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] finished_reqs = {} # Wrap in omni scheduler output to carry transfer metadata. - base_fields = SchedulerOutput.__dataclass_fields__.keys() - base_data = {name: getattr(scheduler_output, name) for name in base_fields} - input_regs = self.input_coordinator.pending_input_registrations if self.input_coordinator else [] - return OmniSchedulerOutput( - **base_data, + return self._wrap_omni_scheduler_output( + scheduler_output, finished_requests_needing_kv_transfer=finished_reqs, - pending_input_registrations=input_regs, ) def update_from_output( @@ -581,15 +566,7 @@ def update_from_output( engine_core_outputs[0] = eco = EngineCoreOutputs() eco.scheduler_stats = stats - omni_output = getattr(model_runner_output, "omni_connector_output", None) - if omni_output is not None: - self._latest_omni_connector_output = omni_output - if self.input_coordinator and omni_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, - omni_output.request_metadata, - model_mode="ar", - ) + self._capture_omni_connector_output(model_runner_output) # Free blocks that were held for transfer (kv_ready and # active_kv_transfers updates already done before the per-request loop). @@ -668,70 +645,73 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) - # 2. Omni Specific: Check if we need to transfer KV - if self._should_transfer_kv_for_request(request_id): - already_triggered = request_id in self.transfer_triggered_requests - is_active = request_id in self.active_kv_transfers - - if already_triggered: - if is_active: - # It triggered but hasn't finished yet. We MUST wait. - logger.debug(f"[Omni] Request {request_id} finished but transfer is still ACTIVE. Waiting.") + # Mirror the generation scheduler's try/finally pattern so the + # input_coordinator entry is always pruned along every return path, + # including the early returns for in-flight / waiting KV transfers + # below. _free_input_coordinator_request is a no-op when the + # coordinator is None, so the unconditional finally is safe. + try: + # 2. Omni Specific: Check if we need to transfer KV + if self._should_transfer_kv_for_request(request_id): + already_triggered = request_id in self.transfer_triggered_requests + is_active = request_id in self.active_kv_transfers + + if already_triggered: + if is_active: + # It triggered but hasn't finished yet. We MUST wait. + logger.debug(f"[Omni] Request {request_id} finished but transfer is still ACTIVE. Waiting.") + self.waiting_for_transfer_free.add(request_id) + kv_xfer_params = None + return kv_xfer_params + elif request_id in self.waiting_for_transfer_free: + # Blocks held until KV extraction completes in a future step. + return None + else: + logger.debug( + f"[Omni] Request {request_id} finished and transfer no longer ACTIVE (extracted/acked). " + "Freeing immediately." + ) + else: self.waiting_for_transfer_free.add(request_id) - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - kv_xfer_params = None + confirmed_computed = self._get_confirmed_num_computed_tokens(request) + self._mark_request_for_kv_transfer(request_id, confirmed_computed) + # Return KV transfer metadata so it propagates to RequestOutput + if request_id in self.requests_needing_kv_transfer: + transfer_data = self.requests_needing_kv_transfer[request_id] + kv_xfer_params = { + "past_key_values": transfer_data["block_ids"], + "kv_metadata": { + "seq_len": transfer_data["seq_len"], + "block_ids": transfer_data["block_ids"], + }, + } + # Also update request.additional_information for good measure + add_info = getattr(request, "additional_information", None) + # If additional_information is an AdditionalInformationPayload-like object, + # unpack it into a plain dict. + if ( + add_info is not None + and hasattr(add_info, "entries") + and isinstance(getattr(add_info, "entries"), dict) + ): + request.additional_information = deserialize_additional_information(add_info) + add_info = request.additional_information + if add_info is None: + request.additional_information = {} + add_info = request.additional_information + if isinstance(add_info, dict): + add_info.update(kv_xfer_params) + return kv_xfer_params - elif request_id in self.waiting_for_transfer_free: - # Blocks held until KV extraction completes in a future step. - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - return None - else: - logger.debug( - f"[Omni] Request {request_id} finished and transfer no longer ACTIVE (extracted/acked). " - "Freeing immediately." - ) - else: - self.waiting_for_transfer_free.add(request_id) - confirmed_computed = self._get_confirmed_num_computed_tokens(request) - self._mark_request_for_kv_transfer(request_id, confirmed_computed) - # Return KV transfer metadata so it propagates to RequestOutput - if request_id in self.requests_needing_kv_transfer: - transfer_data = self.requests_needing_kv_transfer[request_id] - kv_xfer_params = { - "past_key_values": transfer_data["block_ids"], - "kv_metadata": {"seq_len": transfer_data["seq_len"], "block_ids": transfer_data["block_ids"]}, - } - # Also update request.additional_information for good measure - add_info = getattr(request, "additional_information", None) - # If additional_information is an AdditionalInformationPayload-like object, - # unpack it into a plain dict. - if ( - add_info is not None - and hasattr(add_info, "entries") - and isinstance(getattr(add_info, "entries"), dict) - ): - request.additional_information = deserialize_additional_information(add_info) - add_info = request.additional_information - if add_info is None: - request.additional_information = {} - add_info = request.additional_information - if isinstance(add_info, dict): - add_info.update(kv_xfer_params) - - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - return kv_xfer_params - - # 3. Standard Freeing - delay_free_blocks |= connector_delay_free_blocks - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - if not delay_free_blocks: - self._free_blocks(request) - return kv_xfer_params + # 3. Standard Freeing + delay_free_blocks |= connector_delay_free_blocks + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + finally: + self._free_input_coordinator_request(request_id) def _free_blocks(self, request: Request): # Helper to match base class structure if not directly available diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 957b7e5f677..8d665574213 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -26,9 +26,9 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, - uses_qwen3_omni_full_payload_input_coordinator, + uses_full_payload_input_coordinator, ) -from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData, OmniSchedulerOutput +from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, ) @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) self._pending_finish_reqs: list[Request] = [] self.input_coordinator: OmniSchedulingCoordinator | None = None - if uses_qwen3_omni_full_payload_input_coordinator(model_config): + if uses_full_payload_input_coordinator(model_config): self.input_coordinator = OmniSchedulingCoordinator( scheduler_max_num_seqs=self.vllm_config.scheduler_config.max_num_seqs, stage_id=getattr(model_config, "stage_id", 0), @@ -82,18 +82,8 @@ def schedule(self) -> SchedulerOutput: # Temporary queue: preserve waiting order, do not disturb non-diffusion requests skipped_waiting_requests = create_request_queue(self.policy) req_index = 0 - connector_output = self._latest_omni_connector_output - self._latest_omni_connector_output = None - if self.input_coordinator: - if connector_output and connector_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, connector_output.request_metadata, model_mode="generation" - ) - self.input_coordinator.process_pending_full_payload_inputs( - self.waiting, - self.running, - connector_output.stage_recv_req_ids if connector_output else set(), - ) + self._consume_pending_connector_output(model_mode="generation") + self._process_pending_input_timeouts() if self.chunk_transfer_adapter: self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running) @@ -227,14 +217,7 @@ def schedule(self) -> SchedulerOutput: res = super().schedule() if self.input_coordinator: self.input_coordinator.restore_queues(self.waiting, self.running) - base_fields = SchedulerOutput.__dataclass_fields__.keys() - base_data = {name: getattr(res, name) for name in base_fields} - return OmniSchedulerOutput( - **base_data, - pending_input_registrations=( - self.input_coordinator.pending_input_registrations if self.input_coordinator else [] - ), - ) + return self._wrap_omni_scheduler_output(res) # Compute common prefix blocks (aligned with v1) num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) @@ -362,14 +345,7 @@ def schedule(self) -> SchedulerOutput: if self.input_coordinator: self.input_coordinator.restore_queues(self.waiting, self.running) - base_fields = SchedulerOutput.__dataclass_fields__.keys() - base_data = {name: getattr(scheduler_output, name) for name in base_fields} - return OmniSchedulerOutput( - **base_data, - pending_input_registrations=( - self.input_coordinator.pending_input_registrations if self.input_coordinator else [] - ), - ) + return self._wrap_omni_scheduler_output(scheduler_output) def finish_requests(self, request_ids, finished_status: RequestStatus) -> list[tuple[str, int]]: """Handles the finish signal from outside the scheduler. @@ -683,15 +659,7 @@ def update_from_output( engine_core_outputs[0] = eco = EngineCoreOutputs() eco.scheduler_stats = stats - omni_output = getattr(model_runner_output, "omni_connector_output", None) - if omni_output is not None: - self._latest_omni_connector_output = omni_output - if self.input_coordinator and omni_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, - omni_output.request_metadata, - model_mode="generation", - ) + self._capture_omni_connector_output(model_runner_output) return engine_core_outputs diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index 606739e9087..570fa554545 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -1,8 +1,37 @@ from __future__ import annotations +import os +from typing import Any + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreEventType from vllm.v1.request import Request, RequestStatus, StreamingUpdate +from vllm_omni.core.sched.output import OmniChunkRecvHandle, OmniSchedulerOutput + +logger = init_logger(__name__) + +# Upper bound on how long a request may sit in full-payload-input wait +# (the state ``OmniSchedulingCoordinator`` records via ``_waiting_since``) +# before the scheduler force-fails it. Defends against stuck consumer-side +# requests when the producer drops a full-payload, send fails, or recv +# never arrives. Override per-deployment via +# VLLM_OMNI_INPUT_WAIT_TIMEOUT_S; set <=0 to disable the safety net. +# +# Scope: this constant only covers the full-payload coordinator path +# (``input_coordinator``). The async-chunk path uses +# ``chunk_transfer_adapter`` and is not affected by this constant. +_INPUT_WAIT_TIMEOUT_RAW = os.environ.get("VLLM_OMNI_INPUT_WAIT_TIMEOUT_S", "300") +try: + DEFAULT_INPUT_WAIT_TIMEOUT_S: float = float(_INPUT_WAIT_TIMEOUT_RAW) +except ValueError: + logger.warning( + "Invalid VLLM_OMNI_INPUT_WAIT_TIMEOUT_S=%r; falling back to 300 seconds.", + _INPUT_WAIT_TIMEOUT_RAW, + ) + DEFAULT_INPUT_WAIT_TIMEOUT_S = 300.0 + class OmniSchedulerMixin: """Shared scheduler helpers for omni-specific request handling.""" @@ -13,6 +42,111 @@ def _free_input_coordinator_request(self, request_id: str) -> None: if input_coordinator is not None: input_coordinator.free_finished_request(request_id) + # ------------------------------------------------------------------ # + # Shared scheduler/output helpers (lift the AR / generation duplicates) + # ------------------------------------------------------------------ # + + def _consume_pending_connector_output(self, model_mode: str) -> None: + """Drain ``self._latest_omni_connector_output`` into the coordinator. + + Called at the top of every ``schedule()`` cycle. Identical between + AR and generation schedulers except for the ``model_mode`` argument + forwarded to ``update_request_metadata``. + """ + connector_output = getattr(self, "_latest_omni_connector_output", None) + self._latest_omni_connector_output = None + input_coordinator = getattr(self, "input_coordinator", None) + if input_coordinator is None: + return + if connector_output and connector_output.request_metadata: + input_coordinator.update_request_metadata( + self.requests, connector_output.request_metadata, model_mode=model_mode + ) + input_coordinator.process_pending_full_payload_inputs( + self.waiting, + self.running, + connector_output.stage_recv_req_ids if connector_output else set(), + ) + + def _process_pending_input_timeouts(self) -> None: + """Force-fail requests waiting on the full-payload coordinator too long. + + Called at the top of every ``schedule()`` cycle, right after + ``_consume_pending_connector_output``. Without this hook, a request + whose producer dropped a payload would sit in the + full-payload-input wait state indefinitely (the runner mixin + protects ``_pending_load_reqs`` from prune sweeps). + + Reads ``_waiting_since`` timestamps maintained by the input + coordinator and delegates to the base scheduler's + ``finish_requests`` to mark expired requests FINISHED_ERROR. + Disabled when ``DEFAULT_INPUT_WAIT_TIMEOUT_S`` is <= 0. + + Scope: only covers ``input_coordinator`` (full-payload path). + Async-chunk requests park in ``chunk_transfer_adapter`` instead + and are not handled here -- if a similar safety net is needed + for the chunk path, it belongs in the chunk adapter. + """ + if DEFAULT_INPUT_WAIT_TIMEOUT_S <= 0: + return + input_coordinator = getattr(self, "input_coordinator", None) + if input_coordinator is None: + return + timed_out_ids = input_coordinator.collect_timed_out_request_ids(timeout_s=DEFAULT_INPUT_WAIT_TIMEOUT_S) + if not timed_out_ids: + return + present_ids = {req_id for req_id in timed_out_ids if req_id in self.requests} + if not present_ids: + return + logger.warning( + "Marking %d request(s) as FINISHED_ERROR after waiting > %.0fs for connector input: %s", + len(present_ids), + DEFAULT_INPUT_WAIT_TIMEOUT_S, + sorted(present_ids), + ) + self.finish_requests(present_ids, RequestStatus.FINISHED_ERROR) + + def _capture_omni_connector_output(self, model_runner_output: Any) -> None: + """Stash the model runner's omni_connector_output for next schedule(). + + Called at the tail of every ``update_from_output()`` -- identical + between AR and generation schedulers. Only stashes the output; + applying the metadata is the responsibility of + ``_consume_pending_connector_output()`` at the start of the next + ``schedule()`` cycle. Applying it twice (once here, once on + consume) is unsafe under ``update_request_metadata`` in + generation mode, which resets ``prompt_token_ids`` / + ``_output_token_ids`` / ``num_computed_tokens`` and would + clobber any progress between the two calls. + """ + omni_output = getattr(model_runner_output, "omni_connector_output", None) + if omni_output is None: + return + self._latest_omni_connector_output = omni_output + + def _wrap_omni_scheduler_output( + self, + base: SchedulerOutput, + *, + finished_requests_needing_kv_transfer: dict | None = None, + pending_input_registrations: list[OmniChunkRecvHandle] | None = None, + ) -> OmniSchedulerOutput: + """Wrap a base ``SchedulerOutput`` in ``OmniSchedulerOutput``. + + Pulls each base ``SchedulerOutput`` dataclass field via ``getattr`` + and forwards optional omni-specific fields. Lifted from 4 separate + copy-pastes between AR (1) and generation (3) schedulers. + """ + base_data = {name: getattr(base, name) for name in SchedulerOutput.__dataclass_fields__} + input_coordinator = getattr(self, "input_coordinator", None) + if pending_input_registrations is None: + pending_input_registrations = input_coordinator.pending_input_registrations if input_coordinator else [] + return OmniSchedulerOutput( + **base_data, + finished_requests_needing_kv_transfer=finished_requests_needing_kv_transfer or {}, + pending_input_registrations=pending_input_registrations, + ) + def _replace_session_with_streaming_update( self, session: Request, diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 6c32ed4cda8..4056fd93861 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -19,16 +19,64 @@ from vllm.logger import init_logger from vllm.v1.request import Request, RequestStatus +from vllm_omni.core.sched.output import OmniChunkRecvHandle + logger = init_logger(__name__) -def uses_qwen3_omni_full_payload_input_coordinator(model_config: Any) -> bool: - return ( - getattr(model_config, "stage_id", 0) > 0 - and not getattr(model_config, "async_chunk", False) - and getattr(model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration" - and getattr(model_config, "model_stage", None) in {"talker", "code2wav"} +# (arch, model_stage) pairs that route their full_payload stage input via +# the worker connector and therefore need the scheduler-side coordinator to +# park requests in WAITING_FOR_INPUT until the recv side delivers. This set +# must stay aligned with the arch scope of `init_omni_connectors` in +# gpu_ar_model_runner.py and gpu_generation_model_runner.py. Adding a stage +# here without also wiring its worker connector init produces a permanent +# Stage 1 hang (gate parks the request, no transport ever releases it). +# +_FULL_PAYLOAD_INPUT_STAGES: frozenset[tuple[str, str]] = frozenset( + { + ("Qwen3OmniMoeForConditionalGeneration", "talker"), + ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), + # qwen2_5_omni thinker->talker uses the real full-payload + # producer builder (text_hidden_states routed via + # pooler_output["hidden"] -> accumulator -> connector). Both + # stages of qwen2_5_omni are enabled. + ("Qwen2_5OmniForConditionalGeneration", "talker"), + ("Qwen2_5OmniForConditionalGeneration", "code2wav"), + # covo_audio: fused_thinker_talker (Stage 0) -> code2wav (Stage 1). + ("CovoAudioForConditionalGeneration", "code2wav"), + # mimo_audio: fused_thinker_talker (Stage 0) -> code2wav (Stage 1). + ("MiMoAudioModel", "code2wav"), + # qwen3_tts: Qwen3TTSTalkerForConditionalGeneration (Stage 0) + # -> Qwen3TTSCode2Wav (Stage 1). Stage 1 is the consumer. + ("Qwen3TTSCode2Wav", "code2wav"), + # cosyvoice3: cosyvoice3_talker (Stage 0) -> cosyvoice3_code2wav (Stage 1). + ("CosyVoice3Model", "cosyvoice3_code2wav"), + # dynin: token2text (Stage 0) -> token2image (Stage 1) -> + # token2audio (Stage 2). Producer wires via + # custom_process_next_stage_input_func: *_full_payload in deploy yaml. + ("DyninOmniForConditionalGeneration", "token2image"), + ("DyninOmniForConditionalGeneration", "token2audio"), + } +) + + +def uses_full_payload_input_coordinator(model_config: Any) -> bool: + """Returns True iff this stage parks pending requests in + WAITING_FOR_INPUT awaiting a full_payload delivery on the worker connector. + + Gated by (model_arch, model_stage) — see _FULL_PAYLOAD_INPUT_STAGES for the + rationale on why this is a whitelist instead of a marker-driven structural + gate. + """ + if getattr(model_config, "stage_id", 0) <= 0: + return False + if getattr(model_config, "async_chunk", False): + return False + key = ( + getattr(model_config, "model_arch", None), + getattr(model_config, "model_stage", None), ) + return key in _FULL_PAYLOAD_INPUT_STAGES class OmniSchedulingCoordinator: @@ -59,7 +107,11 @@ def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: # Requests waiting for full_payload stage input (WAITING_FOR_INPUT). self._waiting_for_input: deque[Any] = deque() - self.pending_input_registrations: list[Any] = [] + # Per-cycle list of minimal handles to ship to the model runner so it + # can call register_chunk_recv(). Typed concretely (not list[Any]) so + # the surrounding OmniSchedulerOutput stays msgspec-friendly across + # default, PD-disagg, and multi-node executor IPC paths. + self.pending_input_registrations: list[OmniChunkRecvHandle] = [] # Monotonic timestamp recording when each request first entered # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by @@ -166,7 +218,12 @@ def process_pending_full_payload_inputs( self._waiting_since.setdefault(request.request_id, time.monotonic()) to_remove.append(request) self._waiting_for_input.append(request) - self.pending_input_registrations.append(request) + self.pending_input_registrations.append( + OmniChunkRecvHandle( + request_id=request.request_id, + external_req_id=getattr(request, "external_req_id", None), + ) + ) elif request.status == RequestStatus.WAITING_FOR_INPUT: if request.request_id in stage_recv_req_ids: request.status = RequestStatus.WAITING @@ -174,7 +231,12 @@ def process_pending_full_payload_inputs( else: to_remove.append(request) self._waiting_for_input.append(request) - self.pending_input_registrations.append(request) + self.pending_input_registrations.append( + OmniChunkRecvHandle( + request_id=request.request_id, + external_req_id=getattr(request, "external_req_id", None), + ) + ) if to_remove: # Use the bulk-remove helper: one O(N) sweep instead of N # repeated O(N) removes from a list-backed queue. diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index 800881d9ff8..29cd872998f 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import Any from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.request import Request @@ -72,9 +71,26 @@ class OmniCachedRequestData(CachedRequestData): additional_information: dict[str, dict | None] +@dataclass +class OmniChunkRecvHandle: + """Minimal identifier carried from scheduler to runner for chunk-recv + registration. + + The runner's ``register_chunk_recv`` only consumes ``request_id`` and + ``external_req_id`` from each pending request, so we ship just those + two fields instead of the full Request object. Concrete typing + keeps msgspec serialization deterministic across IPC (default, + PD-disagg, multi-node executor variants) and avoids the + ``list[Any]`` fallback path. + """ + + request_id: str + external_req_id: str | None = None + + @dataclass class OmniSchedulerOutput(SchedulerOutput): """Scheduler output with omni-specific transfer metadata.""" finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict) - pending_input_registrations: list[Any] = field(default_factory=list) + pending_input_registrations: list[OmniChunkRecvHandle] = field(default_factory=list) diff --git a/vllm_omni/deploy/dynin_omni_ci.yaml b/vllm_omni/deploy/dynin_omni_ci.yaml index 525b7d888c2..2ddc281e7a7 100644 --- a/vllm_omni/deploy/dynin_omni_ci.yaml +++ b/vllm_omni/deploy/dynin_omni_ci.yaml @@ -14,6 +14,7 @@ stage_args: worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler engine_output_type: latent + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_full_payload trust_remote_code: true gpu_memory_utilization: 0.5 enforce_eager: true @@ -36,6 +37,7 @@ stage_args: worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler engine_output_type: latent + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2image_to_token2audio_full_payload trust_remote_code: true gpu_memory_utilization: 0.2 enforce_eager: true diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 5023307ff8c..2dc81174cb0 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -784,12 +784,14 @@ def forward( else: self._stream_vocoder_cache_by_req[req_id] = new_cache_state else: + token_offset = max(0, meta.talker_prefill_offset or 0) if meta else 0 tts_speech = self.code2wav.forward( token=token.unsqueeze(0), prompt_token=speech_token[:1], prompt_feat=speech_feat[:1], embedding=embedding[:1], n_timesteps=10, + token_offset_tokens=token_offset, ) audio = tts_speech.reshape(-1).to(dtype=torch.float32) diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py index 186a258c809..cb3228c13a7 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py @@ -292,6 +292,7 @@ def forward( prompt_feat: torch.Tensor, embedding: torch.Tensor, n_timesteps: int = 10, + token_offset_tokens: int = 0, ) -> torch.Tensor: """Generate audio waveform from speech tokens.""" feat = self._forward_mel( @@ -300,7 +301,7 @@ def forward( prompt_feat=prompt_feat, embedding=embedding, n_timesteps=n_timesteps, - token_offset_tokens=0, + token_offset_tokens=token_offset_tokens, streaming=False, finalize=True, ) diff --git a/vllm_omni/model_executor/models/cosyvoice3/pipeline.py b/vllm_omni/model_executor/models/cosyvoice3/pipeline.py index 4480a0dd831..ed35c93bd13 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/pipeline.py +++ b/vllm_omni/model_executor/models/cosyvoice3/pipeline.py @@ -31,6 +31,7 @@ owns_tokenizer=True, engine_output_type="latent", async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"), + custom_process_next_stage_input_func=f"{_PROC}.text2flow_full_payload", sampling_constraints={ # merged speech stop token (logsumexp of all 200 stop logits) "stop_token_ids": [6562], @@ -44,7 +45,8 @@ final_output=True, final_output_type="audio", engine_output_type="latent", - sync_process_input_func=f"{_PROC}.text2flow", + custom_process_input_func=f"{_PROC}.text2flow", + sync_process_input_func=f"{_PROC}.text2flow_token_only", ), ), ) diff --git a/vllm_omni/model_executor/models/covo_audio/pipeline.py b/vllm_omni/model_executor/models/covo_audio/pipeline.py index 5b1a31d6ea8..97053e3286f 100644 --- a/vllm_omni/model_executor/models/covo_audio/pipeline.py +++ b/vllm_omni/model_executor/models/covo_audio/pipeline.py @@ -29,6 +29,7 @@ owns_tokenizer=True, requires_multimodal_data=True, engine_output_type="latent", + custom_process_next_stage_input_func=f"{_PROC}.llm2code2wav_full_payload", sampling_constraints={ "detokenize": True, "stop_token_ids": [151645], @@ -44,6 +45,7 @@ final_output_type="audio", engine_output_type="audio", custom_process_input_func=f"{_PROC}.llm2code2wav", + sync_process_input_func=f"{_PROC}.llm2code2wav_token_only", sampling_constraints={"detokenize": False}, ), ), diff --git a/vllm_omni/model_executor/models/mimo_audio/pipeline.py b/vllm_omni/model_executor/models/mimo_audio/pipeline.py index 70d14ef78aa..126c901763c 100644 --- a/vllm_omni/model_executor/models/mimo_audio/pipeline.py +++ b/vllm_omni/model_executor/models/mimo_audio/pipeline.py @@ -39,6 +39,7 @@ owns_tokenizer=True, engine_output_type="latent", async_chunk_process_next_stage_input_func=(f"{_PROC}.llm2code2wav_async_chunk"), + custom_process_next_stage_input_func=f"{_PROC}.llm2code2wav_full_payload", sampling_constraints={ "detokenize": True, # Stop once the speech/text interleaved span ends. Code2Wav @@ -55,7 +56,8 @@ final_output=True, final_output_type="audio", engine_output_type="audio", - sync_process_input_func=f"{_PROC}.llm2code2wav", + custom_process_input_func=f"{_PROC}.llm2code2wav", + sync_process_input_func=f"{_PROC}.llm2code2wav_token_only", sampling_constraints={"detokenize": False}, ), ), diff --git a/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py b/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py index a9d66fbc22b..a1e1ef4699a 100644 --- a/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py +++ b/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py @@ -58,6 +58,7 @@ engine_output_type="audio", tokenizer_subdir="talker/llm", custom_process_input_func=f"{_PROC}.thinker2talker", + sync_process_input_func=f"{_PROC}.thinker2talker_token_only", ), ), ) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py index de0644803b5..afd0a92a531 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py @@ -30,6 +30,7 @@ requires_multimodal_data=True, engine_output_type="latent", sampling_constraints={"detokenize": True}, + custom_process_next_stage_input_func=f"{_PROC}.thinker2talker_full_payload", ), StagePipelineConfig( stage_id=1, @@ -38,6 +39,8 @@ input_sources=(0,), engine_output_type="latent", custom_process_input_func=f"{_PROC}.thinker2talker", + sync_process_input_func=f"{_PROC}.thinker2talker_token_only", + custom_process_next_stage_input_func=f"{_PROC}.talker2code2wav_full_payload", sampling_constraints={ "detokenize": True, "stop_token_ids": [8294], @@ -52,6 +55,7 @@ final_output_type="audio", engine_output_type="audio", custom_process_input_func=f"{_PROC}.talker2code2wav", + sync_process_input_func=f"{_PROC}.talker2code2wav_token_only", sampling_constraints={"detokenize": True}, ), ), @@ -74,6 +78,7 @@ requires_multimodal_data=True, engine_output_type="latent", sampling_constraints={"detokenize": True}, + custom_process_next_stage_input_func=f"{_PROC}.thinker2talker_full_payload", ), ), ) diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py index 5051715ceac..7f50931ddf7 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py +++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py @@ -26,6 +26,7 @@ owns_tokenizer=True, engine_output_type="latent", async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"), + custom_process_next_stage_input_func=f"{_PROC}.talker2code2wav_full_payload", sampling_constraints={ "detokenize": False, "stop_token_ids": [2150], @@ -40,7 +41,14 @@ final_output_type="audio", engine_output_type="audio", model_arch="Qwen3TTSCode2Wav", - sync_process_input_func=f"{_PROC}.talker2code2wav", + # ``sync_process_input_func`` is the only input-proc override for + # this stage in sync (non-async-chunk) mode: a length-only + # ``_token_only`` placeholder. The bulk codec payload itself + # ships via the worker connector from stage 0's + # ``talker2code2wav_full_payload`` producer. Under async_chunk + # mode no pre-stage processing is needed -- chunks deliver + # directly to the consumer. + sync_process_input_func=f"{_PROC}.talker2code2wav_token_only", sampling_constraints={"detokenize": True}, extras={"tts_args": {"max_instructions_length": 500}}, ), diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index b46018e3616..b7c3c9d6e46 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -25,6 +25,27 @@ logger = init_logger(__name__) +def _codec_ids_from_payload_or_input( + input_ids: torch.Tensor, + runtime_info: dict[str, Any] | None, +) -> torch.Tensor: + """Prefer connector-delivered codec ids over token placeholders. + + In non-async full-payload mode, the scheduler only needs placeholder + token ids for allocation. The real codec sequence is delivered through + model_intermediate_buffer as ``codes.audio``. + """ + if isinstance(runtime_info, dict): + codes = runtime_info.get("codes") + if isinstance(codes, dict): + audio = codes.get("audio") + if isinstance(audio, torch.Tensor) and audio.numel() > 0: + return audio.reshape(-1).to(device=input_ids.device, dtype=torch.long) + if isinstance(audio, (list, tuple)) and audio: + return torch.as_tensor(audio, device=input_ids.device, dtype=torch.long).reshape(-1) + return input_ids.reshape(-1).to(dtype=torch.long) + + class Qwen3TTSCode2Wav(nn.Module): """Stage-1 code2wav model for Qwen3-TTS (GenerationModelRunner). Consumes frame-aligned codec tokens from input_ids and decodes waveform @@ -239,6 +260,7 @@ def forward( multimodal_outputs={"model_outputs": [empty], "sr": [sr_tensor]}, ) + runtime_infos = runtime_additional_information or [] ids = input_ids.reshape(-1).to(dtype=torch.long) request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts")) @@ -246,14 +268,18 @@ def forward( valid_codes_qf: list[torch.Tensor] = [] valid_indices: list[int] = [] left_context_size = [0] * len(request_ids_list) - if runtime_additional_information is not None: - for i, info in enumerate(runtime_additional_information): + if runtime_infos: + for i, info in enumerate(runtime_infos): if i >= len(left_context_size): break + if not isinstance(info, dict): + continue meta = info.get("meta", {}) if "left_context_size" in meta: left_context_size[i] = meta["left_context_size"] for i, req_ids in enumerate(request_ids_list): + runtime_info = runtime_infos[i] if i < len(runtime_infos) else None + req_ids = _codec_ids_from_payload_or_input(req_ids, runtime_info) if req_ids.numel() < 1: parsed.append((0, 0)) continue diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml index 131a0d1cd70..024443e8d16 100644 --- a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml +++ b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml @@ -6,6 +6,7 @@ stage_args: max_batch_size: 1 engine_args: model_stage: token2text + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_full_payload model_arch: DyninOmniForConditionalGeneration worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml index 4a54f8188aa..1ab65f0fab9 100644 --- a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml @@ -6,6 +6,7 @@ stage_args: max_batch_size: 1 engine_args: model_stage: token2text + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_full_payload model_arch: DyninOmniForConditionalGeneration worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index cf1ca39ee59..4c7245e773f 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -7,6 +7,7 @@ import numpy as np import torch from vllm.inputs import TextPrompt +from vllm.logger import init_logger from vllm_omni.data_entry_keys import ( CodesStruct, @@ -16,6 +17,10 @@ ) from vllm_omni.inputs.data import OmniTokensPrompt +logger = init_logger(__name__) + +_COSYVOICE3_SPEECH_TOKEN_SIZE = 6561 + def _build_prompt_embed_struct(prompt_payload: dict[str, Any]) -> EmbeddingsStruct | None: """Wrap prompt_payload's flat speech_token/speech_feat/embedding tensors into EmbeddingsStruct.""" @@ -46,6 +51,52 @@ def _ensure_list(x: Any) -> list[Any]: return [x] +def _to_token_id_list(value: Any) -> list[int]: + if value is None: + return [] + if isinstance(value, torch.Tensor): + value = value.detach().to("cpu").reshape(-1).tolist() + token_ids: list[int] = [] + for item in _ensure_list(value): + if isinstance(item, torch.Tensor): + token_ids.extend(_to_token_id_list(item)) + continue + if isinstance(item, (list, tuple)): + token_ids.extend(_to_token_id_list(item)) + continue + token_ids.append(int(item)) + return token_ids + + +def _strip_prompt_prefix(output_ids: list[Any], prefix_ids: list[Any]) -> list[Any]: + if prefix_ids and len(output_ids) >= len(prefix_ids) and output_ids[: len(prefix_ids)] == prefix_ids: + return output_ids[len(prefix_ids) :] + return output_ids + + +def _prompt_speech_token_ids(multi_modal_data: dict[str, Any]) -> list[int]: + speech_token = multi_modal_data.get("speech_token") + if speech_token is None: + embed = multi_modal_data.get("embed") + if isinstance(embed, dict): + speech_token = embed.get("speech_token") + return _to_token_id_list(speech_token) + + +def _has_speech_stop_token(output_ids: list[Any]) -> bool: + return any(token_id >= _COSYVOICE3_SPEECH_TOKEN_SIZE for token_id in _to_token_id_list(output_ids)) + + +def _set_non_stream_prompt_trim(additional_info: dict[str, Any], prompt_speech_len: int) -> None: + if prompt_speech_len <= 0: + return + meta = additional_info.get("meta") + if not isinstance(meta, dict): + meta = {} + additional_info["meta"] = meta + meta["talker_prefill_offset"] = prompt_speech_len + + def _to_cpu_tensor(x: Any) -> torch.Tensor | None: if isinstance(x, list): if not x: @@ -95,9 +146,14 @@ def text2flow( if multi_modal_data is None: raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}") - output_ids = _ensure_list(output.cumulative_token_ids) prefix_ids = _ensure_list(source_output.prompt_token_ids) + raw_output_ids = _ensure_list(output.cumulative_token_ids) + prompt_speech_ids = _prompt_speech_token_ids(multi_modal_data) + output_ids = _strip_prompt_prefix(raw_output_ids, prefix_ids) + output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids) additional_info = dict(multi_modal_data) + if _has_speech_stop_token(raw_output_ids): + _set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids)) additional_info.setdefault("ids", {})["prompt"] = prefix_ids engine_inputs.append(OmniTokensPrompt(prompt_token_ids=output_ids, additional_information=additional_info)) return engine_inputs @@ -271,3 +327,107 @@ def talker2code2wav_async_chunk( state["emitted_chunks"] = int(state.get("emitted_chunks", 0)) + 1 return payload + + +# ============================================================================ +# Worker-connector data plane (non-async-chunk path). +# cosyvoice3 talker emits `multimodal_outputs={"embed": {"speech_token": t, +# "speech_feat": t, "embedding": t}}` ONLY at prefill (decode steps emit +# `{}`). After flatten_payload these become flat top-level keys +# `embed.speech_token` etc., persisted across decode steps by the +# full-payload accumulator (decode doesn't re-emit them). Shipping via the connector +# keeps the orchestrator off the heavy-tensor path. +# ============================================================================ + +# All three embed tensors are emitted once at prefill and must REPLACE-not- +# CONCAT across the (already trivial) per-request accumulator history so a +# regression where decode unexpectedly re-emits them does not silently +# duplicate the prefill tensor. See mixin._FULL_PAYLOAD_REPLACE_KEYS. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset({"embed.speech_token", "embed.speech_feat", "embed.embedding"}) + + +def text2flow_token_only( + source_outputs: list, + prompt: OmniTokensPrompt | TextPrompt = None, + _requires_multimodal_data: bool = True, +): + """Sync-side builder for the non-async-chunk text→flow path. + + CosyVoice3 sync keeps codec ids on the legacy token path. Some vLLM v1 + histories include the source prompt prefix, so strip it only when it is an + exact leading match. + """ + del prompt + engine_inputs: list[OmniTokensPrompt] = [] + for source_output in source_outputs: + if not source_output.finished: + continue + output = source_output.outputs[0] + prefix_ids = _ensure_list(source_output.prompt_token_ids) + raw_output_ids = _ensure_list(output.cumulative_token_ids) + output_ids = _strip_prompt_prefix(raw_output_ids, prefix_ids) + multi_modal_data = output.multimodal_output + if multi_modal_data is None: + raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}") + prompt_speech_ids = _prompt_speech_token_ids(multi_modal_data) + output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids) + additional_info: dict[str, Any] = dict(multi_modal_data) + if _has_speech_stop_token(raw_output_ids): + _set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids)) + additional_info.setdefault("ids", {})["prompt"] = prefix_ids + engine_inputs.append( + OmniTokensPrompt( + prompt_token_ids=output_ids, + additional_information=additional_info, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return engine_inputs + + +def text2flow_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side payload builder. + + Reads prefill-emitted `embed.{speech_token, speech_feat, embedding}` from + the accumulator and ships prompt conditioning as a connector payload. + The downstream flow stage reads these from `model_intermediate_buffer` + (see cosyvoice3.py:671 in the code2wav forward — runtime_info pickup). + """ + del transfer_manager + rid = getattr(request, "external_req_id", None) or getattr(request, "request_id", "?") + if not isinstance(pooling_output, dict): + logger.warning( + "cosyvoice3.text2flow_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) + return None + embed_out: dict[str, Any] = {} + for key in ("speech_token", "speech_feat", "embedding"): + v = pooling_output.get(f"embed.{key}") + if v is None: + nested = pooling_output.get("embed") + if isinstance(nested, dict): + v = nested.get(key) + if isinstance(v, torch.Tensor) and v.numel() > 0: + embed_out[key] = v + if not embed_out: + logger.warning( + "cosyvoice3.text2flow_full_payload: no embed.{speech_token,speech_feat,embedding} " + "found in pooling_output (keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) + return None + return { + "meta": { + "finished": torch.tensor(True, dtype=torch.bool), + }, + "embed": embed_out, + } diff --git a/vllm_omni/model_executor/stage_input_processors/covo_audio.py b/vllm_omni/model_executor/stage_input_processors/covo_audio.py index a0a964bdd2f..7b5ed8c266d 100644 --- a/vllm_omni/model_executor/stage_input_processors/covo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/covo_audio.py @@ -1,33 +1,103 @@ # Copyright 2026 Tencent. from typing import Any +import torch +from vllm.logger import init_logger + from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX +logger = init_logger(__name__) + +# Per-model REPLACE-keys for the full-payload accumulator (none for covo_audio: +# the producer side does not emit per-step hidden_states / model_outputs; +# llm2code2wav_full_payload reads token_ids directly from `request`). +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _filter_audio_codes(token_ids: list[int]) -> list[int]: + """Filter codec-range token ids and rebase by COVO_AUDIO_TOKEN_INDEX.""" + audio_codes = [t - COVO_AUDIO_TOKEN_INDEX for t in token_ids if t >= COVO_AUDIO_TOKEN_INDEX] + if not audio_codes: + audio_codes = [-1] + return audio_codes + def llm2code2wav( source_outputs: list[Any], prompt: Any = None, requires_multimodal_data: bool = False, ) -> list[OmniTokensPrompt]: + """Legacy orchestrator-path builder (retained for async_chunk + back-compat). + + The non-async-chunk path now goes through ``llm2code2wav_token_only`` + + worker connector + ``llm2code2wav_full_payload``. + """ talker_outputs = source_outputs code2wav_inputs = [] - for i, talker_output in enumerate(talker_outputs): + for talker_output in talker_outputs: output = talker_output.outputs[0] - token_ids = output.token_ids + audio_codes = _filter_audio_codes(list(output.token_ids)) + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=audio_codes, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return code2wav_inputs - audio_codes = [t - COVO_AUDIO_TOKEN_INDEX for t in token_ids if t >= COVO_AUDIO_TOKEN_INDEX] - if not audio_codes: - audio_codes = [-1] +def llm2code2wav_token_only( + source_outputs: list[Any], + prompt: Any = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side placeholder for the non-async-chunk Stage-1 input. + Returns an OmniTokensPrompt sized to the code2wav stage's expected + prefill length (one slot per audio code). The actual codec ids are + delivered via the worker connector payload built by + ``llm2code2wav_full_payload``. + """ + code2wav_inputs: list[OmniTokensPrompt] = [] + for output_wrapper in source_outputs: + output = output_wrapper.outputs[0] + audio_codes = _filter_audio_codes(list(output.token_ids)) code2wav_inputs.append( OmniTokensPrompt( - prompt_token_ids=audio_codes, + prompt_token_ids=[0] * len(audio_codes), + additional_information=None, multi_modal_data=None, mm_processor_kwargs=None, ) ) - return code2wav_inputs + + +def llm2code2wav_full_payload( + transfer_manager: Any, + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + """Producer-side payload builder for the worker connector data plane. + + covo_audio's fused_thinker_talker stage emits codec ids via + ``request.output_token_ids`` (token-ids only -- no + hidden_states or embed tensors), so the connector payload is + just the filtered audio codes plus a finished marker. + """ + output_token_ids = list(getattr(request, "output_token_ids", None) or []) + if not output_token_ids: + logger.warning( + "covo_audio.llm2code2wav_full_payload: empty output_token_ids for req=%s; consumer wait gate may hang.", + getattr(request, "request_id", "?"), + ) + return None + audio_codes = _filter_audio_codes(output_token_ids) + return { + "codes": {"audio": audio_codes}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py index 87cecc1033d..5a5804de394 100644 --- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py @@ -5,9 +5,12 @@ import torch from vllm.inputs import TextPrompt +from vllm.logger import init_logger from vllm_omni.inputs.data import OmniTokensPrompt +logger = init_logger(__name__) + def _to_prompt_dict(prompt_item: OmniTokensPrompt | TextPrompt | str | None) -> dict[str, Any]: if isinstance(prompt_item, dict): @@ -146,3 +149,128 @@ def token2image_to_token2audio( requires_multimodal_data: bool = False, ): return _bridge_tokens(source_outputs, prompt, requires_multimodal_data) + + +# ============================================================================ +# Worker-connector data plane (non-async-chunk path). +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. dynin_omni's +# producer model emits new chunks per step (token_ids / runtime_info_json), +# all of which use the default CONCAT/replace semantics — no model_outputs +# entry needs explicit REPLACE. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _build_full_payload(pooling_output: dict[str, Any] | None, request: Any) -> dict[str, Any] | None: + """Producer-side payload builder: assemble dynin_omni connector payload. + + Reads token_ids from ``pooling_output["token_ids"]`` (preferred) or + ``request.output_token_ids`` (fallback). Reads structured non-tensor + metadata from ``pooling_output["runtime_info_json"]`` (JSON-in-uint8) + if present, falling back to ``pooling_output["runtime_info"]`` dict. + Carries forward ``request.additional_information`` so prompt-side + metadata (speaker / language / detok_id) survives the IPC boundary. + """ + if not isinstance(pooling_output, dict): + pooling_output = {} + + token_ids = _to_token_id_list(pooling_output.get("token_ids")) + if not token_ids: + token_ids = _to_token_id_list(pooling_output.get("text_tokens")) + if not token_ids and request is not None: + token_ids = _to_token_id_list(getattr(request, "output_token_ids", None)) + if not token_ids: + logger.warning( + "dynin_omni._build_full_payload: no token_ids found in pooling_output " + "(keys=%s) or request.output_token_ids for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + getattr(request, "request_id", "?"), + ) + return None + + src_additional_info = getattr(request, "additional_information", {}) if request is not None else {} + if not isinstance(src_additional_info, dict): + src_additional_info = {} + + runtime_bridge_info = _decode_runtime_bridge_info(pooling_output.get("runtime_info_json")) + if not runtime_bridge_info: + runtime_bridge_info = pooling_output.get("runtime_info", {}) or {} + + payload = _normalize_additional_info(src_additional_info) + payload.update(_normalize_additional_info(runtime_bridge_info)) + payload["detok_id"] = [_to_int(pooling_output.get("detok_id"), default=_to_int(payload.get("detok_id"), default=0))] + # Use nested OmniPayload shape so the scheduling-metadata extractor in + # OmniConnectorModelRunnerMixin reads codes.audio and meta.finished + # (flat keys at the top level are silently dropped with a warning). + payload["codes"] = {"audio": token_ids} + payload["meta"] = {"finished": torch.tensor(True, dtype=torch.bool)} + return payload + + +def token2text_to_token2image_full_payload( + transfer_manager: Any, + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + """Producer-side payload builder for the Stage-0 → Stage-1 (text → image) transition.""" + del transfer_manager + return _build_full_payload(pooling_output, request) + + +def token2image_to_token2audio_full_payload( + transfer_manager: Any, + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + """Producer-side payload builder for the Stage-1 → Stage-2 (image → audio) transition.""" + del transfer_manager + return _build_full_payload(pooling_output, request) + + +def _token_only_from_source(source_outputs: list[Any]) -> list[OmniTokensPrompt]: + """Length-only placeholder list mirroring ``_bridge_tokens`` token counts.""" + inputs: list[OmniTokensPrompt] = [] + for source_output in source_outputs: + output = source_output.outputs[0] + mm_out = getattr(output, "multimodal_output", None) or {} + token_ids = _to_token_id_list(mm_out.get("token_ids")) + if not token_ids: + token_ids = _to_token_id_list(mm_out.get("text_tokens")) + if not token_ids: + token_ids = list(getattr(output, "token_ids", []) or []) + if not token_ids: + token_ids = [0] + inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * len(token_ids), + additional_information=None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return inputs + + +def token2text_to_token2image_token_only( + stage_list, + engine_input_source, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side placeholder for Stage-1 input (token2image).""" + source_stage_id = engine_input_source[0] if engine_input_source else 0 + source_outputs = stage_list[source_stage_id].engine_outputs + return _token_only_from_source(source_outputs) + + +def token2image_to_token2audio_token_only( + stage_list, + engine_input_source, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side placeholder for Stage-2 input (token2audio).""" + source_stage_id = engine_input_source[0] if engine_input_source else 0 + source_outputs = stage_list[source_stage_id].engine_outputs + return _token_only_from_source(source_outputs) diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index b0b6e887857..2443ad2c479 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -147,6 +147,18 @@ def llm2code2wav_async_chunk( Accumulates codes in connector per request_id, returns payload only when chunk_size is full or request is finished; returns None when waiting. """ + # Null guard: chunk_transfer_adapter calls this every emit step + # including no-output steps where pooling_output is None. + if pooling_output is None or not isinstance(pooling_output, dict): + if is_finished: + connector = getattr(transfer_manager, "connector", None) + raw_cfg = getattr(connector, "config", {}) or {} + cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} + chunk_size = int(cfg.get("codec_chunk_frames", 3)) + left_context_size = int(cfg.get("codec_left_context_frames", 3)) + request_id = getattr(request, "external_req_id", None) + return _flush_remaining_codes(transfer_manager, request_id, chunk_size, left_context_size) + return None connector = getattr(transfer_manager, "connector", None) raw_cfg = getattr(connector, "config", {}) or {} cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} @@ -300,3 +312,137 @@ def llm2code2wav( ) return code2wav_inputs + + +# ============================================================================ +# Worker-connector data plane (non-async-chunk path). +# AR runner's `flatten_payload` converts the model emit +# `multimodal_outputs={"codes": {"audio": ...}}` to flat +# `pooling_output["codes.audio"]` before the full-payload accumulator runs, so default +# CONCAT semantics build the full codec tensor across all decode steps. +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. mimo_audio's +# producer side emits per-step codec frames that should be CONCAT'd across +# steps (not REPLACE'd), so this stays empty. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _filter_zero_codec_rows(codec_codes: torch.Tensor) -> torch.Tensor: + """Drop zero-padded codec rows from a 4-D `[N, 1, 8, 4]` tensor. + + Mirrors the zero-row filter in the orchestrator-path `llm2code2wav` + body (see this file's ``llm2code2wav`` around line 224). + """ + if codec_codes.ndim != 4 or codec_codes.numel() == 0: + return codec_codes + is_all_zero = (codec_codes == 0).all(dim=(1, 2, 3)) + nonzero_idx = (~is_all_zero).nonzero(as_tuple=True)[0] + if len(nonzero_idx) == 0: + # All rows are zero-padded; return an empty tensor so the caller + # can detect this via numel()==0 and skip the request. + return codec_codes[:0] + if len(nonzero_idx) < codec_codes.shape[0]: + return codec_codes[nonzero_idx] + return codec_codes + + +def llm2code2wav_token_only( + source_outputs: list, + _prompt=None, + _requires_multimodal_data: bool = False, +) -> list: + """Sync-side placeholder for the non-async-chunk Stage-1 (code2wav) input. + + Returns an ``OmniTokensPrompt`` sized to the orchestrator-shape codec + length so the consumer runtime allocates the right number of slots. + The actual codec ids are delivered via the worker connector payload + built by ``llm2code2wav_full_payload``. + """ + from vllm_omni.inputs.data import OmniTokensPrompt + + code2wav_inputs: list = [] + for output_wrapper in source_outputs: + out = output_wrapper.outputs[0] + mm = out.multimodal_output if hasattr(out, "multimodal_output") else None + mm = mm if isinstance(mm, dict) else {} + mm_codes = mm.get("codes", {}) if isinstance(mm, dict) else {} + prompt_len = 0 + if isinstance(mm_codes, dict) and "audio" in mm_codes: + audio = mm_codes["audio"] + if isinstance(audio, torch.Tensor) and audio.numel() > 0: + audio = audio.to(torch.long) + audio = _filter_zero_codec_rows(audio) + # +B*4 per batch row for the prepended pad_vec (see prepend_and_flatten_colmajor) + batch_size = int(audio.shape[0]) if audio.ndim >= 1 else 1 + prompt_len = int(audio.numel()) + batch_size * 4 + if prompt_len > MAX_CODE2WAV_TOKENS: + prompt_len = MAX_CODE2WAV_TOKENS + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * prompt_len, + additional_information=None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return code2wav_inputs + + +def llm2code2wav_full_payload( + transfer_manager, + pooling_output: dict, + request, +) -> dict | None: + """Producer-side payload builder for the worker connector data plane. + + AR runner's ``flatten_payload`` converts the per-step model emit + ``{"codes": {"audio": ...}}`` to ``pooling_output["codes.audio"]``. + The accumulator CONCATs per-step tensors along dim 0, so by flush + time this holds the full ``[total_steps, 1, 8, 4]`` codec tensor. + + A back-compat fallback to nested ``pooling_output["codes"]["audio"]`` + is kept in case a future runtime path bypasses `flatten_payload`. + """ + del transfer_manager + rid = getattr(request, "request_id", "?") + if not isinstance(pooling_output, dict): + logger.warning( + "mimo_audio.llm2code2wav_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) + return None + codec_codes = pooling_output.get("codes.audio") + if codec_codes is None: + # Back-compat fallback for un-flattened pooler emits. + codes = pooling_output.get("codes") + if isinstance(codes, dict): + codec_codes = codes.get("audio") + if not isinstance(codec_codes, torch.Tensor) or codec_codes.numel() == 0: + logger.warning( + "mimo_audio.llm2code2wav_full_payload: missing/empty codes.audio " + "(keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) + return None + codec_codes = codec_codes.to(torch.long) + codec_codes = _filter_zero_codec_rows(codec_codes) + if codec_codes.numel() == 0: + logger.warning( + "mimo_audio.llm2code2wav_full_payload: codec_codes empty after _filter_zero_codec_rows for req=%s.", + rid, + ) + return None + + pad_vec = torch.tensor([TALKER_CODEC_PAD_TOKEN_ID] * 4) + code_final = prepend_and_flatten_colmajor(codec_codes, pad_vec).tolist() + if len(code_final) > MAX_CODE2WAV_TOKENS: + code_final = code_final[:MAX_CODE2WAV_TOKENS] + + return { + "codes": {"audio": code_final}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } diff --git a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py index e0d538cb3b0..938018856f4 100644 --- a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py @@ -535,9 +535,97 @@ def thinker2talker( return talker_inputs +# ming_flash_omni is not in ``_OMNI_CONNECTOR_INIT_ARCHS`` or +# ``_FULL_PAYLOAD_INPUT_STAGES``, so the worker connector is not +# initialised for this arch and the consumer never waits on a connector +# payload. Data flows through ``additional_information`` written by +# ``thinker2talker_token_only`` (wired as ``sync_process_input_func`` +# in the pipeline) or the legacy ``thinker2talker`` (wired as +# ``custom_process_input_func``). + + +def thinker2talker_token_only( + source_outputs: list[Any], + prompt: OmniTokensPrompt | TextPrompt | None = None, + _requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side builder for the non-async-chunk thinker→talker path. + + Ports the legacy ``thinker2talker`` body to the new stage-input-processor signature + (``source_outputs`` instead of ``stage_list, engine_input_source``). + Body is otherwise identical: extracts the + generated text from each thinker output and packages it with the + request's voice/speaker additional_information for the talker. + """ + if not isinstance(prompt, list): + prompt = [prompt] + + talker_inputs: list[OmniTokensPrompt] = [] + for i, source_output in enumerate(source_outputs): + output = source_output.outputs[0] + + generated_text = output.text if hasattr(output, "text") and output.text else "" + + original_prompt = prompt[i] if i < len(prompt) else None + additional_info: dict[str, Any] = {} + if original_prompt is not None and hasattr(original_prompt, "additional_information"): + additional_info = original_prompt.additional_information or {} + + spk_emb = additional_info.get("spk_emb", None) + if isinstance(spk_emb, list) and spk_emb and not hasattr(spk_emb[0], "device"): + import torch + + spk_emb = torch.tensor(spk_emb, dtype=torch.float32).unsqueeze(0) + + talker_info = { + "ming_task": "omni", + "text": generated_text, + "spk_emb": spk_emb, + "voice_name": additional_info.get("voice_name", "DB30"), + "prompt_text": additional_info.get("prompt_text", None), + "prompt_wav_lat": additional_info.get("prompt_wav_lat", None), + "prompt_wav_emb": additional_info.get("prompt_wav_emb", None), + "max_text_length": additional_info.get("max_text_length", 50), + } + + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0], + additional_information=talker_info, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return talker_inputs + + +thinker2talker_token_only._is_sync_input = True + + +def thinker2talker_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side payload builder — no-op. + + ming_flash_omni's thinker emits no heavy tensor to ship via the + worker connector (the bridge passes text only, and speaker metadata + arrives through the USER request's additional_information). + ming_flash_omni is not in ``_OMNI_CONNECTOR_INIT_ARCHS`` so this + function is never invoked at runtime; it is retained for forward + compatibility with the connector path. + """ + del transfer_manager, pooling_output, request + return None + + __all__ = [ "CFG_TEXT_SUFFIX", "expand_cfg_prompts", "thinker2imagegen", "thinker2talker", + "thinker2talker_full_payload", + "thinker2talker_token_only", ] diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index 472dcb93386..225674360c2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -1,3 +1,5 @@ +import logging + import torch from vllm.inputs import TextPrompt @@ -11,6 +13,8 @@ ) from vllm_omni.inputs.data import OmniTokensPrompt +logger = logging.getLogger(__name__) + TALKER_CODEC_PAD_TOKEN_ID = 8292 TALKER_CODEC_START_TOKEN_ID = 8293 TALKER_CODEC_END_TOKEN_ID = 8294 @@ -86,3 +90,347 @@ def talker2code2wav( ) ) return code2wav_inputs + + +# ============================================================================ +# Worker-connector data plane (non-async-chunk path). +# Both transitions ship payloads via the worker connector +# (registered in ``_FULL_PAYLOAD_INPUT_STAGES`` in +# omni_scheduling_coordinator): +# - thinker->talker reads accumulated ``pooling_output["hidden"]`` and +# packs an OmniPayload-shaped dict (embed.prefill / +# hidden_states.output / ids.prompt / ids.output) for the talker, which +# the talker's ``talker_preprocess`` reads from +# ``model_intermediate_buffer``. The shape matches what legacy +# ``thinker2talker`` writes into ``additional_information`` as a debug +# fallback; ``thinker2talker_token_only`` only allocates prompt slots. +# - talker->code2wav strips TALKER_CODEC_{START,END} boundary tokens +# and ships the codec token ids. +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. qwen2_5_omni's +# producer side does not emit model_outputs through pooler_output (it ships +# token_ids on the request directly), so the empty set preserves correctness. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _strip_codec_boundaries(token_ids: list[int]) -> list[int]: + """Keep only real codec ids for the code2wav stage. + + The talker stream can contain prompt/control ids (START/PAD/END/MASK) in + addition to sampled codec ids. Code2wav expects codec ids only; carrying + the prompt PAD span forward can inflate the sequence enough to OOM on L4. + Async scheduling may also leave trailing ``-1`` placeholders, so preserve + their length by repeating the last valid codec id. + """ + tids = list(token_ids) + trailing_placeholder_count = 0 + while trailing_placeholder_count < len(tids) and tids[-1 - trailing_placeholder_count] == -1: + trailing_placeholder_count += 1 + + if tids and tids[-1] == TALKER_CODEC_END_TOKEN_ID: + tids = tids[:-1] + trailing_placeholder_count = 0 + + codec_ids = [tid for tid in tids if 0 <= tid < TALKER_CODEC_PAD_TOKEN_ID] + if trailing_placeholder_count > 0 and codec_ids: + codec_ids.extend([codec_ids[-1]] * trailing_placeholder_count) + return codec_ids + + +def talker2code2wav_token_only( + source_outputs, + _prompt: OmniTokensPrompt | TextPrompt = None, + _requires_multimodal_data: bool = False, +): + """Sync-side placeholder for Stage-2 input (code2wav). + + Returns OmniTokensPrompt sized to the stripped codec token count. + Actual codec ids are delivered via the worker connector payload built + by ``talker2code2wav_full_payload``. + """ + code2wav_inputs = [] + for talker_output in source_outputs: + output = talker_output.outputs[0] + token_ids = _strip_codec_boundaries(list(output.cumulative_token_ids)) + if not token_ids: + continue + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * len(token_ids), + additional_information=None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return code2wav_inputs + + +def talker2code2wav_full_payload( + transfer_manager, + pooling_output: dict, + request, +) -> dict | None: + """Producer-side payload builder: ship the stripped codec ids via connector. + + Token-ids-only shape. The talker stage's output already + carries the codec ids on ``request.output_token_ids``; we strip the + boundary tokens and pack a minimal payload. + """ + del transfer_manager + rid = getattr(request, "request_id", "?") + token_ids = list(getattr(request, "output_token_ids", None) or []) + if not token_ids: + logger.warning( + "qwen2_5_omni.talker2code2wav_full_payload: empty output_token_ids " + "for req=%s; consumer wait gate may hang.", + rid, + ) + return None + token_ids = _strip_codec_boundaries(token_ids) + if not token_ids: + logger.warning( + "qwen2_5_omni.talker2code2wav_full_payload: codec ids empty after " + "stripping boundary tokens for req=%s; consumer wait gate may hang.", + rid, + ) + return None + return { + "codes": {"audio": token_ids}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } + + +# ============================================================================ +# Worker-connector data plane (non-async-chunk path) -- thinker->talker. +# +# qwen2_5_omni's talker consumes the thinker's last-layer hidden state +# via Linear(3584, 896). The AR runner publishes those hidden states +# per decode step on ``pooling_output["hidden"]`` (unpacked from +# ``OmniOutput.text_hidden_states``); the full-payload accumulator +# concatenates them so ``thinker2talker_full_payload`` sees the full +# prefill+decode trajectory and packs an OmniPayload-shaped dict that +# the talker's ``talker_preprocess`` reads from +# ``model_intermediate_buffer``. ``thinker2talker_token_only`` only +# allocates the talker's codec prompt slots; legacy +# ``thinker2talker`` above remains as a debug fallback that bundles the +# same shape into ``additional_information``. +# ============================================================================ + + +def thinker2talker_token_only( + source_outputs, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +): + """Placeholder builder for the connector-driven thinker->talker path. + + Allocates the TALKER_CODEC_{START,PAD,END} prompt slots sized to the + thinker prompt length and forwards ``multi_modal_data``. The bulk + payload (hidden_states / embed / ids) ships exclusively through + ``thinker2talker_full_payload`` via the worker connector and lands + in ``model_intermediate_buffer`` before the talker's forward() runs. + + Consumer-wait gating is whitelist-driven via + ``_FULL_PAYLOAD_INPUT_STAGES`` (see the mixin + ``should_accumulate_full_payload_output`` docstring). + """ + thinker_outputs = source_outputs + talker_inputs = [] + if not isinstance(prompt, list): + prompt = [prompt] + multi_modal_data = { + thinker_output.request_id: p.get("multi_modal_data", None) for thinker_output, p in zip(thinker_outputs, prompt) + } + + for thinker_output in thinker_outputs: + prompt_token_ids = thinker_output.prompt_token_ids + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[TALKER_CODEC_START_TOKEN_ID] + + [TALKER_CODEC_PAD_TOKEN_ID] * (len(prompt_token_ids)) + + [TALKER_CODEC_END_TOKEN_ID], + additional_information=None, + multi_modal_data=( + multi_modal_data[thinker_output.request_id] + if requires_multimodal_data and multi_modal_data is not None + else None + ), + mm_processor_kwargs=None, + ) + ) + + return talker_inputs + + +def thinker2talker_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side payload builder for the worker-connector data plane. + + The AR runner emits per-step ``pooling_output["hidden"]`` (the + thinker's last-layer hidden states for the request span, unpacked + from ``OmniOutput.text_hidden_states``). The full-payload + accumulator concatenates those per-step rows across decode steps, so + by the time this builder fires the materialized + ``pooling_output["hidden"]`` contains the full prefill+decode + hidden-state trajectory of size + ``len(prompt_token_ids) + len(output_token_ids)``. + + We split it at ``len(prompt_token_ids)`` into prefill embeddings and + decode hidden states, then pack the ``OmniPayload``-shaped dict that + the talker's ``thinker_to_talker_process`` reads from + ``model_intermediate_buffer`` (keys ``embed.prefill`` / + ``hidden_states.output`` / ``ids.prompt`` / ``ids.output``). Shape + matches what legacy ``thinker2talker`` writes into + ``additional_information`` as a debug fallback, so the talker + consumes the same payload layout from either path. + + Like ``qwen3_omni.thinker2talker_full_payload``, we apply a + finish-reason-aware stop-row trim: vLLM v1 appends the sampled + token to ``output_token_ids`` before ``check_stop``, so a request + that finished via ``FINISHED_STOPPED`` has one extra accumulated + hidden-state row that the talker must not consume. Max-token + finishes need no drop. Status is read from the request when + available; otherwise we fall back to a last-token-in-stop-set + heuristic. + """ + del transfer_manager + rid = getattr(request, "request_id", "?") + if not isinstance(pooling_output, dict): + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) + return None + + hidden = pooling_output.get("hidden") + if not isinstance(hidden, torch.Tensor): + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: missing 'hidden' tensor " + "(keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) + return None + + def _ensure_list(x): + if x is None: + return [] + if hasattr(x, "_x"): + # vLLM wraps cached token-id lists in ConstantList-like objects. + return list(x._x) + if isinstance(x, list): + return list(x) + return list(x) + + prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", None)) + output_token_ids = _ensure_list(getattr(request, "output_token_ids", None)) + all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) + if not all_token_ids: + all_token_ids = list(prompt_token_ids) + list(output_token_ids) + + # Length-aware trim of accumulated thinker output, finish-reason-aware. + # Mirror qwen3_omni.thinker2talker_full_payload's logic so a stop-finish + # does not leak an extra hidden-state row to the talker. + status = getattr(request, "status", None) + status_name = getattr(status, "name", None) or "" + if not status_name and status is not None: + status_name = str(status).rsplit(".", 1)[-1] + stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 + if stop_emission_drop == 0 and not status_name and output_token_ids: + # Worker-side CachedRequestState has no `.status` field in vLLM v1; + # fall back to a last-token-in-stop-set heuristic. + sampling_params = getattr(request, "sampling_params", None) + if sampling_params is not None: + stop_ids: set[int] = set() + ignore_eos = bool(getattr(sampling_params, "ignore_eos", False)) + for sid in getattr(sampling_params, "stop_token_ids", None) or (): + if isinstance(sid, int): + stop_ids.add(sid) + if not ignore_eos: + for eos in ( + getattr(sampling_params, "eos_token_id", None), + getattr(sampling_params, "_eos_token_id", None), + ): + if isinstance(eos, int): + stop_ids.add(eos) + for sid in ( + getattr(sampling_params, "all_stop_token_ids", None) + or getattr(sampling_params, "_all_stop_token_ids", None) + or () + ): + if isinstance(sid, int): + stop_ids.add(sid) + if stop_ids and output_token_ids[-1] in stop_ids: + stop_emission_drop = 1 + + # Trim accumulated thinker output based on stop_emission_drop computed + # above. Mirror qwen3_omni.thinker2talker_full_payload's contract: + # target_rows = len(all_token_ids) - stop_emission_drop + # which excludes the stop-emission row for FINISHED_STOPPED but keeps + # all rows for FINISHED_LENGTH_CAPPED (max_tokens) finishes. + if stop_emission_drop > 0 and len(output_token_ids) >= stop_emission_drop: + output_token_ids = output_token_ids[:-stop_emission_drop] + h = hidden.detach().cpu().to(torch.float32) + target_rows = max(0, len(all_token_ids) - stop_emission_drop) + if target_rows <= 0: + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: target_rows<=0 " + "(all_token_ids=%d, stop_drop=%d) for req=%s; nothing to ship.", + len(all_token_ids), + stop_emission_drop, + getattr(request, "request_id", "?"), + ) + return None + if h.dim() >= 1 and h.shape[0] > target_rows: + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: excess hidden rows " + "(got %d, target %d, stop_drop %d) for req=%s; trimming", + int(h.shape[0]), + target_rows, + stop_emission_drop, + getattr(request, "request_id", None), + ) + h = h[:target_rows] + + prompt_len = len(prompt_token_ids) + if h.shape[0] < prompt_len: + # Under-captured prefill -- defensively skip rather than ship a + # truncated payload that would confuse the talker's prefill path. + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: hidden rows=%d < prompt_len=%d " + "for req=%s; under-captured prefill, skipping payload.", + int(h.shape[0]), + prompt_len, + getattr(request, "request_id", "?"), + ) + return None + + prefill_hidden = h[:prompt_len] + decode_hidden = h[prompt_len:] + + payload: OmniPayload = to_dict( + OmniPayloadStruct( + hidden_states=HiddenStatesStruct( + output=decode_hidden, + output_shape=list(decode_hidden.shape), + ), + embed=EmbeddingsStruct( + prefill=prefill_hidden, + prefill_shape=list(prefill_hidden.shape), + ), + ids=IdsStruct( + prompt=list(prompt_token_ids), + output=list(output_token_ids), + ), + ) + ) + # Intentionally omit payload["meta"]: the thinker->talker transition + # carries no scheduler-relevant metadata (next_stage_prompt_len / + # left_context_size are not set on this edge). + return payload diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index fd7cfd2aa60..b671951b201 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -35,6 +35,12 @@ # Pooling output layer keys: "0" = word embedding, "24" = accept_hidden_layer _EMBED_LAYER_KEY = "0" _HIDDEN_LAYER_KEY = "24" +# Per-model REPLACE-keys for the full-payload accumulator. Keys in this +# set use REPLACE semantics (subsequent emissions discard prior chunks) +# instead of CONCAT. qwen3-omni currently has none — model_outputs is +# not emitted by the thinker/talker forward. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + _QWEN3_CODEC_CODEBOOK_SIZE = 2048 _QWEN3_CODEC_PAD_TOKEN_ID = 4196 _QWEN3_CODEC_BOS_TOKEN_ID = 4197 @@ -119,19 +125,6 @@ def _is_valid_qwen3_codec_token_id(token_id: Any) -> bool: return 0 <= token_id < _QWEN3_CODEC_CODEBOOK_SIZE -def should_accumulate_qwen3_omni_full_payload_output( - model_config: Any, - custom_process_func: Any, -) -> bool: - """Return whether Qwen3-Omni should accumulate full-payload outputs.""" - return ( - custom_process_func is not None - and not getattr(model_config, "async_chunk", False) - and getattr(model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration" - and getattr(model_config, "model_stage", None) in {"thinker", "talker"} - ) - - def _extract_qwen3_full_payload_codec_rows( code_predictor_codes: torch.Tensor, output_token_ids: list[int], @@ -462,7 +455,13 @@ def thinker2talker_full_payload( request: OmniEngineCoreRequest, ) -> dict[str, Any] | None: """Pack complete thinker output for the non-async connector path.""" + rid = getattr(request, "request_id", None) if not isinstance(pooling_output, dict): + logger.warning( + "thinker2talker_full_payload: pooling_output not a dict (type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None layers = { @@ -475,11 +474,13 @@ def thinker2talker_full_payload( hidden = pooling_output.get("hidden") thinker_emb = hidden if isinstance(hidden, torch.Tensor) else None if thinker_emb is None or thinker_hid is None: - logger.debug( - "thinker2talker_full_payload: missing thinker tensors for req=%s (embed=%s hidden=%s)", - getattr(request, "request_id", None), + logger.warning( + "thinker2talker_full_payload: missing thinker tensors for req=%s " + "(embed=%s hidden=%s keys=%s); consumer wait gate may hang.", + rid, thinker_emb is not None, thinker_hid is not None, + list(pooling_output.keys()), ) return None @@ -493,7 +494,7 @@ def thinker2talker_full_payload( # The accumulator captures one hidden-state row per executed thinker # forward (prefill + every decode step including the one that emitted # the stop_token), so for a finished request thinker_emb has exactly one - # row more than the rows the talker should consume. async_chunk's + # row more than the rows the talker should consume. async_chunk's # chunk-0 path naturally captures only the prefill / non-stop portion, # which is why the [async_chunk] parametrization passes while [default] # over-generates one codec frame on short outputs (e.g. @@ -797,7 +798,14 @@ def talker2code2wav_full_payload( request: OmniEngineCoreRequest, ) -> dict[str, Any] | None: """Pack complete talker codec output for the non-async connector path.""" + rid = getattr(request, "request_id", None) if not isinstance(pooling_output, dict): + logger.warning( + "talker2code2wav_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None code_predictor_codes = pooling_output.get("codes.audio") if code_predictor_codes is None: @@ -805,10 +813,19 @@ def talker2code2wav_full_payload( if isinstance(codes, dict): code_predictor_codes = codes.get("audio") if code_predictor_codes is None: + logger.warning( + "talker2code2wav_full_payload: missing codes.audio (keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) return None if not isinstance(code_predictor_codes, torch.Tensor): code_predictor_codes = torch.as_tensor(code_predictor_codes) if code_predictor_codes.numel() == 0: + logger.warning( + "talker2code2wav_full_payload: empty codes.audio for req=%s; consumer wait gate may hang.", + rid, + ) return None output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) @@ -818,6 +835,16 @@ def talker2code2wav_full_payload( list(output_token_ids), ) if code_predictor_codes.numel() == 0: + logger.warning( + "talker2code2wav_full_payload: no valid codec rows after filtering " + "(raw_shape=%s output_ids_len=%d aligned_rows=%s valid_rows=%s) for req=%s; " + "consumer wait gate may hang.", + raw_shape, + len(output_token_ids), + codec_stats["aligned_rows"], + codec_stats["valid_rows"], + rid, + ) return None codec_codes = code_predictor_codes.transpose(0, 1).cpu().reshape(-1).tolist() diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index faa7e4cc4d3..1ffbdb931a9 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -53,10 +53,13 @@ def talker2code2wav( # audio_codes may still contain zero-padded / invalid rows, so trim only # after filtering valid frames instead of trying to align EOS indices. seq_len = max(len(token_ids) - 1, 0) - # Filter invalid frames: zero-padded (EOS) and frames containing - # out-of-range values (e.g. stop_token_id=2150 exceeds codebook_size=2048). + # Filter invalid frames: zero-padded (EOS), out-of-range values (e.g. + # stop_token_id=2150 exceeds codebook_size=2048), and negative + # sentinels (e.g. -1 padding). _CODEBOOK_SIZE = 2048 - valid_mask = audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + valid_mask = ( + (audio_codes >= 0).all(dim=1) & audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + ) audio_codes = audio_codes[valid_mask] if seq_len > 0 and audio_codes.ndim == 2 and int(audio_codes.shape[0]) > seq_len: audio_codes = audio_codes[-seq_len:] @@ -280,3 +283,220 @@ def talker2code2wav_async_chunk( speaker=extract_speaker_from_request(request), language=extract_language_from_request(request), ) + + +# ============================================================================ +# Worker-connector data plane (non-async-chunk path). +# AR runner's `flatten_payload` converts the model emit +# `multimodal_outputs={"codes": {"audio": ..., "ref": ...}, +# "meta": {"ref_code_len": ..., "codec_streaming": ...}}` to flat dotted +# keys (`codes.audio`, `codes.ref`, `meta.ref_code_len`, +# `meta.codec_streaming`) before the full-payload accumulator runs. +# - codes.audio is 2-D so default CONCAT across steps builds the full sequence. +# - codes.ref is a list (not Tensor with dim>=2) so accumulator LATEST-wins +# keeps the prefill-emitted ref tensor across decode steps (which don't emit +# ref again). +# - meta.ref_code_len is 1-D so LATEST-wins; consumer reads [-1]. +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. qwen3_tts's +# producer side emits codec frames that should CONCAT (codes.audio) plus +# scalars/lists that are correctly handled by default LATEST-wins, so this +# stays empty. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + +_CODEBOOK_SIZE = 2048 +_NUM_QUANTIZERS_DEFAULT = 16 + + +def _filter_audio_codes_qwen3_tts(audio_codes: torch.Tensor) -> torch.Tensor: + """Filter zero-padded, out-of-range, and negative-padded codec frames. + + Mirrors the orchestrator-path body in `talker2code2wav` above. + """ + if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: + return audio_codes + if audio_codes.ndim != 2: + return audio_codes + valid_mask = ( + (audio_codes >= 0).all(dim=1) & audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + ) + return audio_codes[valid_mask] + + +def _coerce_ref_code_len(raw) -> int: + """Coerce mm["meta"]["ref_code_len"] / pooling_output["meta.ref_code_len"] + raw value (Tensor | int | None) into a non-negative int. Mirrors the + extraction inlined in the legacy ``talker2code2wav`` path; clamps any + negative input to 0 since downstream code treats this as a non-negative + frame count.""" + if isinstance(raw, torch.Tensor): + value = int(raw.reshape(-1)[-1].item()) if raw.numel() > 0 else 0 + elif raw is None: + value = 0 + else: + value = int(raw) + return max(value, 0) + + +def _normalize_ref_code(ref_code, num_quantizers: int, ref_code_len: int): + """Coerce ref_code into a [ref_len, Q] tensor or None. Mirrors orchestrator path.""" + if isinstance(ref_code, list): + ref_code = ref_code[0] if ref_code else None + if not isinstance(ref_code, torch.Tensor) or ref_code.numel() == 0: + return None, 0 + ref_code = ref_code.to(torch.long).cpu().contiguous() + if ref_code.ndim == 1: + if ref_code.numel() % num_quantizers != 0: + return None, 0 + ref_code = ref_code.reshape(-1, num_quantizers) + elif ref_code.ndim != 2: + return None, 0 + if ref_code_len > 0 and int(ref_code.shape[0]) > ref_code_len: + ref_code = ref_code[:ref_code_len] + return ref_code, int(ref_code.shape[0]) + + +def talker2code2wav_token_only( + source_outputs: list, + prompt=None, + _requires_multimodal_data: bool = False, +) -> list: + """Sync-side placeholder for the non-async-chunk Stage-1 (code2wav) input. + + Sized to the expected codec token count (codebook-major flat: + Q * (ref_frames + audio_frames)). Speaker / language metadata are + extracted from `prompt` and threaded via `additional_information` + (orchestrator-style; same as the legacy `talker2code2wav` builder). + Actual codec ids are delivered via the worker connector payload built + by `talker2code2wav_full_payload`. + """ + from vllm_omni.inputs.data import OmniTokensPrompt + + code2wav_inputs: list = [] + for i, talker_output in enumerate(source_outputs): + if not talker_output.finished: + continue + output = talker_output.outputs[0] + mm = output.multimodal_output if hasattr(output, "multimodal_output") else None + mm = mm if isinstance(mm, dict) else {} + mm_codes = mm.get("codes", {}) if isinstance(mm, dict) else {} + token_ids = getattr(output, "cumulative_token_ids", []) or [] + seq_len = max(len(token_ids) - 1, 0) + + audio = mm_codes.get("audio") if isinstance(mm_codes, dict) else None + if isinstance(audio, torch.Tensor) and audio.numel() > 0: + audio = audio.to(torch.long) + audio = _filter_audio_codes_qwen3_tts(audio) + if seq_len > 0 and audio.ndim == 2 and int(audio.shape[0]) > seq_len: + audio = audio[-seq_len:] + num_audio_frames = int(audio.shape[0]) if audio.ndim == 2 else 0 + num_quantizers = int(audio.shape[1]) if audio.ndim == 2 and audio.shape[1] > 0 else _NUM_QUANTIZERS_DEFAULT + else: + num_audio_frames = 0 + num_quantizers = _NUM_QUANTIZERS_DEFAULT + + ref_code_raw = mm_codes.get("ref") if isinstance(mm_codes, dict) else None + ref_code_len_raw = mm.get("meta", {}).get("ref_code_len") if isinstance(mm.get("meta"), dict) else None + ref_code_len = _coerce_ref_code_len(ref_code_len_raw) + _, ref_frames = _normalize_ref_code(ref_code_raw, num_quantizers, ref_code_len) + + # Codebook-major flat: Q * (ref_frames + audio_frames) + prompt_len = num_quantizers * (ref_frames + num_audio_frames) + + additional_info = to_dict( + OmniPayloadStruct( + meta=MetaStruct(left_context_size=ref_frames) if ref_frames > 0 else None, + speaker=extract_speaker_from_prompt(prompt, index=i), + language=extract_language_from_prompt(prompt, index=i), + ) + ) + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * prompt_len, + additional_information=additional_info if additional_info else None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return code2wav_inputs + + +def talker2code2wav_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side payload builder. + + Reads accumulated codec from `pooling_output["codes.audio"]` (CONCAT + across steps via flatten_payload), latest `pooling_output["codes.ref"]` + (prefill-emitted), and latest `pooling_output["meta.ref_code_len"]`. + Replicates the orchestrator-path body of `talker2code2wav` (filter, + crop to seq_len, prepend ref, codebook-major flatten). + """ + del transfer_manager + rid = getattr(request, "request_id", "?") + if not isinstance(pooling_output, dict): + logger.warning( + "qwen3_tts.talker2code2wav_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) + return None + + # codes.audio — try flat dotted first (flatten_payload), then nested fallback. + audio = pooling_output.get("codes.audio") + if audio is None: + codes_nested = pooling_output.get("codes") + if isinstance(codes_nested, dict): + audio = codes_nested.get("audio") + if not isinstance(audio, torch.Tensor) or audio.numel() == 0: + logger.warning( + "qwen3_tts.talker2code2wav_full_payload: missing/empty codes.audio " + "(keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) + return None + audio = audio.to(torch.long) + audio = _filter_audio_codes_qwen3_tts(audio) + if audio.numel() == 0: + logger.warning( + "qwen3_tts.talker2code2wav_full_payload: audio empty after codec " + "filter (negative/all-zero/out-of-range rows dropped) for req=%s.", + rid, + ) + return None + + output_token_ids = list(getattr(request, "output_token_ids", None) or []) + seq_len = max(len(output_token_ids) - 1, 0) + if seq_len > 0 and audio.ndim == 2 and int(audio.shape[0]) > seq_len: + audio = audio[-seq_len:] + + num_quantizers = int(audio.shape[1]) if audio.ndim == 2 and audio.shape[1] > 0 else _NUM_QUANTIZERS_DEFAULT + + # meta.ref_code_len — flat dotted then nested fallback. + ref_code_len_raw = pooling_output.get("meta.ref_code_len") + if ref_code_len_raw is None: + meta_nested = pooling_output.get("meta") + if isinstance(meta_nested, dict): + ref_code_len_raw = meta_nested.get("ref_code_len") + ref_code_len = _coerce_ref_code_len(ref_code_len_raw) + + # codes.ref — flat dotted then nested fallback. + ref_code_raw = pooling_output.get("codes.ref") + if ref_code_raw is None: + codes_nested = pooling_output.get("codes") + if isinstance(codes_nested, dict): + ref_code_raw = codes_nested.get("ref") + ref_code, ref_frames = _normalize_ref_code(ref_code_raw, num_quantizers, ref_code_len) + if ref_code is not None: + audio = torch.cat([ref_code.to(audio.device), audio], dim=0) + + codec_codes = audio.transpose(0, 1).to(device="cpu", dtype=torch.long).reshape(-1).contiguous() + return { + "codes": {"audio": codec_codes}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 5da4cf9d870..60c059c5a44 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -84,11 +84,23 @@ def __init__(self, *args, **kwargs): self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) # Initialize KV cache manager (preserve vllm_config fallback behavior) self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) - # Only Qwen3-Omni currently consumes the connector-based full-payload - # handoff added in this PR. Other model architectures (e.g. Bagel - # diffusion) retain their pre-existing runner behavior so this PR - # does not perturb them. - if getattr(self.model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration": + # Worker-connector init is gated by a per-`model_arch` allowlist + # (covers both producer-side and consumer-side runners for the + # arches below). Consumer-wait stages must be registered + # separately as `(model_arch, model_stage)` tuples in + # `omni_scheduling_coordinator._FULL_PAYLOAD_INPUT_STAGES`; + # forgetting that produces a Stage-1 hang on the consumer. + _OMNI_CONNECTOR_INIT_ARCHS = { + "Qwen3OmniMoeForConditionalGeneration", + "Qwen2_5OmniForConditionalGeneration", + "CovoAudioForConditionalGeneration", + "MiMoAudioModel", + "Qwen3TTSTalkerForConditionalGeneration", + "Qwen3TTSCode2Wav", + "CosyVoice3Model", + "DyninOmniForConditionalGeneration", + } + if getattr(self.model_config, "model_arch", None) in _OMNI_CONNECTOR_INIT_ARCHS: self.init_omni_connectors( vllm_config=self.vllm_config, model_config=self.model_config, @@ -109,6 +121,57 @@ def _make_buffer(self, *size, dtype, numpy=True): with maybe_disable_pin_memory_for_ray(self, total_bytes): return super()._make_buffer(*size, dtype=dtype, numpy=numpy) + def _build_model_sampler_output_token_ids(self) -> list[list[int]]: + """Build decoded-token history for custom model samplers. + + vLLM only populates sampling_metadata.output_token_ids when penalties or + logits processors require it. CosyVoice3's custom RAS sampler also + depends on this history, so we reconstruct it directly from the input + batch for prefer_model_sampler models. + """ + req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", []) + req_ids = list(getattr(self.input_batch, "req_ids", [])) + output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))] + + sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None) + async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None) + prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None) + if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None: + return output_token_ids + + sampled_token_ids: list[list[int]] | None = None + for index, req_id in enumerate(req_ids): + prev_index = prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_history = output_token_ids[index] + if not req_history or req_history[-1] != -1: + continue + if sampled_token_ids is None: + assert async_copy_ready_event is not None + async_copy_ready_event.synchronize() + sampled_token_ids = sampled_token_ids_cpu.tolist() + new_ids = list(sampled_token_ids[prev_index]) + if not new_ids: + continue + num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) + first_placeholder = req_history.index(-1) + num_placeholders = len(req_history) - first_placeholder + num_to_replace = min(num_sampled_ids, num_placeholders) + req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace] + + for index, req_history in enumerate(output_token_ids): + if -1 in req_history: + output_token_ids[index] = req_history[: req_history.index(-1)] + + return output_token_ids + + def _sampling_metadata_for_model_sampler(self, sampling_metadata): + output_token_ids = self._build_model_sampler_output_token_ids() + if output_token_ids == sampling_metadata.output_token_ids: + return sampling_metadata + return replace(sampling_metadata, output_token_ids=output_token_ids) + def _request_final_stage_id(self, req_id: str) -> int | None: info = self.model_intermediate_buffer.get(req_id) if not isinstance(info, dict): @@ -974,7 +1037,7 @@ def propose_draft_token_ids(sampled_token_ids): req_hidden_states_cpu[rid] = hidden_states[start:end].detach().to("cpu").contiguous() pooler_output = [] - for rid in req_ids_output_copy: + for out_idx, rid in enumerate(req_ids_output_copy): if rid not in downstream_req_id_set: pooler_output.append({}) continue diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 175547ff31d..9f1060ed1aa 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -59,10 +59,18 @@ class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Scope full-payload connector init to Qwen3-Omni: other generation - # models (e.g. Bagel DiT) retain their pre-existing runner setup - # so this refactor does not perturb them. - if getattr(self.model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration": + # Mirrors the init allowlist in gpu_ar_model_runner.py. + _OMNI_CONNECTOR_INIT_ARCHS = { + "Qwen3OmniMoeForConditionalGeneration", + "Qwen2_5OmniForConditionalGeneration", + "CovoAudioForConditionalGeneration", + "MiMoAudioModel", + "Qwen3TTSTalkerForConditionalGeneration", + "Qwen3TTSCode2Wav", + "CosyVoice3Model", + "DyninOmniForConditionalGeneration", + } + if getattr(self.model_config, "model_arch", None) in _OMNI_CONNECTOR_INIT_ARCHS: self.init_omni_connectors( vllm_config=self.vllm_config, model_config=self.model_config, @@ -299,7 +307,6 @@ def execute_model( num_tokens_padded, intermediate_tensors, ) - # [Omni] Pass token counts per request for code2wav output slicing model_kwargs["seq_token_counts"] = tokens diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 0f164fea6df..f157e5db1a1 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -361,12 +361,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None # Remove finished requests from the cached states. # cleanup_finished_request lives on OmniConnectorModelRunnerMixin and - # is only safe to call once init_omni_connectors() has populated the - # mixin state. Archs that inherit the method via MRO without running - # that init must be skipped, so probe a mixin-owned attribute as the - # "state initialized" gate. + # is only safe to call once init_omni_connectors() has finished + # populating mixin state (it sets ``_omni_connector_initialized = True`` + # at the very end). Archs that inherit the method via MRO without + # running that init must be skipped, so gate on the explicit flag + # rather than probing private attribute names. cleanup_finished_request = ( - getattr(self, "cleanup_finished_request", None) if hasattr(self, "_request_ids_mapping") else None + getattr(self, "cleanup_finished_request", None) + if getattr(self, "_omni_connector_initialized", False) + else None ) for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index eb5b47b53e3..64b7e60e26c 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -46,6 +46,27 @@ logger = init_logger(__name__) +def should_accumulate_full_payload_output(model_config, custom_process_func) -> bool: + """Producer-side structural gate. + + Fires iff the stage explicitly declares a downstream full-payload + producer hook via ``custom_process_next_stage_input_func``. Consumer + stages may have ``custom_process_input_func`` values that can be + mechanically derived to ``*_full_payload`` helper names in the same + module; those are intentionally not enough to make the stage a producer. + """ + if custom_process_func is None: + return False + if getattr(model_config, "async_chunk", False): + return False + if getattr(model_config, "final_output", False): + return False + next_stage_func = getattr(model_config, "custom_process_next_stage_input_func", None) + if not isinstance(next_stage_func, str) or not next_stage_func: + return False + return getattr(model_config, "model_stage", None) is not None + + class OmniConnectorModelRunnerMixin: """Unified data-plane communication mixin for Model Runners. @@ -189,6 +210,13 @@ def init_omni_connectors( ) self._save_thread.start() + # Explicit "fully initialised" marker so other parts of the runner + # (e.g. _update_states cleanup) can branch on a stable contract + # instead of probing for private mixin attribute names. Must be set + # only after every field above has been bound, so a partially + # constructed mixin is never observable as initialised. + self._omni_connector_initialized = True + def shutdown_omni_connectors(self) -> None: """Stop background threads and release connector resources.""" self._stop_event.set() @@ -217,6 +245,30 @@ def cleanup_finished_request(self, req_id: str) -> None: saves is added to ``_deferred_send_cleanup`` so the bg save's decrement path drains it without leaving orphans. """ + # Force-flush any pending full-payload accumulator entry before + # cleanup proceeds. Without this, finished requests with no + # downstream consumer (e.g. text-only on multi-modal arch) leave + # the entry orphaned in _pending_full_payload_send across requests, + # which empirically destabilises subsequent thinker forwards by + # making prefix-cache reuse observe stale accumulator state. The + # flush is idempotent when the entry has already been flushed by the + # scheduler-driven path, but this cleanup path runs for every request, + # so skip it entirely when the request never accumulated a payload. + if req_id in self._pending_full_payload_send: + try: + self.flush_full_payload_outputs({req_id}) + except Exception: + # Cleanup must still proceed regardless of flush errors here -- + # we already gated on ``_omni_connector_initialized`` upstream, + # so any exception here reflects a real connector-side issue + # (shared memory corruption, background thread crash) worth + # surfacing rather than silently swallowing. + logger.warning( + "flush_full_payload_outputs(%s) raised during cleanup; continuing tear-down.", + req_id, + exc_info=True, + ) + ext_id = self._request_ids_mapping.pop(req_id, None) keys_to_clean: list[str] = [req_id] if ext_id is not None and ext_id != req_id: @@ -684,6 +736,27 @@ def _should_accumulate_full_payload_output(self) -> bool: _custom_process_func, both of which are set at init time. Avoid the per-step dynamic import inside the model decode loop. """ + if getattr(self, "_omni_connector", None) is None: + # No connector at all: send_full_payload_outputs would no-op. + # Skip the per-step accumulator+build that would otherwise be + # silently discarded. Defends against a terminal stage whose + # custom_process_input_func has a *_full_payload derivative in + # the same module (e.g. dynin stage 2 token2image_to_token2audio + # in pipelines that don't configure any connector at all). + # + # Known limitation: a *terminal-consumer* stage that has a + # connector configured for receiving upstream input is NOT + # caught here -- ``_omni_connector`` is non-None for it, and + # ``_load_custom_func`` may still resolve a ``*_full_payload`` + # derivative from this stage's ``custom_process_input_func``. + # In that case the accumulator builds payloads that + # ``send_full_payload_outputs`` later drops via its own + # connector-side checks (wasted CPU, not a functional bug). + # A topology-aware gate (explicit producer field or pipeline + # is_terminal info) would close the gap; that change is out + # of scope for this PR. + self._should_accumulate_full_payload_output_cached = False + return False cached = getattr(self, "_should_accumulate_full_payload_output_cached", None) if cached is not None: return cached @@ -691,19 +764,12 @@ def _should_accumulate_full_payload_output(self) -> bool: if model_config is None: self._should_accumulate_full_payload_output_cached = False return False - if getattr(model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration": - from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( - should_accumulate_qwen3_omni_full_payload_output, - ) - - result = should_accumulate_qwen3_omni_full_payload_output( - model_config, - getattr(self, "_custom_process_func", None), - ) - self._should_accumulate_full_payload_output_cached = result - return result - self._should_accumulate_full_payload_output_cached = False - return False + result = should_accumulate_full_payload_output( + model_config, + getattr(self, "_custom_process_func", None), + ) + self._should_accumulate_full_payload_output_cached = result + return result @staticmethod def _new_full_payload_accumulator(output: dict[str, Any]): @@ -729,6 +795,58 @@ def _materialize_full_payload_entry(entry): output[k] = tensors[0] if len(tensors) == 1 else torch.cat(tensors, dim=0) return output, request + def _resolve_full_payload_replace_keys(self) -> frozenset: + """Per-model REPLACE-key set for the full-payload accumulator. + + Looked up from the stage-input-processor module that ships the model's sync builder + (`model_config.custom_process_input_func.__module__`). The module + declares ``_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str]``; if absent, + returns the empty set. + + Cached per instance. Keys in this set use REPLACE semantics in the + accumulator (subsequent emissions discard prior chunks) instead of + the default CONCAT semantics. Use for tensors that carry the full + result so far rather than per-step deltas (e.g. ``model_outputs``). + """ + cached = getattr(self, "_full_payload_replace_keys_cached", None) + if cached is not None: + return cached + proc = getattr(self, "_custom_process_func", None) + if proc is None: + self._full_payload_replace_keys_cached = frozenset() + return self._full_payload_replace_keys_cached + module_name = getattr(proc, "__module__", None) + if module_name is None: + self._full_payload_replace_keys_cached = frozenset() + return self._full_payload_replace_keys_cached + try: + import sys as _sys + + mod = _sys.modules.get(module_name) or importlib.import_module(module_name) + keys = getattr(mod, "_FULL_PAYLOAD_REPLACE_KEYS", frozenset()) + except ImportError: + logger.debug( + "Could not import stage input processor module %s while resolving " + "_FULL_PAYLOAD_REPLACE_KEYS; using CONCAT semantics for all keys.", + module_name, + exc_info=True, + ) + keys = frozenset() + if not isinstance(keys, (frozenset, set)): + logger.debug( + "Ignoring non-set _FULL_PAYLOAD_REPLACE_KEYS from %s: %s", + module_name, + type(keys).__name__, + ) + keys = frozenset() + self._full_payload_replace_keys_cached = frozenset(keys) + logger.debug( + "Resolved _FULL_PAYLOAD_REPLACE_KEYS for %s: %s", + module_name, + sorted(self._full_payload_replace_keys_cached), + ) + return self._full_payload_replace_keys_cached + def accumulate_full_payload_output( self, req_id: str, @@ -751,6 +869,7 @@ def accumulate_full_payload_output( The data is actually sent when ``flush_full_payload_outputs`` is called with the finished request IDs from the next scheduler cycle. """ + replace_keys = self._resolve_full_payload_replace_keys() existing = self._pending_full_payload_send.get(req_id) if existing is None: @@ -766,6 +885,19 @@ def accumulate_full_payload_output( for k, v in pooler_output.items(): if v is None: continue + if k in replace_keys: + # Explicit REPLACE semantics: the new value supersedes any + # prior chunks (e.g. `model_outputs` carries the full result + # so far, not an appendable per-step delta). + latest.pop(k, None) + if isinstance(v, torch.Tensor) and v.dim() >= 2: + chunks[k] = [v] + rows[k] = int(v.shape[0]) + else: + chunks.pop(k, None) + rows.pop(k, None) + latest[k] = v + continue if isinstance(v, torch.Tensor) and v.dim() >= 2: if k in chunks and chunks[k] and v.shape[1:] == chunks[k][0].shape[1:]: chunks[k].append(v) @@ -783,6 +915,10 @@ def accumulate_full_payload_output( def flush_full_payload_outputs(self, finished_req_ids: set[str]) -> None: """Send accumulated full_payload outputs for requests that just finished.""" + pending_req_ids = set(self._pending_full_payload_send.keys()) + if not (finished_req_ids & pending_req_ids): + return + logger.info( "[Stage-%s] flush_full_payload_outputs: finished_req_ids=%s, pending=%s", self._stage_id, @@ -927,11 +1063,11 @@ def register_chunk_recv(self, request: Any) -> None: if self._stage_id == 0: return request_id = request.request_id - self._request_ids_mapping[request_id] = getattr( - request, - "external_req_id", - request_id, - ) + # Explicit external_req_id=None must fall back to request_id; + # otherwise recv keys become `None__` and collide + # across requests. + ext = getattr(request, "external_req_id", None) + self._request_ids_mapping[request_id] = ext if ext is not None else request_id with self._lock: if request_id in self._stage_recv_req_ids: return @@ -2115,7 +2251,10 @@ def _resolve_external_req_id(self, request: Any, fallback_req_id: str) -> str: if mapped is not None: return mapped if request is not None: - return getattr(request, "external_req_id", fallback_req_id) + # external_req_id may be explicitly None; fall back. + ext = getattr(request, "external_req_id", None) + if ext is not None: + return ext return fallback_req_id def _resolve_next_stage_id(self, model_config: Any) -> int: