Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 126 additions & 5 deletions tests/core/sched/test_omni_scheduling_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

import unittest
from types import SimpleNamespace
from unittest import mock

import torch

import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod
from vllm_omni.core.sched.omni_scheduling_coordinator import (
OmniSchedulingCoordinator,
uses_async_chunk_coordinator,
uses_full_payload_input_coordinator,
)

Expand Down Expand Up @@ -204,6 +206,26 @@ def test_ready_request_transitions_to_waiting(self):
self.assertEqual(req.status, RequestStatus.WAITING)
self.assertIn("r1", coord.requests_with_ready_chunks)

def test_late_ready_before_queue_insertion_is_retained(self):
# codex r3: a chunk can arrive before the request is surfaced into a
# queue. The readiness must be retained (not lost when the connector
# output is cleared) so a later cycle still transitions the request.
coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)

# Cycle 1: ready for "r1" arrives while no queue holds it yet.
coord.process_pending_chunks(MockQueue([]), [], chunk_ready_req_ids={"r1"}, chunk_finished_req_ids=set())
self.assertIn("r1", coord.requests_with_ready_chunks, "late ready must be retained")

# Cycle 2: r1 now appears as a fresh WAITING request, but chunk_ready is
# already empty (the connector output was consumed last cycle). Because
# retention recorded r1, it must NOT be parked into WAITING_FOR_CHUNK --
# it stays schedulable. Without the retain it would be wrongly parked.
req = _make_request("r1", status=RequestStatus.WAITING)
waiting = MockQueue([req])
coord.process_pending_chunks(waiting, [], chunk_ready_req_ids=set(), chunk_finished_req_ids=set())
self.assertEqual(req.status, RequestStatus.WAITING, "ready-before-insertion must not be parked")
self.assertIn(req, waiting, "request must remain schedulable in the waiting queue")

def test_non_ready_stays_waiting_for_chunk(self):
coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)

Expand Down Expand Up @@ -471,7 +493,7 @@ def test_full_payload_mode_auto_transitions_waiting_to_waiting_for_input(self):

self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
self.assertEqual(len(coord._waiting_for_input), 1)
self.assertEqual(len(coord.pending_input_registrations), 1)
self.assertEqual(len(coord.pending_connector_registrations), 1)

