Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/core/sched/test_omni_scheduler_mixin_timeouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit coverage for _process_pending_input_timeouts.

Verifies that the mixin correctly *delegates* timed-out requests to the
base scheduler's ``finish_requests`` API with ``RequestStatus.FINISHED_ERROR``.
The end-to-end effect (queue removal + status set + per-request cleanup +
client-facing FINISHED_ERROR emission) is the responsibility of upstream
vLLM's ``finish_requests`` implementation and is covered by upstream tests;
this file only asserts the wiring from the mixin to that API.
"""

from __future__ import annotations

from types import SimpleNamespace

import pytest

from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin


class _FakeCoordinator:
def __init__(self, timed_out_ids):
self._timed_out_ids = set(timed_out_ids)
self.calls = []

def collect_timed_out_request_ids(self, timeout_s):
self.calls.append(timeout_s)
return set(self._timed_out_ids)


class _FakeScheduler(OmniSchedulerMixin):
def __init__(self, requests, coordinator):
self.requests = requests
self.input_coordinator = coordinator
self.finish_calls = []

def finish_requests(self, req_ids, status):
self.finish_calls.append((set(req_ids), status))


def test_process_pending_input_timeouts_delegates_to_finish_requests():
"""Timed-out request present in self.requests is forwarded to finish_requests."""
req_id = "stuck-req"
requests = {req_id: SimpleNamespace(request_id=req_id)}
coord = _FakeCoordinator(timed_out_ids={req_id})
scheduler = _FakeScheduler(requests, coord)

scheduler._process_pending_input_timeouts()

assert len(coord.calls) == 1, "coordinator should be polled once"
assert coord.calls[0] > 0, "timeout must be positive when enabled"

assert len(scheduler.finish_calls) == 1
finished_ids, status = scheduler.finish_calls[0]
assert finished_ids == {req_id}
# RequestStatus is the upstream enum; the mixin imports it as
# RequestStatus.FINISHED_ERROR. Check by name to avoid hard import here.
assert getattr(status, "name", str(status)).endswith("FINISHED_ERROR")


def test_process_pending_input_timeouts_skips_already_freed_request():
"""Timed-out id no longer in self.requests must not be forwarded."""
coord = _FakeCoordinator(timed_out_ids={"already-freed"})
scheduler = _FakeScheduler(requests={}, coordinator=coord)

scheduler._process_pending_input_timeouts()

assert coord.calls == [coord.calls[0]] and coord.calls[0] > 0
assert scheduler.finish_calls == []


def test_process_pending_input_timeouts_noop_without_coordinator():
"""No coordinator => no finish_requests call, no crash."""

class _NoCoord(OmniSchedulerMixin):
def __init__(self):
self.requests = {}
self.input_coordinator = None
self.finish_calls = []

def finish_requests(self, req_ids, status):
self.finish_calls.append((set(req_ids), status))

scheduler = _NoCoord()
scheduler._process_pending_input_timeouts()
assert scheduler.finish_calls == []


def test_process_pending_input_timeouts_disabled_when_timeout_zero(monkeypatch):
"""Setting DEFAULT_INPUT_WAIT_TIMEOUT_S <= 0 disables the safety net."""
from vllm_omni.core.sched import omni_scheduler_mixin

monkeypatch.setattr(omni_scheduler_mixin, "DEFAULT_INPUT_WAIT_TIMEOUT_S", 0.0)

coord = _FakeCoordinator(timed_out_ids={"r1"})
scheduler = _FakeScheduler(requests={"r1": SimpleNamespace(request_id="r1")}, coordinator=coord)
scheduler._process_pending_input_timeouts()
assert coord.calls == [], "coordinator must not be polled when timeout is disabled"
assert scheduler.finish_calls == []


if __name__ == "__main__":
pytest.main([__file__, "-v"])
97 changes: 73 additions & 24 deletions tests/core/sched/test_omni_scheduling_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod
from vllm_omni.core.sched.omni_scheduling_coordinator import (
OmniSchedulingCoordinator,
uses_qwen3_omni_full_payload_input_coordinator,
uses_full_payload_input_coordinator,
)

# ------------------------------------------------------------------ #
Expand Down Expand Up @@ -92,47 +92,96 @@ def remove_requests(self, requests):


class TestFullPayloadCoordinatorSelection(unittest.TestCase):
def test_qwen3_omni_talker_and_code2wav_use_full_payload_input_coordinator(self):
for model_stage in ("talker", "code2wav"):
"""Tests for the (model_arch, model_stage) whitelist gate.

