Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/e2e/offline_inference/test_voxcpm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,31 @@ def test_voxcpm2_voice_clone_002(voxcpm2_engine):
audio = _extract_audio(outputs[0].outputs[0].multimodal_output)
duration_s = audio.shape[0] / SAMPLE_RATE
assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"


@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_voxcpm2_prefill_decode_mixed_batch_003(voxcpm2_engine):
"""Regression: prefill+decode mixed batch must not crash (PR #2903)."""
long_prompt = (
"This is a deliberately long prompt that will stay in the decode "
"phase for many steps so that subsequent shorter prompts keep "
"entering prefill alongside it, reproducing the prefill plus "
"decode mixed batch scheduling pattern."
)
short_prompts = [
"Hello one.",
"Hello two.",
"Hello three.",
"Hello four.",
]
requests = [{"prompt": long_prompt}] + [{"prompt": p} for p in short_prompts]

outputs = voxcpm2_engine.generate(requests)
assert len(outputs) == len(requests)

for i, out in enumerate(outputs):
audio = _extract_audio(out.outputs[0].multimodal_output)
duration_s = audio.shape[0] / SAMPLE_RATE
assert 0.1 < duration_s < 30.0, f"Request {i} audio duration out of range: {duration_s:.2f}s"
2 changes: 2 additions & 0 deletions tests/model_executor/models/voxcpm2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
121 changes: 121 additions & 0 deletions tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Regression tests for VoxCPM2 talker per-request state lifecycle."""

from __future__ import annotations

import pytest

torch = pytest.importorskip("torch")
pytest.importorskip("librosa")

from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( # noqa: E402
VoxCPM2TalkerForConditionalGeneration,
_RequestState,
)


def _make_bare_talker() -> VoxCPM2TalkerForConditionalGeneration:
talker = VoxCPM2TalkerForConditionalGeneration.__new__(VoxCPM2TalkerForConditionalGeneration)
talker._active_states = {}
talker._current_request_id = None
talker._pending_requests = []
talker._results_queue = []
talker._audio_queue = []
talker._deferred_cleanup_ids = set()
talker._max_batch_size = 4
talker._active_state_warn_threshold = 512
talker._active_state_warned = False
return talker


def _seed_cached_decode(talker, req_id: str) -> _RequestState:
state = _RequestState(request_id=req_id)
state.prefill_completed = True
state.decode_step_count = 5
talker._active_states[req_id] = state
return state


class TestStateEvictionContract:
def test_pending_requests_is_not_used_for_eviction(self) -> None:
talker = _make_bare_talker()

cached_ids = [f"req-{i}" for i in range(4)]
for rid in cached_ids:
_seed_cached_decode(talker, rid)

walked_so_far = ["req-new", cached_ids[0], cached_ids[1]]
talker._pending_requests = [(rid, False, None, 0) for rid in walked_so_far]

for rid in cached_ids:
assert rid in talker._active_states
assert talker._active_states[rid].prefill_completed is True

def test_on_requests_finished_defers_cleanup(self) -> None:
talker = _make_bare_talker()
_seed_cached_decode(talker, "req-A")
_seed_cached_decode(talker, "req-B")

talker.on_requests_finished({"req-A"})

assert "req-A" in talker._active_states
assert "req-A" in talker._deferred_cleanup_ids

def test_flush_deferred_cleanup_removes_only_finished(self) -> None:
talker = _make_bare_talker()
_seed_cached_decode(talker, "req-A")
_seed_cached_decode(talker, "req-B")
talker.on_requests_finished(["req-A"])

talker._flush_deferred_cleanup()

assert "req-A" not in talker._active_states
assert "req-B" in talker._active_states
assert talker._deferred_cleanup_ids == set()

def test_current_request_id_cleared_when_matching(self) -> None:
talker = _make_bare_talker()
_seed_cached_decode(talker, "req-A")
talker._current_request_id = "req-A"

talker.on_requests_finished({"req-A"})
talker._flush_deferred_cleanup()

assert talker._current_request_id is None

def test_current_request_id_preserved_when_not_finished(self) -> None:
talker = _make_bare_talker()
_seed_cached_decode(talker, "req-A")
_seed_cached_decode(talker, "req-B")
talker._current_request_id = "req-B"

talker.on_requests_finished({"req-A"})
talker._flush_deferred_cleanup()

assert talker._current_request_id == "req-B"


class TestLeakWarnGuard:
def test_warn_fires_once_over_threshold(self, monkeypatch) -> None:
from vllm_omni.model_executor.models.voxcpm2 import voxcpm2_talker as tk

calls: list[str] = []

def _capture(msg, *args, **kwargs):
calls.append(msg % args if args else msg)

monkeypatch.setattr(tk.logger, "warning", _capture)

talker = _make_bare_talker()
talker._active_state_warn_threshold = 3

for i in range(4):
talker._active_states[f"seed-{i}"] = _RequestState(request_id=f"seed-{i}")

talker._get_or_create_state("new-1")
talker._get_or_create_state("new-2")

leak_warnings = [m for m in calls if "cleanup path leak" in m]
assert len(leak_warnings) == 1
assert talker._active_state_warned is True
37 changes: 26 additions & 11 deletions vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@

_ENABLE_PROFILING = os.environ.get("VOXCPM2_PROFILE", "0") == "1"

# Lower bound for the _active_states leak-warn threshold. The effective
# threshold is max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * max_batch_size) so small
# deployments still get a usable floor instead of a tiny noisy one.
_ACTIVE_STATE_LEAK_WARN_MIN = 512


def is_cjk_char(c: str) -> bool:
"""Check if a character is a CJK ideograph."""
Expand Down Expand Up @@ -440,6 +445,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._results_queue: list[tuple[str, torch.Tensor | None]] = []
self._audio_queue: list[tuple[str, Any]] = []
self._deferred_cleanup_ids: set[str] = set()
self._active_state_warn_threshold = max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * self._max_batch_size)
# one-shot by design: fires at most once per process to avoid log spam.
self._active_state_warned = False

@property
def tts(self) -> nn.Module:
Expand All @@ -449,9 +457,20 @@ def tts(self) -> nn.Module:
# -------------------- request state management --------------------

def _get_or_create_state(self, request_id: str) -> _RequestState:
if request_id not in self._active_states:
self._active_states[request_id] = _RequestState(request_id=request_id)
return self._active_states[request_id]
state = self._active_states.get(request_id)
if state is None:
state = _RequestState(request_id=request_id)
self._active_states[request_id] = state
if len(self._active_states) > self._active_state_warn_threshold and not self._active_state_warned:
logger.warning(
"VoxCPM2: _active_states size=%d exceeds threshold %d "
"(max_batch_size=%d); possible cleanup path leak",
len(self._active_states),
Comment thread
Sy0307 marked this conversation as resolved.
self._active_state_warn_threshold,
self._max_batch_size,
)
self._active_state_warned = True
return state

def _switch_to_request(self, request_id: str) -> _RequestState:
if request_id != self._current_request_id:
Expand Down Expand Up @@ -1087,14 +1106,10 @@ def preprocess(
is_prefill = span_len > 1

if is_prefill:
# Evict stale states
pending_ids = {rid for rid, *_ in self._pending_requests}
pending_ids.add(req_id)
if self._current_request_id:
pending_ids.add(self._current_request_id)
for rid in [r for r, s in self._active_states.items() if r not in pending_ids and s.prefill_completed]:
self._cleanup_request(rid)

# Do not evict state here: _pending_requests is a per-step prefix,
# not the full batch. Cleanup is driven by on_requests_finished ->
# _flush_deferred_cleanup (fed by vLLM scheduler._free_request via
# gpu_ar_model_runner.py).
real = info_dict.get("text_token_ids")
token_ids = input_ids.tolist() if real is None else real[0]
# Fail-fast: unsplit multichar Chinese IDs in input_ids means the
Expand Down
Loading