def test_async_chunk_mode_does_not_auto_transition(self):
"""In async_chunk mode, fresh WAITING requests should NOT be
Expand All @@ -494,7 +516,7 @@ def test_async_chunk_mode_does_not_auto_transition(self):

self.assertEqual(req.status, RequestStatus.WAITING)

def test_pending_input_registrations(self):
def test_pending_connector_registrations(self):
coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)

req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
Expand All @@ -507,8 +529,8 @@ def test_pending_input_registrations(self):
stage_recv_req_ids=set(),
)

self.assertEqual(len(coord.pending_input_registrations), 1)
self.assertEqual(coord.pending_input_registrations[0].request_id, "r1")
self.assertEqual(len(coord.pending_connector_registrations), 1)
self.assertEqual(coord.pending_connector_registrations[0].request_id, "r1")

def test_idle_cycles_retain_received_marker_before_request_appears(self):
coord = OmniSchedulingCoordinator(
Expand All @@ -533,7 +555,7 @@ def test_idle_cycles_retain_received_marker_before_request_appears(self):
coord.process_pending_full_payload_inputs(waiting, running, stage_recv_req_ids=set())

self.assertEqual(late_req.status, RequestStatus.WAITING)
self.assertEqual(coord.pending_input_registrations, [])
self.assertEqual(coord.pending_connector_registrations, [])
self.assertIn("late", coord._full_payload_input_received)
self.assertIn("late", coord.finished_requests)

Expand Down Expand Up @@ -861,5 +883,104 @@ def test_overflow_does_not_strand_request(self):
self.assertNotEqual(req.status, RequestStatus.RUNNING, "Overflowed request must not keep RUNNING status")


class TestAsyncChunkCoordinatorGate(unittest.TestCase):
"""`uses_async_chunk_coordinator` selects the coordinator+mixin path for
allowlisted async-chunk archs on SharedMemory only; everyone else (empty
allowlist today, Mooncake, sync) stays on the legacy adapter.
"""

_SM = {"name": "SharedMemoryConnector"}
_MOONCAKE = {"name": "MooncakeStoreConnector"}

def test_allowlisted_sharedmemory_fires(self):
key = ("Qwen3OmniMoeForConditionalGeneration", "talker")
with mock.patch.object(coord_mod, "_ASYNC_CHUNK_COORDINATOR_STAGES", frozenset({key})):
mc = SimpleNamespace(
async_chunk=True,
model_arch=key[0],
model_stage=key[1],
stage_connector_config=self._SM,
)
self.assertTrue(uses_async_chunk_coordinator(mc))
# default (no connector config) is SharedMemory -> also fires
mc_default = SimpleNamespace(
async_chunk=True,
model_arch=key[0],
model_stage=key[1],
stage_connector_config=None,
)
self.assertTrue(uses_async_chunk_coordinator(mc_default))

def test_mooncake_stays_on_adapter(self):
key = ("Qwen3OmniMoeForConditionalGeneration", "talker")
with mock.patch.object(coord_mod, "_ASYNC_CHUNK_COORDINATOR_STAGES", frozenset({key})):
mc = SimpleNamespace(
async_chunk=True,
model_arch=key[0],
model_stage=key[1],
stage_connector_config=self._MOONCAKE,
)
self.assertFalse(uses_async_chunk_coordinator(mc))

def test_sync_or_non_allowlisted_does_not_fire(self):
key = ("Qwen3OmniMoeForConditionalGeneration", "talker")
with mock.patch.object(coord_mod, "_ASYNC_CHUNK_COORDINATOR_STAGES", frozenset({key})):
# async_chunk=False
self.assertFalse(
uses_async_chunk_coordinator(
SimpleNamespace(
async_chunk=False, model_arch=key[0], model_stage=key[1], stage_connector_config=self._SM
)
)
)
# non-allowlisted arch
self.assertFalse(
uses_async_chunk_coordinator(
SimpleNamespace(
async_chunk=True,
model_arch="MiMoAudioModel",
model_stage="code2wav",
stage_connector_config=self._SM,
)
)
)
# non-allowlisted stage of an allowlisted arch
self.assertFalse(
uses_async_chunk_coordinator(
SimpleNamespace(
async_chunk=True, model_arch=key[0], model_stage="thinker", stage_connector_config=self._SM
)
)
)


class TestAsyncChunkRecvRegistration(unittest.TestCase):
"""Regression coverage: a parked async-chunk
request MUST be registered for bg-thread recv via the CARRIED
``pending_connector_registrations`` (the old ``pending_chunk_registrations``
was never carried/consumed -> the runner never called register_chunk_recv,
the bg thread never polled, and the request hung until the 300s timeout).
The full-payload pass runs AFTER process_pending_chunks each cycle in
async-chunk mode, so it must NOT re-clear the chunk registrations.
"""

def test_parked_chunk_request_registered_and_survives_full_payload_pass(self):
coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
req = _make_request("r1", status=RequestStatus.WAITING)
waiting = MockQueue([req])
running: list = []

# No chunk ready yet -> park WAITING_FOR_CHUNK AND register for recv.
coord.process_pending_chunks(waiting, running, chunk_ready_req_ids=set(), chunk_finished_req_ids=set())
self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
regs = [h.request_id for h in coord.pending_connector_registrations]
self.assertIn("r1", regs, "parked async-chunk request must be registered for bg recv polling")

# The full-payload pass (runs after, every cycle) must not wipe it.
coord.process_pending_full_payload_inputs(waiting, running, stage_recv_req_ids=set())
regs_after = [h.request_id for h in coord.pending_connector_registrations]
self.assertIn("r1", regs_after, "full-payload pass must not drop async-chunk recv registrations")


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch

from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni import (
Qwen3OmniMoeForConditionalGeneration,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def _make_minimal_omni() -> Qwen3OmniMoeForConditionalGeneration:
model = Qwen3OmniMoeForConditionalGeneration.__new__(Qwen3OmniMoeForConditionalGeneration)
model.talker = SimpleNamespace(text_projection=lambda x: x + 10)
model.tts_pad_embed = torch.full((2,), -1.0)
model.tts_eos_embed = torch.full((2,), -2.0)
return model


def test_async_chunk_decode_consumes_cached_handoff_decode() -> None:
model = _make_minimal_omni()
payload = {
"embed": {
"cached_decode": torch.tensor(
[
[1.0, 2.0],
[3.0, 4.0],
]
)
},
"meta": {
"num_processed_tokens": 1,
"prefill_consumed_text_tokens": 1,
},
}
update: dict = {}

out = model._thinker_decode_to_talker_decode(payload, torch.device("cpu"), update)

assert torch.equal(out, torch.tensor([11.0, 12.0]))
assert update["_advance_num_processed_tokens"] is True


def test_async_chunk_decode_appends_current_decode_after_cached_prefix() -> None:
model = _make_minimal_omni()
payload = {
"embed": {
"cached_decode": torch.tensor(
[
[1.0, 2.0],
[3.0, 4.0],
]
),
"decode": torch.tensor([[5.0, 6.0]]),
},
"meta": {
"num_processed_tokens": 3,
"prefill_consumed_text_tokens": 1,
},
}
update: dict = {}

out = model._thinker_decode_to_talker_decode(payload, torch.device("cpu"), update)

assert torch.equal(out, torch.tensor([15.0, 16.0]))
assert update["_advance_num_processed_tokens"] is True


def test_async_chunk_decode_uses_accumulated_decode_when_cache_is_prefix() -> None:
model = _make_minimal_omni()
payload = {
"embed": {
"cached_decode": torch.tensor([[1.0, 2.0]]),
"decode": torch.tensor(
[
[1.0, 2.0],
[3.0, 4.0],
]
),
},
"meta": {
"num_processed_tokens": 2,
"prefill_consumed_text_tokens": 1,
},
}
update: dict = {}

out = model._thinker_decode_to_talker_decode(payload, torch.device("cpu"), update)

assert torch.equal(out, torch.tensor([13.0, 14.0]))
assert update["_advance_num_processed_tokens"] is True
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""code2wav async finish-sentinel terminal flush.

The producer runner sends every in-step codec chunk with ``finished=False`` and
emits a separate finish sentinel next cycle (empty payload + the
``ASYNC_FINISH_SENTINEL_KEY`` marker the legacy adapter never sets). On that
marker, ``talker2code2wav_async_chunk`` must flush the trailing partial codec
chunk that the live ``is_finished`` branch would otherwise have flushed, reusing
the same context math, without re-appending.
"""

