diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py index 6ec4630a45..e37d3f74df 100644 --- a/tests/e2e/offline_inference/test_voxcpm2.py +++ b/tests/e2e/offline_inference/test_voxcpm2.py @@ -100,3 +100,31 @@ def test_voxcpm2_voice_clone_002(voxcpm2_engine): audio = _extract_audio(outputs[0].outputs[0].multimodal_output) duration_s = audio.shape[0] / SAMPLE_RATE assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s" + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +def test_voxcpm2_prefill_decode_mixed_batch_003(voxcpm2_engine): + """Regression: prefill+decode mixed batch must not crash (PR #2903).""" + long_prompt = ( + "This is a deliberately long prompt that will stay in the decode " + "phase for many steps so that subsequent shorter prompts keep " + "entering prefill alongside it, reproducing the prefill plus " + "decode mixed batch scheduling pattern." + ) + short_prompts = [ + "Hello one.", + "Hello two.", + "Hello three.", + "Hello four.", + ] + requests = [{"prompt": long_prompt}] + [{"prompt": p} for p in short_prompts] + + outputs = voxcpm2_engine.generate(requests) + assert len(outputs) == len(requests) + + for i, out in enumerate(outputs): + audio = _extract_audio(out.outputs[0].multimodal_output) + duration_s = audio.shape[0] / SAMPLE_RATE + assert 0.1 < duration_s < 30.0, f"Request {i} audio duration out of range: {duration_s:.2f}s" diff --git a/tests/model_executor/models/voxcpm2/__init__.py b/tests/model_executor/models/voxcpm2/__init__.py new file mode 100644 index 0000000000..208f01a7cb --- /dev/null +++ b/tests/model_executor/models/voxcpm2/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py new file mode 100644 index 0000000000..5d8a35636b --- /dev/null +++ b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Regression tests for VoxCPM2 talker per-request state lifecycle.""" + +from __future__ import annotations + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("librosa") + +from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( # noqa: E402 + VoxCPM2TalkerForConditionalGeneration, + _RequestState, +) + + +def _make_bare_talker() -> VoxCPM2TalkerForConditionalGeneration: + talker = VoxCPM2TalkerForConditionalGeneration.__new__(VoxCPM2TalkerForConditionalGeneration) + talker._active_states = {} + talker._current_request_id = None + talker._pending_requests = [] + talker._results_queue = [] + talker._audio_queue = [] + talker._deferred_cleanup_ids = set() + talker._max_batch_size = 4 + talker._active_state_warn_threshold = 512 + talker._active_state_warned = False + return talker + + +def _seed_cached_decode(talker, req_id: str) -> _RequestState: + state = _RequestState(request_id=req_id) + state.prefill_completed = True + state.decode_step_count = 5 + talker._active_states[req_id] = state + return state + + +class TestStateEvictionContract: + def test_pending_requests_is_not_used_for_eviction(self) -> None: + talker = _make_bare_talker() + + cached_ids = [f"req-{i}" for i in range(4)] + for rid in cached_ids: + _seed_cached_decode(talker, rid) + + walked_so_far = ["req-new", cached_ids[0], cached_ids[1]] + talker._pending_requests = [(rid, False, None, 0) for rid in walked_so_far] + + for rid in cached_ids: + assert rid in talker._active_states + assert talker._active_states[rid].prefill_completed is True + + def test_on_requests_finished_defers_cleanup(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + _seed_cached_decode(talker, "req-B") + + talker.on_requests_finished({"req-A"}) + + assert "req-A" in talker._active_states + assert "req-A" in talker._deferred_cleanup_ids + + def test_flush_deferred_cleanup_removes_only_finished(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + _seed_cached_decode(talker, "req-B") + talker.on_requests_finished(["req-A"]) + + talker._flush_deferred_cleanup() + + assert "req-A" not in talker._active_states + assert "req-B" in talker._active_states + assert talker._deferred_cleanup_ids == set() + + def test_current_request_id_cleared_when_matching(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + talker._current_request_id = "req-A" + + talker.on_requests_finished({"req-A"}) + talker._flush_deferred_cleanup() + + assert talker._current_request_id is None + + def test_current_request_id_preserved_when_not_finished(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + _seed_cached_decode(talker, "req-B") + talker._current_request_id = "req-B" + + talker.on_requests_finished({"req-A"}) + talker._flush_deferred_cleanup() + + assert talker._current_request_id == "req-B" + + +class TestLeakWarnGuard: + def test_warn_fires_once_over_threshold(self, monkeypatch) -> None: + from vllm_omni.model_executor.models.voxcpm2 import voxcpm2_talker as tk + + calls: list[str] = [] + + def _capture(msg, *args, **kwargs): + calls.append(msg % args if args else msg) + + monkeypatch.setattr(tk.logger, "warning", _capture) + + talker = _make_bare_talker() + talker._active_state_warn_threshold = 3 + + for i in range(4): + talker._active_states[f"seed-{i}"] = _RequestState(request_id=f"seed-{i}") + + talker._get_or_create_state("new-1") + talker._get_or_create_state("new-2") + + leak_warnings = [m for m in calls if "cleanup path leak" in m] + assert len(leak_warnings) == 1 + assert talker._active_state_warned is True diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 7d54cd17f5..3724528898 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -41,6 +41,11 @@ _ENABLE_PROFILING = os.environ.get("VOXCPM2_PROFILE", "0") == "1" +# Lower bound for the _active_states leak-warn threshold. The effective +# threshold is max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * max_batch_size) so small +# deployments still get a usable floor instead of a tiny noisy one. +_ACTIVE_STATE_LEAK_WARN_MIN = 512 + def is_cjk_char(c: str) -> bool: """Check if a character is a CJK ideograph.""" @@ -440,6 +445,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._results_queue: list[tuple[str, torch.Tensor | None]] = [] self._audio_queue: list[tuple[str, Any]] = [] self._deferred_cleanup_ids: set[str] = set() + self._active_state_warn_threshold = max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * self._max_batch_size) + # one-shot by design: fires at most once per process to avoid log spam. + self._active_state_warned = False @property def tts(self) -> nn.Module: @@ -449,9 +457,20 @@ def tts(self) -> nn.Module: # -------------------- request state management -------------------- def _get_or_create_state(self, request_id: str) -> _RequestState: - if request_id not in self._active_states: - self._active_states[request_id] = _RequestState(request_id=request_id) - return self._active_states[request_id] + state = self._active_states.get(request_id) + if state is None: + state = _RequestState(request_id=request_id) + self._active_states[request_id] = state + if len(self._active_states) > self._active_state_warn_threshold and not self._active_state_warned: + logger.warning( + "VoxCPM2: _active_states size=%d exceeds threshold %d " + "(max_batch_size=%d); possible cleanup path leak", + len(self._active_states), + self._active_state_warn_threshold, + self._max_batch_size, + ) + self._active_state_warned = True + return state def _switch_to_request(self, request_id: str) -> _RequestState: if request_id != self._current_request_id: @@ -1087,14 +1106,10 @@ def preprocess( is_prefill = span_len > 1 if is_prefill: - # Evict stale states - pending_ids = {rid for rid, *_ in self._pending_requests} - pending_ids.add(req_id) - if self._current_request_id: - pending_ids.add(self._current_request_id) - for rid in [r for r, s in self._active_states.items() if r not in pending_ids and s.prefill_completed]: - self._cleanup_request(rid) - + # Do not evict state here: _pending_requests is a per-step prefix, + # not the full batch. Cleanup is driven by on_requests_finished -> + # _flush_deferred_cleanup (fed by vLLM scheduler._free_request via + # gpu_ar_model_runner.py). real = info_dict.get("text_token_ids") token_ids = input_ids.tolist() if real is None else real[0] # Fail-fast: unsplit multichar Chinese IDs in input_ids means the