The init_omni_connectors arch allowlist is keyed by ``model_arch`` and
is a superset of the stages registered here -- consumer-wait stages
must be registered explicitly in ``_FULL_PAYLOAD_INPUT_STAGES``, while
the init allowlist covers both producer- and consumer-side runners.
These tests pin which ``(arch, stage)`` pairs the gate fires for today.
"""

# Expected whitelist (model_arch, model_stage). Hardcoded to avoid the
# tautology of importing _FULL_PAYLOAD_INPUT_STAGES and asserting it
# against itself; any drift between this matrix and the whitelist will
# fail loudly here.
EXPECTED_FULL_PAYLOAD_INPUT_STAGES: frozenset[tuple[str, str]] = frozenset(
{
("Qwen3OmniMoeForConditionalGeneration", "talker"),
("Qwen3OmniMoeForConditionalGeneration", "code2wav"),
("Qwen2_5OmniForConditionalGeneration", "talker"),
("Qwen2_5OmniForConditionalGeneration", "code2wav"),
("CovoAudioForConditionalGeneration", "code2wav"),
("MiMoAudioModel", "code2wav"),
("Qwen3TTSCode2Wav", "code2wav"),
("CosyVoice3Model", "cosyvoice3_code2wav"),
("DyninOmniForConditionalGeneration", "token2image"),
("DyninOmniForConditionalGeneration", "token2audio"),
}
)

def test_whitelist_matches_expected_matrix(self):
"""_FULL_PAYLOAD_INPUT_STAGES must equal the hardcoded expected matrix.