from types import SimpleNamespace

import torch

from vllm_omni.data_entry_keys import ASYNC_FINISH_SENTINEL_KEY
from vllm_omni.model_executor.stage_input_processors.qwen3_omni import talker2code2wav_async_chunk


def _tm(accumulated, chunk_frames=4, left_frames=25):
return SimpleNamespace(
code_prompt_token_ids=dict(accumulated),
connector=SimpleNamespace(
config={"extra": {"codec_chunk_frames": chunk_frames, "codec_left_context_frames": left_frames}}
),
)


def _sentinel_payload():
return {ASYNC_FINISH_SENTINEL_KEY: True}


def test_finish_sentinel_flushes_partial_tail():
# 6 frames accumulated, chunk size 4 -> a 2-frame partial tail is still held.
tm = _tm({"r": [torch.tensor([[i]]) for i in range(1, 7)]}, chunk_frames=4, left_frames=25)
req = SimpleNamespace(external_req_id="r")

out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True)

assert out is not None
assert bool(out.meta.finished) is True
# context_length = 6 % 4 = 2; left = min(6-2, 25) = 4; end_index = min(6, 4+2) = 6.
assert out.meta.left_context_size == 4
assert isinstance(out.codes.audio, torch.Tensor)
# 6 single-codebook frames -> flattened length 6.
assert out.codes.audio.numel() == 6


def test_finish_sentinel_on_chunk_boundary_emits_flag_only():
# 4 frames, chunk size 4 -> the last full chunk was already sent in-step;
# no unsent tail, so the sentinel must NOT re-send codec (flag only).
tm = _tm({"r": [torch.tensor([[i]]) for i in range(1, 5)]}, chunk_frames=4)
req = SimpleNamespace(external_req_id="r")

out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True)

assert out is not None
assert bool(out.meta.finished) is True
assert out.codes is None, "boundary finish must not re-send the last full chunk"


def test_finish_sentinel_with_no_sent_chunks_emits_flag_only():
tm = _tm({}, chunk_frames=4)
req = SimpleNamespace(external_req_id="missing")

out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True)

assert out is not None
assert bool(out.meta.finished) is True
assert out.codes is None


def test_non_sentinel_empty_call_is_unchanged():
# Without the marker, an empty/codeless call returns None as before -> the
# adapter path (which never sets the marker) is byte-identical.
tm = _tm({"r": [torch.tensor([[1]]), torch.tensor([[2]])]}, chunk_frames=4)
req = SimpleNamespace(external_req_id="r")

assert talker2code2wav_async_chunk(tm, {"codes": {}}, req, is_finished=True) is None
assert talker2code2wav_async_chunk(tm, {}, req, is_finished=True) is None
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def test_streaming_input_prefill_chunk_is_cached() -> None:
transfer_manager,
)

assert payload is None
assert payload is not None
assert payload.embed is None
assert payload.hidden_states is None
assert payload.ids is None
assert bool(payload.meta.finished) is False
cached = transfer_manager._pending_streaming_prefills["rt-1"]
assert torch.equal(cached["embed"]["prefill"], thinker_emb)
assert torch.equal(cached["hidden_states"]["output"], thinker_hid)
Expand Down Expand Up @@ -135,10 +139,12 @@ def test_streaming_input_prefill_flushes_with_next_decode_chunk() -> None:
)

assert payload is not None
assert payload.embed.prefill.shape == (3, 3)
assert payload.hidden_states.output.shape == (3, 3)
assert payload.embed.prefill.shape == (2, 3)
assert torch.equal(payload.embed.decode, thinker_emb)
assert payload.hidden_states.output.shape == (2, 3)
assert payload.ids.all == [151644, 872, 100]
assert payload.ids.prompt == [151644, 872]
assert payload.ids.output == [101]
assert "rt-2" not in transfer_manager._pending_streaming_prefills


Expand Down
Loading
Loading