diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index 54faa3b3af5..edf46eb9cca 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -16,6 +16,8 @@ talker2code2wav_async_chunk, ) +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + _FRAME = [1, 2, 3, 4] _Q = len(_FRAME) @@ -90,7 +92,7 @@ def test_flush_on_finish(): is_finished=True, ) assert p is not None - assert p["finished"].item() is True + assert p["finished"] is True assert len(p["code_predictor_codes"]) == _Q * 24 @@ -158,24 +160,29 @@ def test_dynamic_ic_adapts_to_load(): def test_ic_load_change_mid_request(): - """IC stateless: load spike mid-request shifts initial_coverage.""" + """IC is cached per request; a load spike only affects new requests.""" tm = _tm(chunk_frames=25, left_context=25, max_num_seqs=8) - # Low load -> IC=2 -> emit at frame 2 + # Low load -> IC=2 (cached for "r"), emit at frame 2 p1 = _call(tm, "r", n_frames=2) assert p1 is not None - # Spike load: 6 others -> IC=16 -> initial_coverage=16 + # Spike load: 6 others running for i in range(6): tm.code_prompt_token_ids[f"other-{i}"] = [[0]] * 10 - # adjusted=25-16=9, 9%25!=0 -> hold + # IC for "r" is still cached as 2. + # initial_coverage = ((25-1)//2)*2 = 24, first normal emit at 24+25=49 assert _call(tm, "r", n_frames=25) is None - - # First normal emit at 16+25=41 - p3 = _call(tm, "r", n_frames=41) + assert _call(tm, "r", n_frames=27) is None + p3 = _call(tm, "r", n_frames=49) assert p3 is not None - assert p3["left_context_size"] == 16 + + # A *new* request under high load gets IC=16 (not IC=2). + # Frame 2 would emit under IC=2 but must hold under IC=16. + assert _call(tm, "new_req", n_frames=2) is None + p4 = _call(tm, "new_req", n_frames=16) + assert p4 is not None @pytest.mark.parametrize( @@ -295,8 +302,13 @@ def test_non_async_processor_prepends_ref_code_and_sets_trim_context(): ], dtype=torch.long, ) - output = SimpleNamespace(multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code}) - stage = SimpleNamespace(engine_outputs=[SimpleNamespace(outputs=[output])]) + output = SimpleNamespace( + multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code}, + token_ids=list(range(3)), + ) + stage = SimpleNamespace( + engine_outputs=[SimpleNamespace(outputs=[output], finished=True)], + ) prompts = talker2code2wav(stage_list=[stage], engine_input_source=[0]) @@ -335,8 +347,13 @@ def test_non_async_processor_filters_out_of_range_codec_values(): ], dtype=torch.long, ) - output = SimpleNamespace(multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code}) - stage = SimpleNamespace(engine_outputs=[SimpleNamespace(outputs=[output])]) + output = SimpleNamespace( + multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code}, + token_ids=list(range(4)), + ) + stage = SimpleNamespace( + engine_outputs=[SimpleNamespace(outputs=[output], finished=True)], + ) prompts = talker2code2wav(stage_list=[stage], engine_input_source=[0])