Catches both accidental additions (which would silently enable the
consumer-wait gate for a new arch) and accidental removals (which
would silently disable an enabled arch).
"""
from vllm_omni.core.sched.omni_scheduling_coordinator import _FULL_PAYLOAD_INPUT_STAGES

self.assertEqual(
frozenset(_FULL_PAYLOAD_INPUT_STAGES),
self.EXPECTED_FULL_PAYLOAD_INPUT_STAGES,
msg="_FULL_PAYLOAD_INPUT_STAGES drifted from the expected matrix; "
"update EXPECTED_FULL_PAYLOAD_INPUT_STAGES if intentional.",
)

def test_all_whitelisted_arch_stage_pairs_fire_gate(self):
"""Every (arch, stage) pair in the expected matrix must fire
the gate when stage_id > 0 and async_chunk=False.
"""
for arch, stage in self.EXPECTED_FULL_PAYLOAD_INPUT_STAGES:
model_config = SimpleNamespace(
stage_id=1,
async_chunk=False,
model_arch="Qwen3OmniMoeForConditionalGeneration",
model_stage=model_stage,
model_arch=arch,
model_stage=stage,
)
self.assertTrue(
uses_full_payload_input_coordinator(model_config),
msg=f"expected gate to fire for {arch}/{stage}",
)

self.assertTrue(uses_qwen3_omni_full_payload_input_coordinator(model_config))

def test_async_chunk_and_non_qwen3_omni_do_not_use_full_payload_input_coordinator(self):
def test_other_arch_or_stage_or_mode_does_not_fire(self):
cases = [
SimpleNamespace(
stage_id=1,
async_chunk=True,
model_arch="Qwen3OmniMoeForConditionalGeneration",
model_stage="talker",
stage_id=1, async_chunk=True, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="talker"
),
SimpleNamespace(
stage_id=1,
async_chunk=False,
model_arch="Qwen3TTSForConditionalGeneration",
model_stage="code2wav",
stage_id=0, async_chunk=False, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="thinker"
),
SimpleNamespace(
stage_id=1,
async_chunk=False,
model_arch="Qwen2_5OmniForConditionalGeneration",
model_stage="talker",
model_arch="Qwen3OmniMoeForConditionalGeneration",
model_stage="some_future_stage",
),
SimpleNamespace(
stage_id=0,
async_chunk=False,
model_arch="Qwen3OmniMoeForConditionalGeneration",
model_stage="thinker",
stage_id=1, async_chunk=False, model_arch="Qwen3TTSForConditionalGeneration", model_stage="code2wav"
),
SimpleNamespace(
stage_id=1, async_chunk=False, model_arch="MingFlashOmniForConditionalGeneration", model_stage="talker"
),
SimpleNamespace(stage_id=1, async_chunk=False, model_arch=None, model_stage="talker"),
SimpleNamespace(
stage_id=1, async_chunk=False, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage=None
),
]

for model_config in cases:
self.assertFalse(uses_qwen3_omni_full_payload_input_coordinator(model_config))
self.assertFalse(
uses_full_payload_input_coordinator(model_config),
msg=f"expected gate OFF for {model_config}",
)


class TestChunkCoordinatorStateTransition(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,13 @@ def inference(self, speech_feat, finalize=True):
model = object.__new__(CosyVoice3Code2Wav)
nn.Module.__init__(model)
model.hift = DummyHiFT()
model._forward_mel = lambda **_: torch.ones((1, 80, 8), dtype=torch.float32)
forward_mel_calls = []

def fake_forward_mel(**kwargs):
forward_mel_calls.append(kwargs)
return torch.ones((1, 80, 8), dtype=torch.float32)

model._forward_mel = fake_forward_mel

out = model.forward(
token=torch.tensor([[1, 2, 3]], dtype=torch.int32),
Expand All @@ -304,3 +310,4 @@ def inference(self, speech_feat, finalize=True):

assert out.shape == (1, 1, 8)
assert model.hift.finalize_calls == [True]
assert forward_mel_calls[0]["token_offset_tokens"] == 0
Original file line number Diff line number Diff line change
Expand Up @@ -138,27 +138,6 @@ def _make_sampling_metadata(
)


def test_split_request_ids_uses_seq_token_counts():
CosyVoice3Model, _ = _cosyvoice3_model_and_runner()
ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long)
chunks = CosyVoice3Model._split_request_ids(ids, [2, 2, 2])
assert [c.tolist() for c in chunks] == [[10, 11], [12, 13], [14]]


def test_split_request_ids_honors_single_request_seq_token_counts():
CosyVoice3Model, _ = _cosyvoice3_model_and_runner()
ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long)
chunks = CosyVoice3Model._split_request_ids(ids, [3])
assert [c.tolist() for c in chunks] == [[10, 11, 12]]


def test_sanitize_codec_tokens_filters_out_of_range():
model = _make_code2wav_model()
raw = torch.tensor([-1, 0, 3, 4, 99], dtype=torch.long)
clean = model._sanitize_codec_tokens(raw)
assert clean.tolist() == [0, 3]


def test_forward_prefers_token_offset_when_present():
model = _make_code2wav_model()

Expand Down Expand Up @@ -265,6 +244,31 @@ def test_forward_uses_non_stream_decode_without_chunk_metadata():
assert len(model.code2wav.forward_streaming_calls) == 0
call = model.code2wav.forward_calls[0]
assert call["token"].tolist() == [[0, 1, 2]]
assert call["token_offset_tokens"] == 0


def test_forward_uses_non_stream_talker_prefill_offset():
model = _make_code2wav_model()

runtime_info = [
{
"embed": {
"speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long),
"speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32),
"embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32),
},
"meta": {"talker_prefill_offset": 3},
}
]

model.forward(
input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
positions=torch.tensor([0, 1, 2], dtype=torch.long),
model_intermediate_buffer=runtime_info,
seq_token_counts=[3],
)

assert model.code2wav.forward_calls[0]["token_offset_tokens"] == 3


def test_forward_reuses_streaming_cache_state_between_chunks():
Expand Down
Loading
Loading