diff --git a/tests/test_fix_mimo_audio_is_finished.py b/tests/test_fix_mimo_audio_is_finished.py new file mode 100644 index 00000000000..f6dc8091320 --- /dev/null +++ b/tests/test_fix_mimo_audio_is_finished.py @@ -0,0 +1,210 @@ +""" +Reproduction test for GitHub Issue #1569: + llm2code2wav_async_chunk() got an unexpected keyword argument 'is_finished' + +The call site in chunk_transfer_adapter.py:211 passes is_finished=is_finished, +but the function signature in mimo_audio.py:53 does NOT accept this parameter. +This causes a TypeError crash when using MiMo-Audio in async chunk mode. + +We import the module directly via importlib to bypass the vllm_omni __init__.py +which has heavy dependencies (aenum, vllm, etc.) not available in test env. +""" + +import importlib.util +import inspect +import os +import sys +import types +import unittest +from collections import defaultdict +from unittest.mock import MagicMock + +import torch + +# Repository root (one level up from tests/) +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +# Will be set in setUpModule / restored in tearDownModule +_saved_sys_modules = None +llm2code2wav_async_chunk = None + + +def _load_module_directly(module_name: str, file_path: str): + """Load a Python module directly from file path, bypassing package __init__.py.""" + spec = importlib.util.spec_from_file_location(module_name, file_path) + mod = importlib.util.module_from_spec(spec) + + # Provide stubs for dependencies that would fail to import + if "vllm.inputs" not in sys.modules: + stub = types.ModuleType("vllm.inputs") + stub.TextPrompt = type("TextPrompt", (), {}) + sys.modules["vllm"] = types.ModuleType("vllm") + sys.modules["vllm.inputs"] = stub + if "vllm.logger" not in sys.modules: + vllm_logger = types.ModuleType("vllm.logger") + import logging + vllm_logger.init_logger = lambda name: logging.getLogger(name) + sys.modules["vllm.logger"] = vllm_logger + if "vllm_omni.inputs.data" not in sys.modules: + stub_data = types.ModuleType("vllm_omni.inputs.data") + stub_data.OmniTokensPrompt = type("OmniTokensPrompt", (), {}) + sys.modules.setdefault("vllm_omni", types.ModuleType("vllm_omni")) + sys.modules.setdefault("vllm_omni.inputs", types.ModuleType("vllm_omni.inputs")) + sys.modules["vllm_omni.inputs.data"] = stub_data + if "vllm_omni.model_executor.models.mimo_audio.config_mimo_audio" not in sys.modules: + stub_config = types.ModuleType("vllm_omni.model_executor.models.mimo_audio.config_mimo_audio") + stub_config.TALKER_CODEC_PAD_TOKEN_ID = 0 # Default pad token ID + sys.modules.setdefault("vllm_omni.model_executor", types.ModuleType("vllm_omni.model_executor")) + sys.modules.setdefault("vllm_omni.model_executor.models", types.ModuleType("vllm_omni.model_executor.models")) + sys.modules.setdefault( + "vllm_omni.model_executor.models.mimo_audio", + types.ModuleType("vllm_omni.model_executor.models.mimo_audio"), + ) + sys.modules["vllm_omni.model_executor.models.mimo_audio.config_mimo_audio"] = stub_config + + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + return mod + + +def setUpModule(): + """Snapshot sys.modules, then load the target module with stubs.""" + global _saved_sys_modules, llm2code2wav_async_chunk + _saved_sys_modules = sys.modules.copy() + + mimo_audio_path = os.path.join( + REPO_ROOT, "vllm_omni", "model_executor", "stage_input_processors", "mimo_audio.py" + ) + mimo_audio = _load_module_directly("mimo_audio_test_target", mimo_audio_path) + llm2code2wav_async_chunk = mimo_audio.llm2code2wav_async_chunk + + +def tearDownModule(): + """Restore sys.modules to its original state so other tests are unaffected.""" + sys.modules.clear() + sys.modules.update(_saved_sys_modules) + + +class TestMiMoAudioIsFinishedParam(unittest.TestCase): + """Test that llm2code2wav_async_chunk accepts is_finished keyword argument.""" + + def _make_transfer_manager(self): + """Create a mock transfer_manager with the attributes used by the function.""" + tm = MagicMock() + tm.code_prompt_token_ids = defaultdict(list) + return tm + + def _make_request(self, is_finished=False, external_req_id="test-req-001"): + """Create a mock request object.""" + req = MagicMock() + req.is_finished.return_value = is_finished + req.external_req_id = external_req_id + return req + + def _make_pooling_output_with_codes(self): + """Create a pooling_output dict with valid code_predictor_codes (shape: [1,1,8,4]).""" + codes = torch.ones(1, 1, 8, 4, dtype=torch.long) + return {"code_predictor_codes": codes} + + def test_calling_with_is_finished_kwarg_raises_typeerror(self): + """ + REPRODUCTION: This test demonstrates the bug from Issue #1569. + The call site passes is_finished=True, but the function + does not accept this parameter, causing TypeError. + """ + tm = self._make_transfer_manager() + request = self._make_request(is_finished=True) + pooling_output = self._make_pooling_output_with_codes() + + # This is exactly how _send_single_request calls the function + # (chunk_transfer_adapter.py:207-211) + try: + result = llm2code2wav_async_chunk( + transfer_manager=tm, + pooling_output=pooling_output, + request=request, + is_finished=True, + ) + # If we get here, the fix is applied + print("PASS: llm2code2wav_async_chunk accepts is_finished parameter (fix applied)") + self.assertIsNotNone(result) + except TypeError as e: + if "unexpected keyword argument 'is_finished'" in str(e): + self.fail( + f"BUG REPRODUCED (Issue #1569): {e}\n" + "The function does not accept 'is_finished' keyword argument, " + "but chunk_transfer_adapter.py:211 passes it." + ) + else: + raise + + def test_is_finished_true_sets_finished_flag_in_payload(self): + """When is_finished=True, the returned payload should have finished=True.""" + tm = self._make_transfer_manager() + request = self._make_request(is_finished=False, external_req_id="test-req-002") + pooling_output = self._make_pooling_output_with_codes() + + result = llm2code2wav_async_chunk( + transfer_manager=tm, + pooling_output=pooling_output, + request=request, + is_finished=True, + ) + + # With is_finished=True, the function should emit a payload even if chunk is not full + self.assertIsNotNone(result, "Should return payload when is_finished=True") + self.assertIn("finished", result) + self.assertTrue( + result["finished"].item(), + "finished flag should be True when is_finished=True" + ) + + def test_is_finished_false_buffers_until_chunk_full(self): + """When is_finished=False and chunk is not full, should return None (buffering).""" + tm = self._make_transfer_manager() + request = self._make_request(is_finished=False, external_req_id="test-req-003") + pooling_output = self._make_pooling_output_with_codes() + + result = llm2code2wav_async_chunk( + transfer_manager=tm, + pooling_output=pooling_output, + request=request, + is_finished=False, + ) + + self.assertIsNone(result, "Should return None when chunk is not full and not finished") + + def test_empty_codes_with_is_finished_returns_sentinel(self): + """When codes are empty but is_finished=True, should return a finished sentinel.""" + tm = self._make_transfer_manager() + request = self._make_request(is_finished=False) + pooling_output = {"code_predictor_codes": None} + + result = llm2code2wav_async_chunk( + transfer_manager=tm, + pooling_output=pooling_output, + request=request, + is_finished=True, + ) + + self.assertIsNotNone(result, "Should return sentinel when is_finished=True and codes are None") + self.assertTrue(result["finished"].item()) + + def test_signature_matches_thinker2talker_pattern(self): + """ + Verify function signature matches the pattern used by thinker2talker_async_chunk, + which is the working reference implementation in qwen3_omni.py. + """ + sig = inspect.signature(llm2code2wav_async_chunk) + params = list(sig.parameters.keys()) + + self.assertIn("is_finished", params, + "is_finished must be in function signature (matches thinker2talker_async_chunk pattern)") + + is_finished_param = sig.parameters["is_finished"] + self.assertEqual(is_finished_param.default, False, + "is_finished should default to False") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index cb0657675d6..e3e64e9779f 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -51,7 +51,7 @@ def _make_finished_sentinel() -> dict[str, Any]: def llm2code2wav_async_chunk( - transfer_manager: Any, pooling_output: dict[str, Any], request: Any + transfer_manager: Any, pooling_output: dict[str, Any], request: Any, is_finished: bool = False ) -> dict[str, Any] | None: """ Async chunk version: convert stage-0 pooling_output to code2wav payload (pooling / connector accumulation). @@ -60,19 +60,19 @@ def llm2code2wav_async_chunk( returns payload only when chunk_size is full or request is finished; returns None when waiting. """ if "code_predictor_codes" not in pooling_output: - if request.is_finished(): + if is_finished: return _make_finished_sentinel() return None code_predictor_codes = pooling_output["code_predictor_codes"] if code_predictor_codes is None: - if request.is_finished(): + if is_finished: return _make_finished_sentinel() return None if isinstance(code_predictor_codes, torch.Tensor): if code_predictor_codes.numel() == 0: - if request.is_finished(): + if is_finished: return _make_finished_sentinel() return None elif hasattr(code_predictor_codes, "__len__"): @@ -81,14 +81,14 @@ def llm2code2wav_async_chunk( if isinstance(code_predictor_codes, torch.Tensor): if not code_predictor_codes.any(): - if request.is_finished(): + if is_finished: return _make_finished_sentinel() return None code_tensor = code_predictor_codes.to(torch.long) else: code_tensor = torch.tensor(code_predictor_codes, dtype=torch.long) if not code_tensor.any(): - if request.is_finished(): + if is_finished: return _make_finished_sentinel() return None @@ -101,7 +101,7 @@ def llm2code2wav_async_chunk( code_final = prepend_and_flatten_colmajor(code_tensor, pad_vec) code_list = code_final.tolist() if sum(code_list) == 0: - if request.is_finished(): + if is_finished: return _make_finished_sentinel() return None @@ -113,7 +113,7 @@ def llm2code2wav_async_chunk( transfer_manager.code_prompt_token_ids[request_id].append(code_list) length = len(transfer_manager.code_prompt_token_ids[request_id]) chunk_length = length % chunk_size - if chunk_length != 0 and not request.is_finished(): + if chunk_length != 0 and not is_finished: return None context_length = chunk_length if chunk_length != 0 else chunk_size @@ -123,7 +123,7 @@ def llm2code2wav_async_chunk( "code_predictor_codes": ( torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).reshape(-1).tolist() ), - "finished": torch.tensor(request.is_finished(), dtype=torch.bool), + "finished": torch.tensor(is_finished, dtype=torch.bool), } return info