diff --git a/examples/online_serving/bagel/run_server_stage_cli.sh b/examples/online_serving/bagel/run_server_stage_cli.sh new file mode 100644 index 00000000000..51639153f73 --- /dev/null +++ b/examples/online_serving/bagel/run_server_stage_cli.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Bagel multi-stage online serving startup script +# Starts stage 0 as master with API server, and stage 1 in headless mode + +MODEL="${MODEL:-ByteDance-Seed/BAGEL-7B-MoT}" +PORT="${PORT:-8091}" +MASTER_ADDRESS="${MASTER_ADDRESS:-127.0.0.1}" +MASTER_PORT="${MASTER_PORT:-8092}" +STAGE_CONFIGS_PATH="$(dirname "$0")/../../../vllm_omni/model_executor/stage_configs/bagel.yaml" + +echo "Starting Bagel multi-stage server..." +echo "Model: $MODEL" +echo "API Port: $PORT" +echo "Master Address: $MASTER_ADDRESS" +echo "Master Port: $MASTER_PORT" +echo "Stage Configs: $STAGE_CONFIGS_PATH" + +# Start stage 1 (DiT) in headless mode first +echo "Starting Stage 1 (DiT) in headless mode..." +vllm serve "$MODEL" --omni \ + --stage-configs-path "$STAGE_CONFIGS_PATH" \ + --stage-id 1 \ + --headless \ + -oma "$MASTER_ADDRESS" \ + -omp "$MASTER_PORT" & + +# Start stage 0 (Thinker) as master with API server +echo "Starting Stage 0 (Thinker) as master..." +vllm serve "$MODEL" --omni \ + --port "$PORT" \ + --stage-configs-path "$STAGE_CONFIGS_PATH" \ + --stage-id 0 \ + -oma "$MASTER_ADDRESS" \ + -omp "$MASTER_PORT" diff --git a/requirements/common.txt b/requirements/common.txt index 3954b4587db..7e5f5f9b260 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -14,3 +14,4 @@ imageio[ffmpeg]>=2.37.2 sox>=1.5.0 prettytable>=3.8.0 aenum==3.1.16 +pyzmq>=25.0.0 diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py index ba0390040bf..ce443d5a5c7 100644 --- a/tests/entrypoints/test_async_omni_diffusion_config.py +++ b/tests/entrypoints/test_async_omni_diffusion_config.py @@ -3,21 +3,23 @@ import pytest -from vllm_omni.entrypoints import omni as omni_module +from vllm_omni.entrypoints import utils as utils_module from vllm_omni.entrypoints.async_omni import AsyncOmni pytestmark = [pytest.mark.core_model, pytest.mark.cpu] +MODEL = "riverclouds/qwen_image_random" + def test_default_stage_config_includes_cache_backend(monkeypatch): """Ensure cache_backend/cache_config are preserved in default diffusion stage.""" - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) - monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None) + monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) + monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None) monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) omni = AsyncOmni( - model="dummy-model", + model=MODEL, cache_backend="cache_dit", cache_config='{"Fn_compute_blocks": 2}', vae_use_slicing=True, @@ -41,13 +43,13 @@ def test_default_stage_config_includes_cache_backend(monkeypatch): def test_default_cache_config_used_when_missing(monkeypatch): """Ensure default cache_config is applied when cache_backend is set.""" - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) - monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None) + monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) + monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None) monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) omni = AsyncOmni( - model="dummy-model", + model=MODEL, cache_backend="cache_dit", ) @@ -59,13 +61,13 @@ def test_default_cache_config_used_when_missing(monkeypatch): def test_default_stage_devices_from_sequence_parallel(monkeypatch): """Ensure devices list reflects sequence parallel size when no parallel_config is provided.""" - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) - monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None) + monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) + monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None) monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) omni = AsyncOmni( - model="dummy-model", + model=MODEL, ulysses_degree=2, ring_degree=2, ) diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py index ab1194a25c8..8b7977d9c1e 100644 --- a/tests/entrypoints/test_omni_diffusion.py +++ b/tests/entrypoints/test_omni_diffusion.py @@ -18,6 +18,37 @@ category=DeprecationWarning, ) +MODEL = "riverclouds/qwen_image_random" + + +class _FakeStageRequestStats: + """Fake StageRequestStats object with necessary attributes aligned with real StageRequestStats.""" + + def __init__(self, **kwargs): + # Required fields (with defaults for testing) + self.batch_id = kwargs.get("batch_id", 0) + self.batch_size = kwargs.get("batch_size", 1) + self.num_tokens_in = kwargs.get("num_tokens_in", 0) + self.num_tokens_out = kwargs.get("num_tokens_out", 1) + self.stage_gen_time_ms = kwargs.get("stage_gen_time_ms", 10.0) + self.rx_transfer_bytes = kwargs.get("rx_transfer_bytes", 0) + self.rx_decode_time_ms = kwargs.get("rx_decode_time_ms", 0.0) + self.rx_in_flight_time_ms = kwargs.get("rx_in_flight_time_ms", 0.0) + self.stage_stats = kwargs.get("stage_stats", None) + + # Optional fields + self.stage_id = kwargs.get("stage_id", None) + self.final_output_type = kwargs.get("final_output_type", None) + self.request_id = kwargs.get("request_id", None) + self.postprocess_time_ms = kwargs.get("postprocess_time_ms", 0.0) + self.diffusion_metrics = kwargs.get("diffusion_metrics", None) + self.audio_generated_frames = kwargs.get("audio_generated_frames", 0) + + # Allow additional attributes for flexibility + for key, value in kwargs.items(): + if not hasattr(self, key): + setattr(self, key, value) + class _FakeEngineArgs(dict): """Fake engine args that can be used both as object attributes and as **kwargs.""" @@ -60,8 +91,10 @@ def put(self, item): def put_nowait(self, item): self._queue.put_nowait(item) - def get(self): - return self._queue.get() + def get(self, timeout=None): + if timeout is None: + return self._queue.get() + return self._queue.get(timeout=timeout) def get_nowait(self): return self._queue.get_nowait() @@ -70,6 +103,52 @@ def empty(self): return self._queue.empty() +class _FakeZmqQueue: + """Fake ZmqQueue that wraps _FakeQueue and matches ZmqQueue interface.""" + + def __init__( + self, + ctx=None, + socket_type=None, + *, + bind: str | None = None, + connect: str | None = None, + recv_timeout_ms: int | None = None, + send_timeout_ms: int | None = None, + ): + """Initialize fake ZMQ queue with same signature as real ZmqQueue.""" + self._queue = _FakeQueue(maxsize=0) + # Determine endpoint from bind or connect + path = bind if bind is not None else connect + self.endpoint = path or f"fake://zmq-endpoint-{id(self)}" + self._recv_timeout_ms = recv_timeout_ms + self._send_timeout_ms = send_timeout_ms + + def put(self, obj: Any) -> None: + """Send an object to the queue.""" + self._queue.put(obj) + + def put_nowait(self, obj: Any) -> None: + """Send an object to the queue without blocking.""" + self._queue.put_nowait(obj) + + def get(self, timeout: float | None = None) -> Any: + """Receive an object from the queue with optional timeout in seconds.""" + return self._queue.get(timeout=timeout) + + def get_nowait(self) -> Any: + """Receive an object from the queue without blocking.""" + return self._queue.get_nowait() + + def empty(self) -> bool: + """Check if the queue is empty without blocking.""" + return self._queue.empty() + + def close(self) -> None: + """Close the queue.""" + pass + + class _FakeStage: """Lightweight Stage stub for multi-process pipeline version with queue support.""" @@ -272,6 +351,19 @@ def _mock_get_context(method): monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False) monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) + # Mock ZmqQueue to use _FakeZmqQueue + monkeypatch.setattr( + "vllm_omni.entrypoints.zmq_utils.ZmqQueue", + _FakeZmqQueue, + raising=False, + ) + # Also mock where ZmqQueue is imported/used + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.ZmqQueue", + _FakeZmqQueue, + raising=False, + ) + def _setup_ipc_mocks(monkeypatch): """Helper function to set up IPC function mocks.""" @@ -324,6 +416,65 @@ def build_and_log_summary(self, final_stage_id): ) +def _setup_connector_mocks(monkeypatch, omni_module=None): + """Helper function to set up connector mocks for stage-to-stage forwarding. + + If omni_module is provided, mocks directly on the module. Otherwise, uses string path. + """ + + # Mock initialize_orchestrator_connectors to return fake connectors + def _fake_initialize_orchestrator_connectors(config_path, worker_backend=None, shm_threshold_bytes=None): + # Create fake connectors for all stage-to-stage edges + # Each connector is just a mock object that will be passed to try_send_via_connector + fake_connectors = {} + # Add connectors for common edges (0->1, 1->2, etc.) + for i in range(10): # Support up to 10 stages + fake_connectors[(str(i), str(i + 1))] = MagicMock() + return None, fake_connectors + + if omni_module is not None: + # Mock directly on the omni module where it's used (after import) + monkeypatch.setattr(omni_module, "initialize_orchestrator_connectors", _fake_initialize_orchestrator_connectors) + else: + # Mock via string path (before import) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni.initialize_orchestrator_connectors", + _fake_initialize_orchestrator_connectors, + raising=False, + ) + + +def _setup_connector_adapter_mock(monkeypatch, omni_module): + """Helper function to mock try_send_via_connector on the omni module. + + This must be called AFTER importing omni module, to mock the function where it's actually used. + """ + + # Mock try_send_via_connector to always succeed + def _fake_try_send_via_connector( + connector, + stage_id, + next_stage_id, + req_id, + next_inputs, + sampling_params, + original_prompt, + next_stage_queue_submit_fn, + metrics, + ): + # Simulate successful send by calling the submit function + task = { + "request_id": req_id, + "engine_inputs": next_inputs, + "sampling_params": sampling_params, + } + next_stage_queue_submit_fn(task) + return True + + # Mock directly on the omni module where it's used + monkeypatch.setattr(omni_module, "try_send_via_connector", _fake_try_send_via_connector) + + @pytest.fixture(autouse=True) def mock_get_config(monkeypatch): """Auto-mock get_config and related model loading functions to avoid model path validation.""" @@ -449,10 +600,19 @@ def _mock_cached_file(path_or_repo_id, *args, **kwargs): def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config): """Test that stage configs are auto-loaded when stage_configs_path is None.""" - def _fake_loader(model: str, base_engine_args=None): - return [ - _FakeStageConfig(fake_stage_config), - _FakeStageConfig(fake_stage_config), + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + cfg0 = dict(fake_stage_config) + cfg0["stage_id"] = 0 + cfg1 = dict(fake_stage_config) + cfg1["stage_id"] = 1 + return None, [ + _FakeStageConfig(cfg0), + _FakeStageConfig(cfg1), ] # Remove modules from cache BEFORE setting mocks @@ -471,10 +631,11 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) - # Mock load_stage_configs_from_model + # Mock load_and_resolve_stage_configs monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -489,13 +650,13 @@ def _fake_loader(model: str, base_engine_args=None): # Import the module after mocks are set import vllm_omni.entrypoints.omni as omni_module - # Patch the imported function and class in the module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + # Patch the imported class in the module monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Verify: auto-loaded stage_configs and stage_list have consistent count assert isinstance(omni.stage_configs, list) assert len(omni.stage_configs) == 2 @@ -514,8 +675,13 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): """Test that generate raises ValueError when sampling_params_list length doesn't match.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(fake_stage_config)] import sys @@ -531,14 +697,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -550,12 +712,12 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) with pytest.raises(ValueError): omni.generate(prompts=["hi"], sampling_params_list=[]) @@ -563,11 +725,18 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): """Test multi-stage generation pipeline with queue polling.""" stage_cfg0 = dict(fake_stage_config) + stage_cfg0["stage_id"] = 0 stage_cfg1 = dict(fake_stage_config) + stage_cfg1["stage_id"] = 1 stage_cfg1["processed_input"] = ["processed-for-stage-1"] - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -575,6 +744,8 @@ def _fake_loader(model: str, base_engine_args=None): "vllm_omni.entrypoints.utils", "vllm_omni.entrypoints.omni", "vllm_omni.entrypoints.omni_stage", + "vllm_omni.distributed.omni_connectors.adapter", + "vllm_omni.distributed.omni_connectors", ]: if module_name in sys.modules: del sys.modules[module_name] @@ -585,12 +756,7 @@ def _fake_loader(model: str, base_engine_args=None): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -602,8 +768,11 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) + _setup_connector_adapter_mock(monkeypatch, omni_module) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -612,7 +781,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -625,7 +794,7 @@ def _fake_loader(model: str, base_engine_args=None): { "request_id": expected_request_id, "engine_outputs": [{"stage": 0, "text": "s0"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) # Stage 1 output (will be collected after stage 0 forwards to it) @@ -637,7 +806,7 @@ def _fake_loader(model: str, base_engine_args=None): { "request_id": expected_request_id, "engine_outputs": [{"stage": 1, "text": "s1"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) @@ -662,11 +831,18 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_pipeline_with_batch_input(monkeypatch, fake_stage_config): """Test single-stage generation pipeline with multiple inputs in one batch.""" stage_cfg0 = dict(fake_stage_config) - stage_cfg1 = dict(fake_stage_config) + stage_cfg0["stage_id"] = 0 stage_cfg0["final_output"] = False + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["stage_id"] = 1 - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -674,6 +850,8 @@ def _fake_loader(model: str, base_engine_args=None): "vllm_omni.entrypoints.utils", "vllm_omni.entrypoints.omni", "vllm_omni.entrypoints.omni_stage", + "vllm_omni.distributed.omni_connectors.adapter", + "vllm_omni.distributed.omni_connectors", ]: if module_name in sys.modules: del sys.modules[module_name] @@ -684,7 +862,7 @@ def _fake_loader(model: str, base_engine_args=None): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -696,8 +874,11 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) + _setup_connector_adapter_mock(monkeypatch, omni_module) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -706,7 +887,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -718,28 +899,28 @@ def _fake_loader(model: str, base_engine_args=None): { "request_id": expected_request_id, "engine_outputs": [{"stage": 0, "text": "s0"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) omni.stage_list[0]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 0, "text": "s0"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) omni.stage_list[1]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 1}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) omni.stage_list[1]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 1}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) @@ -768,12 +949,19 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): """Test that generate returns empty list when all stages have final_output=False.""" stage_cfg0 = dict(fake_stage_config) - stage_cfg1 = dict(fake_stage_config) + stage_cfg0["stage_id"] = 0 stage_cfg0["final_output"] = False + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["stage_id"] = 1 stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -781,6 +969,8 @@ def _fake_loader(model: str, base_engine_args=None): "vllm_omni.entrypoints.utils", "vllm_omni.entrypoints.omni", "vllm_omni.entrypoints.omni_stage", + "vllm_omni.distributed.omni_connectors.adapter", + "vllm_omni.distributed.omni_connectors", ]: if module_name in sys.modules: del sys.modules[module_name] @@ -791,7 +981,7 @@ def _fake_loader(model: str, base_engine_args=None): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -803,8 +993,11 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) + _setup_connector_adapter_mock(monkeypatch, omni_module) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -813,7 +1006,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -847,12 +1040,19 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config): """Test that generate uses default sampling params when sampling_params_list is None.""" stage_cfg0 = dict(fake_stage_config) - stage_cfg1 = dict(fake_stage_config) + stage_cfg0["stage_id"] = 0 stage_cfg0["final_output"] = False + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["stage_id"] = 1 stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -860,6 +1060,8 @@ def _fake_loader(model: str, base_engine_args=None): "vllm_omni.entrypoints.utils", "vllm_omni.entrypoints.omni", "vllm_omni.entrypoints.omni_stage", + "vllm_omni.distributed.omni_connectors.adapter", + "vllm_omni.distributed.omni_connectors", ]: if module_name in sys.modules: del sys.modules[module_name] @@ -870,7 +1072,7 @@ def _fake_loader(model: str, base_engine_args=None): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -882,8 +1084,11 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) + _setup_connector_adapter_mock(monkeypatch, omni_module) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -892,7 +1097,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -902,14 +1107,14 @@ def _fake_loader(model: str, base_engine_args=None): { "request_id": expected_request_id, "engine_outputs": [{"stage": 0}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) omni.stage_list[1]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 1}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) # Use the default sampling params @@ -919,8 +1124,13 @@ def _fake_loader(model: str, base_engine_args=None): def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): """Test that _wait_for_stages_ready handles timeout correctly.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(fake_stage_config)] import sys @@ -936,9 +1146,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -961,13 +1172,13 @@ def init_stage_worker(self, *args, **kwargs): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) from vllm_omni.entrypoints.omni import Omni # Use very short timeout - omni = Omni(model="any", init_timeout=0.01) + omni = Omni(model=MODEL, init_timeout=0.01) # Verify that no stages are ready assert len(omni._stages_ready) == 0 @@ -975,8 +1186,13 @@ def init_stage_worker(self, *args, **kwargs): def test_generate_handles_error_messages(monkeypatch, fake_stage_config): """Test that generate handles error messages from stages correctly.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(fake_stage_config)] import sys @@ -992,9 +1208,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -1006,8 +1223,8 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -1016,7 +1233,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -1034,7 +1251,7 @@ def _fake_loader(model: str, base_engine_args=None): { "request_id": expected_request_id, "engine_outputs": [{"stage": 0, "text": "result"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), } ) @@ -1050,8 +1267,13 @@ def _fake_loader(model: str, base_engine_args=None): def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config): """Test that close() sends shutdown signal to all input queues.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(fake_stage_config)] import sys @@ -1067,9 +1289,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -1081,12 +1304,12 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Call close omni.close() diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index 4f05575ca59..49971f1b41c 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -18,6 +18,8 @@ category=DeprecationWarning, ) +MODEL = "riverclouds/qwen_image_random" + class _FakeEngineArgs(dict): """Fake engine args that can be used both as object attributes and as **kwargs.""" @@ -60,8 +62,10 @@ def put(self, item): def put_nowait(self, item): self._queue.put_nowait(item) - def get(self): - return self._queue.get() + def get(self, timeout=None): + if timeout is None: + return self._queue.get() + return self._queue.get(timeout=timeout) def get_nowait(self): return self._queue.get_nowait() @@ -70,6 +74,52 @@ def empty(self): return self._queue.empty() +class _FakeZmqQueue: + """Fake ZmqQueue that wraps _FakeQueue and matches ZmqQueue interface.""" + + def __init__( + self, + ctx=None, + socket_type=None, + *, + bind: str | None = None, + connect: str | None = None, + recv_timeout_ms: int | None = None, + send_timeout_ms: int | None = None, + ): + """Initialize fake ZMQ queue with same signature as real ZmqQueue.""" + self._queue = _FakeQueue(maxsize=0) + # Determine endpoint from bind or connect + path = bind if bind is not None else connect + self.endpoint = path or f"fake://zmq-endpoint-{id(self)}" + self._recv_timeout_ms = recv_timeout_ms + self._send_timeout_ms = send_timeout_ms + + def put(self, obj: Any) -> None: + """Send an object to the queue.""" + self._queue.put(obj) + + def put_nowait(self, obj: Any) -> None: + """Send an object to the queue without blocking.""" + self._queue.put_nowait(obj) + + def get(self, timeout: float | None = None) -> Any: + """Receive an object from the queue with optional timeout in seconds.""" + return self._queue.get(timeout=timeout) + + def get_nowait(self) -> Any: + """Receive an object from the queue without blocking.""" + return self._queue.get_nowait() + + def empty(self) -> bool: + """Check if the queue is empty without blocking.""" + return self._queue.empty() + + def close(self) -> None: + """Close the queue.""" + pass + + class _FakeStage: """Lightweight Stage stub for multi-process pipeline version with queue support.""" @@ -272,6 +322,19 @@ def _mock_get_context(method): monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False) monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) + # Mock ZmqQueue to use _FakeZmqQueue + monkeypatch.setattr( + "vllm_omni.entrypoints.zmq_utils.ZmqQueue", + _FakeZmqQueue, + raising=False, + ) + # Also mock where ZmqQueue is imported/used + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.ZmqQueue", + _FakeZmqQueue, + raising=False, + ) + def _setup_ipc_mocks(monkeypatch): """Helper function to set up IPC function mocks.""" @@ -324,6 +387,65 @@ def build_and_log_summary(self, final_stage_id): ) +def _setup_connector_mocks(monkeypatch, omni_module=None): + """Helper function to set up connector mocks for stage-to-stage forwarding. + + If omni_module is provided, mocks directly on the module. Otherwise, uses string path. + """ + + # Mock initialize_orchestrator_connectors to return fake connectors + def _fake_initialize_orchestrator_connectors(config_path, worker_backend=None, shm_threshold_bytes=None): + # Create fake connectors for all stage-to-stage edges + # Each connector is just a mock object that will be passed to try_send_via_connector + fake_connectors = {} + # Add connectors for common edges (0->1, 1->2, etc.) + for i in range(10): # Support up to 10 stages + fake_connectors[(str(i), str(i + 1))] = MagicMock() + return None, fake_connectors + + if omni_module is not None: + # Mock directly on the omni module where it's used (after import) + monkeypatch.setattr(omni_module, "initialize_orchestrator_connectors", _fake_initialize_orchestrator_connectors) + else: + # Mock via string path (before import) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni.initialize_orchestrator_connectors", + _fake_initialize_orchestrator_connectors, + raising=False, + ) + + +def _setup_connector_adapter_mock(monkeypatch, omni_module): + """Helper function to mock try_send_via_connector on the omni module. + + This must be called AFTER importing omni module, to mock the function where it's actually used. + """ + + # Mock try_send_via_connector to always succeed + def _fake_try_send_via_connector( + connector, + stage_id, + next_stage_id, + req_id, + next_inputs, + sampling_params, + original_prompt, + next_stage_queue_submit_fn, + metrics, + ): + # Simulate successful send by calling the submit function + task = { + "request_id": req_id, + "engine_inputs": next_inputs, + "sampling_params": sampling_params, + } + next_stage_queue_submit_fn(task) + return True + + # Mock directly on the omni module where it's used + monkeypatch.setattr(omni_module, "try_send_via_connector", _fake_try_send_via_connector) + + @pytest.fixture(autouse=True) def mock_get_config(monkeypatch): """Auto-mock get_config and related model loading functions to avoid model path validation.""" @@ -449,10 +571,19 @@ def _mock_cached_file(path_or_repo_id, *args, **kwargs): def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config): """Test that stage configs are auto-loaded when stage_configs_path is None.""" - def _fake_loader(model: str, base_engine_args=None): - return [ - _FakeStageConfig(fake_stage_config), - _FakeStageConfig(fake_stage_config), + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + cfg0 = dict(fake_stage_config) + cfg0["stage_id"] = 0 + cfg1 = dict(fake_stage_config) + cfg1["stage_id"] = 1 + return None, [ + _FakeStageConfig(cfg0), + _FakeStageConfig(cfg1), ] # Remove modules from cache BEFORE setting mocks @@ -471,10 +602,11 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) - # Mock load_stage_configs_from_model + # Mock load_and_resolve_stage_configs monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -490,12 +622,14 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module # Patch the imported function and class in the module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Verify: auto-loaded stage_configs and stage_list have consistent count assert isinstance(omni.stage_configs, list) assert len(omni.stage_configs) == 2 @@ -514,8 +648,15 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): """Test that generate raises ValueError when sampling_params_list length doesn't match.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + cfg0 = dict(fake_stage_config) + cfg0["stage_id"] = 0 + return None, [_FakeStageConfig(cfg0)] import sys @@ -531,14 +672,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -550,12 +687,12 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) with pytest.raises(ValueError): omni.generate(prompts=["hi"], sampling_params_list=[]) @@ -563,11 +700,18 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): """Test multi-stage generation pipeline with queue polling.""" stage_cfg0 = dict(fake_stage_config) + stage_cfg0["stage_id"] = 0 stage_cfg1 = dict(fake_stage_config) + stage_cfg1["stage_id"] = 1 stage_cfg1["processed_input"] = ["processed-for-stage-1"] - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -583,14 +727,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -602,8 +742,11 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) + _setup_connector_adapter_mock(monkeypatch, omni_module) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -612,7 +755,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -662,12 +805,19 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): """Test that generate returns empty list when all stages have final_output=False.""" stage_cfg0 = dict(fake_stage_config) - stage_cfg1 = dict(fake_stage_config) + stage_cfg0["stage_id"] = 0 stage_cfg0["final_output"] = False + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["stage_id"] = 1 stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -683,9 +833,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -697,8 +848,11 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) + _setup_connector_adapter_mock(monkeypatch, omni_module) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -707,7 +861,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -741,12 +895,19 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config): """Test that generate uses default sampling params when sampling_params_list is None.""" stage_cfg0 = dict(fake_stage_config) - stage_cfg1 = dict(fake_stage_config) + stage_cfg0["stage_id"] = 0 stage_cfg0["final_output"] = False + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["stage_id"] = 1 stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -762,9 +923,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -776,8 +938,11 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + # Apply connector and adapter mocks after importing omni module + _setup_connector_mocks(monkeypatch, omni_module) + _setup_connector_adapter_mock(monkeypatch, omni_module) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -786,7 +951,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -813,8 +978,15 @@ def _fake_loader(model: str, base_engine_args=None): def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): """Test that _wait_for_stages_ready handles timeout correctly.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + cfg0 = dict(fake_stage_config) + cfg0["stage_id"] = 0 + return None, [_FakeStageConfig(cfg0)] import sys @@ -830,9 +1002,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -855,13 +1028,13 @@ def init_stage_worker(self, *args, **kwargs): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs)) from vllm_omni.entrypoints.omni import Omni # Use very short timeout - omni = Omni(model="any", init_timeout=0.01) + omni = Omni(model=MODEL, init_timeout=0.01) # Verify that no stages are ready assert len(omni._stages_ready) == 0 @@ -869,8 +1042,15 @@ def init_stage_worker(self, *args, **kwargs): def test_generate_handles_error_messages(monkeypatch, fake_stage_config): """Test that generate handles error messages from stages correctly.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + cfg0 = dict(fake_stage_config) + cfg0["stage_id"] = 0 + return None, [_FakeStageConfig(cfg0)] import sys @@ -886,9 +1066,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -900,7 +1081,7 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) # Mock uuid.uuid4() to return a predictable value for request ID generation @@ -910,7 +1091,7 @@ def _fake_loader(model: str, base_engine_args=None): from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -944,8 +1125,15 @@ def _fake_loader(model: str, base_engine_args=None): def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config): """Test that close() sends shutdown signal to all input queues.""" - def _fake_loader(model: str, base_engine_args=None): - return [_FakeStageConfig(fake_stage_config)] + def _fake_loader( + model: str, + stage_configs_path: str | None = None, + base_engine_args: dict | None = None, + default_stage_cfg_factory=None, + ): + cfg0 = dict(fake_stage_config) + cfg0["stage_id"] = 0 + return None, [_FakeStageConfig(cfg0)] import sys @@ -961,9 +1149,10 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", _fake_loader, raising=False, ) @@ -975,12 +1164,12 @@ def _fake_loader(model: str, base_engine_args=None): import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) from vllm_omni.entrypoints.omni import Omni - omni = Omni(model="any", init_timeout=1) + omni = Omni(model=MODEL, init_timeout=1) # Call close omni.close() diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index d00652a658b..08c4ce90388 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -38,7 +38,7 @@ logger = init_logger(__name__) -def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handler): +def _weak_close_cleanup_async(stage_list, stage_in_queues, stage_out_queues, ray_pg, output_handler, zmq_ctx=None): """Weak reference cleanup function for AsyncOmni instances.""" if stage_list: for q in stage_in_queues: @@ -46,6 +46,13 @@ def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handle q.put_nowait(SHUTDOWN_TASK) except Exception as e: logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() + for q in stage_out_queues: + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() for stage in stage_list: try: stage.stop_stage_worker() @@ -55,6 +62,8 @@ def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handle # Cancel output handler if output_handler is not None: output_handler.cancel() + if zmq_ctx is not None: + zmq_ctx.term() class AsyncOmni(OmniBase): @@ -108,8 +117,10 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: _weak_close_cleanup_async, self.stage_list, self._stage_in_queues, + self._stage_out_queues, self._ray_pg, self.output_handler, + self._zmq_ctx, ) def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]: diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index e28637f6aa8..73bb03805df 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -7,19 +7,37 @@ import argparse import os +import signal +from typing import Any +import msgspec.msgpack import uvloop +import zmq from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.logger import init_logger from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.network_utils import make_zmq_socket +from vllm.v1.utils import get_engine_client_zmq_addr +from vllm_omni.distributed.omni_connectors import ( + get_connectors_config_for_stage, + load_omni_transfer_config, +) +from vllm_omni.distributed.omni_connectors.utils.initialization import ( + resolve_omni_kv_config_for_stage, +) from vllm_omni.entrypoints.cli.logo import log_logo +from vllm_omni.entrypoints.omni import OmniBase, omni_snapshot_download +from vllm_omni.entrypoints.omni_stage import OmniStage from vllm_omni.entrypoints.openai.api_server import omni_run_server +from vllm_omni.entrypoints.utils import inject_omni_kv_config logger = init_logger(__name__) +HANDSHAKE_TIMEOUT_MINS = 5 + DESCRIPTION = """Launch a local OpenAI-compatible API server to serve Omni models via HTTP. Supports both multi-stage LLM models and diffusion models. @@ -55,9 +73,15 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, "model_tag") and args.model_tag is not None: args.model = args.model_tag - uvloop.run(omni_run_server(args)) + if args.headless: + run_headless(args) + else: + uvloop.run(omni_run_server(args)) def validate(self, args: argparse.Namespace) -> None: + if args.stage_id is not None and (args.omni_master_address is None or args.omni_master_port is None): + raise ValueError("--stage-id requires both --omni-master-address and --omni-master-port to be set") + # Skip validation for diffusion models as they have different requirements from vllm_omni.diffusion.utils.hf_utils import is_diffusion_model @@ -94,6 +118,12 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="Path to the stage configs file. If not specified, the stage configs will be loaded from the model.", ) + omni_config_group.add_argument( + "--stage-id", + type=int, + default=None, + help="Select and launch a single stage by stage_id.", + ) omni_config_group.add_argument( "--stage-init-timeout", type=int, @@ -142,6 +172,18 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="The address of the Ray cluster to connect to.", ) + omni_config_group.add_argument( + "--omni-master-address", + "-oma", + type=str, + help="Hostname or IP address of the Omni orchestrator (master).", + ) + omni_config_group.add_argument( + "--omni-master-port", + "-omp", + type=int, + help="Port of the Omni orchestrator (master).", + ) # Diffusion model specific arguments omni_config_group.add_argument( @@ -251,5 +293,125 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu return serve_parser +def _create_default_diffusion_stage_cfg(args: argparse.Namespace) -> list[dict[str, Any]]: + omni_base = OmniBase.__new__(OmniBase) + return omni_base._create_default_diffusion_stage_cfg(vars(args)) + + +def run_headless(args: argparse.Namespace) -> None: + if args.api_server_count is not None and args.api_server_count > 1: + raise ValueError("api_server_count can't be set in headless mode") + if args.worker_backend != "multi_process": + raise ValueError("headless mode requires worker_backend=multi_process") + + model = omni_snapshot_download(args.model) + + omni_base = OmniBase.__new__(OmniBase) + args_dict = vars(args).copy() + args_dict["model"] = model + config_path, stage_configs = omni_base._resolve_stage_configs(model, args_dict) + + single_stage_id = args.stage_id + if single_stage_id is None: + if len(stage_configs) != 1: + raise ValueError("--stage-id is required in headless mode for multi-stage configs") + single_stage_id = getattr(stage_configs[0], "stage_id", 0) + + stage_config = None + for cfg in stage_configs: + if getattr(cfg, "stage_id", None) == single_stage_id: + stage_config = cfg + break + if stage_config is None: + raise ValueError(f"No stage matches stage_id={single_stage_id}.") + + # TODO(wuhang): Support connectors config by cli + transfer_config = load_omni_transfer_config(config_path, default_shm_threshold=args.shm_threshold_bytes) + connectors_config = get_connectors_config_for_stage(transfer_config, single_stage_id) + + omni_master_address = args.omni_master_address + omni_master_port = args.omni_master_port + + # Perform handshake with orchestrator to get dynamically allocated endpoints + with zmq.Context() as zmq_ctx: + handshake_endpoint = get_engine_client_zmq_addr( + local_only=False, host=omni_master_address, port=omni_master_port + ) + + with make_zmq_socket(zmq_ctx, handshake_endpoint, zmq.REQ, bind=False, linger=5000) as handshake_socket: + # TODO(wuhang): Define protocol in python dataclass. + handshake_msg = {"type": "handshake", "stage_id": single_stage_id} + handshake_socket.send(msgspec.msgpack.encode(handshake_msg)) + + # Wait for response with timeout + if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): + raise RuntimeError( + f"Handshake timeout ({HANDSHAKE_TIMEOUT_MINS} minutes) for stage-{single_stage_id} " + f"at {handshake_endpoint}" + ) + + try: + response = msgspec.msgpack.decode(handshake_socket.recv()) + except msgspec.DecodeError as exc: + raise RuntimeError( + f"Handshake decode failed for stage-{single_stage_id} at {handshake_endpoint}: {exc}" + ) from exc + except Exception as exc: # pragma: no cover - unexpected decode errors + raise RuntimeError( + f"Unexpected error decoding handshake for stage-{single_stage_id} at {handshake_endpoint}: {exc}" + ) from exc + + if not response["ok"]: + error_msg = response["error"] + raise RuntimeError(f"Handshake failed for stage-{single_stage_id}: {error_msg}") + + in_endpoint, out_endpoint = response["in_endpoint"], response["out_endpoint"] + + logger.info( + f"[Headless] Stage-{single_stage_id} received endpoints via handshake: " + f"in={in_endpoint}, out={out_endpoint}" + ) + + shutdown_requested = False + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if shutdown_requested: + return + shutdown_requested = True + raise SystemExit + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + stage = OmniStage(stage_config, stage_init_timeout=args.stage_init_timeout) + stage.attach_queues(in_endpoint, out_endpoint) + + # Inject YAML-resolved connector config into omni_kv_config for in-engine usage. + try: + omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage(transfer_config, single_stage_id) + if omni_conn_cfg: + inject_omni_kv_config(stage, omni_conn_cfg, omni_from, omni_to) # type: ignore + except Exception as e: + logger.debug("[Headless] Failed to inject omni connector config into stage-%s: %s", single_stage_id, e) + + old_env = os.environ.get("VLLM_LOGGING_PREFIX") + os.environ["VLLM_LOGGING_PREFIX"] = f"[Stage-{single_stage_id}] {'' if old_env is None else old_env}" + try: + stage.init_stage_worker( + model, + is_async=True, + shm_threshold_bytes=int(args.shm_threshold_bytes), + batch_timeout=int(args.batch_timeout), + connectors_config=connectors_config, + worker_backend="multi_process", + ignore_runtime_config=True, + ) + if stage._proc is not None: + stage._proc.join() + finally: + stage.stop_stage_worker() + + def cmd_init() -> list[CLISubcommand]: return [OmniServeCommand()] diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index f30cd7d368e..051b2acf0a1 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -3,6 +3,7 @@ import json import multiprocessing as mp import os +import threading import time import uuid import weakref @@ -11,10 +12,14 @@ from typing import Any, Literal, overload import huggingface_hub +import msgspec.msgpack +import zmq from omegaconf import OmegaConf from tqdm.auto import tqdm from vllm import SamplingParams from vllm.logger import init_logger +from vllm.utils.network_utils import make_zmq_socket +from vllm.v1.utils import get_engine_client_zmq_addr from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, @@ -33,12 +38,12 @@ from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load from vllm_omni.entrypoints.utils import ( + build_base_engine_args, get_final_stage_id_for_e2e, inject_omni_kv_config, - load_stage_configs_from_model, - load_stage_configs_from_yaml, - resolve_model_config_path, + load_and_resolve_stage_configs, ) +from vllm_omni.entrypoints.zmq_utils import ZmqQueue from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams from vllm_omni.metrics import OrchestratorAggregator, StageRequestStats from vllm_omni.model_executor.model_loader.weight_utils import ( @@ -49,7 +54,16 @@ logger = init_logger(__name__) -def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): +def _weak_close_cleanup( + stage_list, + stage_in_queues, + stage_out_queues, + ray_pg, + zmq_ctx=None, + handshake_stop: threading.Event | None = None, + zmq_handshake_socket: zmq.Socket | None = None, + handshake_thread: threading.Thread | None = None, +): """Weak reference cleanup function for OmniBase instances.""" if stage_list: for q in stage_in_queues: @@ -57,6 +71,13 @@ def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): q.put_nowait(SHUTDOWN_TASK) except Exception as e: logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() + for q in stage_out_queues: + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() for stage in stage_list: try: stage.stop_stage_worker() @@ -64,6 +85,20 @@ def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): logger.warning(f"Failed to stop stage worker: {e}") try_close_ray(ray_pg) + # Gracefully shutdown handshake server thread + if handshake_stop is not None: + handshake_stop.set() + if handshake_thread is not None: + handshake_thread.join(timeout=2.0) + if handshake_thread.is_alive(): + logger.warning("Handshake server thread did not terminate gracefully within timeout") + + # Close ZMQ resources after thread has exited + if zmq_handshake_socket is not None: + zmq_handshake_socket.close(0) + if zmq_ctx is not None: + zmq_ctx.term() + def _dummy_snapshot_download(model_id): return model_id @@ -125,12 +160,21 @@ def __init__(self, model: str, **kwargs: Any) -> None: # Stage management attributes self.stage_list: list[OmniStage] = [] - self._stage_in_queues: list[mp.Queue] = [] - self._stage_out_queues: list[mp.Queue] = [] + self._stage_in_queues: list[Any] = [] + self._stage_out_queues: list[Any] = [] self._stages_ready: set[int] = set() self._ray_pg = None self._queue_cls = None self._ctx = None + self._zmq_ctx: zmq.Context | None = None + self._zmq_master_address: str | None = None + self._zmq_master_port: int | None = None + self._zmq_handshake_socket: zmq.Socket | None = None + self._handshake_thread: threading.Thread | None = None + self._handshake_stop: threading.Event | None = None + self._handshake_endpoints: dict[int, tuple[str, str]] = {} + self._handshake_seen: set[int] = set() # Track which stage IDs have completed ZMQ handshake + self._single_stage_id: int | None = None # Optional: deploy only a specific stage ID # Initialize stages - each stage will create appropriate instance based on stage_type # Stage workers will automatically create OmniLLM or OmniDiffusion instances @@ -206,35 +250,25 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" return default_stage_cfg - def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: - """Initialize stage list management.""" - stage_init_timeout = kwargs.get("stage_init_timeout", 20) - shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) - init_timeout = kwargs.get("init_timeout", 300) - worker_backend = kwargs.get("worker_backend", "multi_process") - ray_address = kwargs.get("ray_address", None) - batch_timeout = kwargs.get("batch_timeout", 10) - stage_configs_path = kwargs.get("stage_configs_path", None) - log_stats = kwargs.get("log_stats", False) - - ### base engine args - tokenizer = kwargs.get("tokenizer", None) + def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]: + """Resolve stage configs and inject defaults shared by orchestrator/headless.""" + # TODO(wuhang): + # Remove kwargs as parameters in the future. + # Use dataclass directly for engine args. + base_engine_args = build_base_engine_args(kwargs) - base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None + stage_configs_path = kwargs.get("stage_configs_path", None) # Load stage configurations from YAML - if stage_configs_path is None: - self.config_path = resolve_model_config_path(model) - self.stage_configs = load_stage_configs_from_model(model, base_engine_args=base_engine_args) - if not self.stage_configs: - default_stage_cfg = self._create_default_diffusion_stage_cfg(kwargs) - self.stage_configs = OmegaConf.create(default_stage_cfg) - else: - self.config_path = stage_configs_path - self.stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=base_engine_args) + config_path, stage_configs = load_and_resolve_stage_configs( + model, + stage_configs_path, + base_engine_args, + default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs), + ) # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. - for cfg in self.stage_configs: + for cfg in stage_configs: try: if getattr(cfg, "stage_type", None) != "diffusion": continue @@ -253,6 +287,27 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: except Exception as e: logger.warning("Failed to inject LoRA config for stage: %s", e) + return config_path, stage_configs + + def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: + """Initialize stage list management.""" + stage_init_timeout = kwargs.get("stage_init_timeout", 20) + shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) + init_timeout = kwargs.get("init_timeout", 300) + worker_backend = kwargs.get("worker_backend", "multi_process") + ray_address = kwargs.get("ray_address", None) + batch_timeout = kwargs.get("batch_timeout", 10) + log_stats = kwargs.get("log_stats", False) + self._single_stage_id = kwargs.get("stage_id", None) + self._zmq_master_address = kwargs.get("omni_master_address", None) + if self._zmq_master_address is None: + self._zmq_master_address = "127.0.0.1" + logger.info("No omni_master_address provided, defaulting to localhost (127.0.0.1)") + self._zmq_master_port = kwargs.get("omni_master_port", None) + + # Resolve stage configs shared by orchestrator/headless paths. + self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs) + # Initialize connectors self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes @@ -281,7 +336,7 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: self.stage_list = [st for _, st in results] self.default_sampling_params_list = [st.default_sampling_params for st in self.stage_list] self.output_modalities = [st.final_output_type for st in self.stage_list] - logger.debug(f"[{self._name}] Loaded {len(self.stage_list)} stages") + logger.info(f"[{self._name}] Loaded {len(self.stage_list)} stages") if self.worker_backend == "ray": self._queue_cls = get_ray_queue_class() @@ -307,10 +362,34 @@ def _start_stages(self, model: str) -> None: self._ray_pg = create_placement_group( number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" ) + else: + # Initialize ZMQ context + if self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + + # Allocate endpoints for each stage + total_stages = len(self.stage_configs) + self._handshake_endpoints = {} + for sid in range(total_stages): + in_endpoint = get_engine_client_zmq_addr(local_only=False, host=self._zmq_master_address) + out_endpoint = get_engine_client_zmq_addr(local_only=False, host=self._zmq_master_address) + self._handshake_endpoints[sid] = (in_endpoint, out_endpoint) + logger.debug( + f"[{self._name}] Allocated endpoints for stage-{sid}: in={in_endpoint}, out={out_endpoint}" + ) + + # Start handshake server + self.start_handshake_server() for stage_id, stage in enumerate[OmniStage](self.stage_list): - in_q = self._queue_cls() - out_q = self._queue_cls() + if self.worker_backend == "ray": + in_q = self._queue_cls() + out_q = self._queue_cls() + else: + in_endpoint, out_endpoint = self._handshake_endpoints[stage_id] + in_q = ZmqQueue(self._zmq_ctx, zmq.PUSH, bind=in_endpoint) + out_q = ZmqQueue(self._zmq_ctx, zmq.PULL, bind=out_endpoint) + self._stage_in_queues.append(in_q) self._stage_out_queues.append(out_q) stage.attach_queues(in_q, out_q) @@ -332,6 +411,12 @@ def _start_stages(self, model: str) -> None: except Exception as e: logger.debug("[Omni] Failed to inject omni connector config into stage-%s: %s", stage_id, e) + if self._single_stage_id is not None and stage_id != int(self._single_stage_id): + logger.info( + f"[{self._name}] Skipping initialization of stage-{stage_id} worker due to single_stage_id setting" + ) + continue + stage.init_stage_worker( model, is_async=self.is_async, @@ -341,6 +426,7 @@ def _start_stages(self, model: str) -> None: connectors_config=stage_connectors_config, worker_backend=self.worker_backend, ray_placement_group=self._ray_pg, + ignore_runtime_config=True if self._single_stage_id is not None else False, ) logger.debug(f"[{self._name}] Stage-{stage_id} process started") @@ -351,6 +437,9 @@ def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str def _wait_for_stages_ready(self, timeout: int = 120) -> None: """Wait for all stages to report readiness with optimized polling.""" + if self._single_stage_id is not None and self.worker_backend != "ray": + timeout = self._wait_for_handshakes(timeout) + num_stages = len(self.stage_list) deadline = time.time() + max(0, int(timeout)) @@ -501,6 +590,129 @@ def close(self) -> None: if hasattr(self, "_weak_finalizer"): self._weak_finalizer() + def _process_handshake_message(self, msg: Any) -> dict[str, Any]: + """Process incoming handshake message and generate response. + + Args: + msg: Decoded message from client + + Returns: + Response dictionary with ok status and either endpoints or error + """ + if not isinstance(msg, dict) or msg.get("type") != "handshake": + return {"ok": False, "error": "invalid handshake payload"} + + try: + stage_id = int(msg.get("stage_id")) + except (TypeError, ValueError) as e: + return {"ok": False, "error": f"invalid stage_id: {e}"} + + endpoints = self._handshake_endpoints.get(stage_id) + if endpoints is None: + return {"ok": False, "error": f"unknown stage_id: {stage_id}"} + + # Mark stage as seen and prepare success response + self._handshake_seen.add(stage_id) + in_endpoint, out_endpoint = endpoints + + logger.info( + "[%s] Handshake received from stage-%s", + self._name, + stage_id, + ) + + return { + "ok": True, + "in_endpoint": in_endpoint, + "out_endpoint": out_endpoint, + } + + def _run_handshake_server_loop(self) -> None: + """Main loop for handshake server - polls for messages and responds.""" + poller = zmq.Poller() + poller.register(self._zmq_handshake_socket, zmq.POLLIN) + + try: + while not self._handshake_stop.is_set(): + events = poller.poll(1000) + has_message = any(sock == self._zmq_handshake_socket and event == zmq.POLLIN for sock, event in events) + if not has_message: + continue + + msg = msgspec.msgpack.decode(self._zmq_handshake_socket.recv()) + response = msgspec.msgpack.encode(self._process_handshake_message(msg)) + self._zmq_handshake_socket.send(response) + finally: + poller.unregister(self._zmq_handshake_socket) + + def start_handshake_server(self) -> None: + """Start the ZMQ handshake server. + + The handshake server allows distributed stages to discover their + queue endpoints by querying the orchestrator with their stage_id. + Skips starting if the server is already running or ZMQ is not initialized. + """ + # Skip if already running or ZMQ not initialized + if self._handshake_thread is not None or self._zmq_ctx is None: + return + + # Skip if master address/port not configured + if not self._zmq_master_address or self._zmq_master_port is None: + return + + # Create server endpoint and socket + endpoint = get_engine_client_zmq_addr( + local_only=False, host=self._zmq_master_address, port=int(self._zmq_master_port) + ) + + self._handshake_stop = threading.Event() + self._zmq_handshake_socket = make_zmq_socket(self._zmq_ctx, endpoint, zmq.REP, bind=True, linger=5000) + + # Start server thread + self._handshake_thread = threading.Thread( + target=self._run_handshake_server_loop, daemon=True, name="zmq-handshake-server" + ) + self._handshake_thread.start() + + def _wait_for_handshakes(self, timeout: int = 120) -> int: + """Wait for handshakes from all expected stages. + + Args: + timeout: Timeout in seconds for waiting for handshakes. Default is 120s. + + Returns: + Remaining timeout in seconds after waiting for handshakes. + """ + total_stages = len(self.stage_configs) + expected = set(range(total_stages)) - {int(self._single_stage_id)} + if not expected: + return timeout + + deadline = time.time() + max(0, int(timeout)) + logger.info(f"[{self._name}] Waiting for handshakes from stages: {expected} (timeout: {timeout}s)") + + # NOTE: _handshake_seen may be updated from the handshake server thread. + # It is intentionally used here without additional locking because: + # - _handshake_seen only ever grows (stages are added but never removed), and + # - we only check membership and set inclusion relative to `expected`. + # Under these monotonic semantics and the CPython GIL, concurrent reads/writes + # are safe for this usage and cannot violate correctness: we may observe a + # slightly stale view, but the loop condition remains valid and eventually + # becomes true once all expected stages have handshaked or the timeout elapses. + while not expected.issubset(self._handshake_seen) and time.time() < deadline: + time.sleep(1.0) + + remaining_timeout = max(0, int(deadline - time.time())) + + if not expected.issubset(self._handshake_seen): + missing = sorted(expected - self._handshake_seen) + logger.warning( + f"[{self._name}] Handshake timeout: {len(self._handshake_seen)}/{len(expected)} " + f"stages completed handshake. Missing stages: {missing}" + ) + + return remaining_timeout + @property def _name(self) -> str: return "OmniBase" @@ -546,7 +758,12 @@ def __init__(self, model: str, **kwargs: Any) -> None: _weak_close_cleanup, self.stage_list, self._stage_in_queues, + self._stage_out_queues, self._ray_pg, + self._zmq_ctx, + self._handshake_stop, + self._zmq_handshake_socket, + self._handshake_thread, ) @overload diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 6c9723b6b5b..d90a805bd19 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -50,6 +50,10 @@ set_stage_devices, ) from vllm_omni.entrypoints.utils import detect_pid_host +from vllm_omni.entrypoints.zmq_utils import ( + ZmqQueue, + create_zmq_queue, +) from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams, OmniTokensPrompt from vllm_omni.metrics import count_tokens_from_outputs from vllm_omni.outputs import OmniRequestOutput @@ -286,8 +290,8 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): except TypeError as error: raise TypeError(f"Invalid default_sampling_params for stage {self.stage_id}: {error}") from error # Runtime orchestration state (added) - self._in_q: mp.Queue | None = None - self._out_q: mp.Queue | None = None + self._in_q: mp.queues.Queue | ZmqQueue | str | None = None + self._out_q: mp.queues.Queue | ZmqQueue | str | None = None self._proc: mp.Process | None = None self._shm_threshold_bytes: int = 65536 self._stage_init_timeout: int = stage_init_timeout @@ -349,12 +353,16 @@ def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: self.engine_outputs = engine_outputs # ----------------- New Orchestration APIs ----------------- - def attach_queues(self, in_q: mp.Queue, out_q: mp.Queue) -> None: + def attach_queues( + self, + in_q: mp.queues.Queue | ZmqQueue | str | None, + out_q: mp.queues.Queue | ZmqQueue | str | None, + ) -> None: """Attach input and output queues for IPC communication. Args: - in_q: Input queue for receiving tasks from orchestrator - out_q: Output queue for sending results to orchestrator + in_q: Input queue for receiving tasks from orchestrator (queue object or endpoint string) + out_q: Output queue for sending results to orchestrator (queue object or endpoint string) """ self._in_q = in_q self._out_q = out_q @@ -401,6 +409,7 @@ def init_stage_worker( batch_timeout: int = 10, connectors_config: dict | None = None, worker_backend: str = "multi_process", + ignore_runtime_config: bool = False, **kwargs: Any, ) -> None: """Initialize and start the stage worker process. @@ -416,6 +425,7 @@ def init_stage_worker( batch_timeout: Timeout in seconds for batching requests connectors_config: Configuration for stage connectors worker_backend: Backend type ("multi_process" or "ray") + ignore_runtime_config: Whether to ignore runtime configuration (default: False) **kwargs: Additional arguments (e.g. ray_placement_group) Raises: @@ -433,7 +443,10 @@ def init_stage_worker( ctx = ctx or mp.get_context("spawn") # Prepare lightweight dict config for worker engine_args = _to_dict(self.engine_args) - runtime_cfg = _to_dict(getattr(self.stage_config, "runtime", {})) + if ignore_runtime_config: + runtime_cfg = {} + else: + runtime_cfg = _to_dict(getattr(self.stage_config, "runtime", {})) stage_payload: dict[str, Any] = { "stage_id": self.stage_id, "engine_args": engine_args, @@ -442,6 +455,8 @@ def init_stage_worker( "connectors_config": connectors_config or {}, "stage_type": self.stage_type, "engine_input_source": self.engine_input_source, + "final_output": self.final_output, + "final_output_type": self.final_output_type, } try: old_env = os.environ.get("VLLM_LOGGING_PREFIX") @@ -453,9 +468,10 @@ def init_stage_worker( _stage_worker_async_entry, ray_placement_group, self.stage_id, - self, model=model, stage_payload=stage_payload, + in_q=self._in_q, + out_q=self._out_q, batch_timeout=batch_timeout, stage_init_timeout=self._stage_init_timeout, ) @@ -476,9 +492,10 @@ def init_stage_worker( self._proc = ctx.Process( target=_stage_worker_async_entry, args=( - self, model, stage_payload, + self._in_q.endpoint if isinstance(self._in_q, ZmqQueue) else self._in_q, + self._out_q.endpoint if isinstance(self._out_q, ZmqQueue) else self._out_q, batch_timeout, self._stage_init_timeout, ), @@ -489,8 +506,8 @@ def init_stage_worker( args=( model, stage_payload, - self._in_q, - self._out_q, + self._in_q.endpoint if isinstance(self._in_q, ZmqQueue) else self._in_q, + self._out_q.endpoint if isinstance(self._out_q, ZmqQueue) else self._out_q, batch_timeout, self._stage_init_timeout, ), @@ -514,6 +531,13 @@ def stop_stage_worker(self) -> None: self._in_q.put_nowait(SHUTDOWN_TASK) except Exception as e: logger.warning("Failed to send shutdown to in_q: %s", e) + close_fn = getattr(self._in_q, "close", None) + if callable(close_fn): + close_fn() + if self._out_q is not None: + close_fn = getattr(self._out_q, "close", None) + if callable(close_fn): + close_fn() if hasattr(self, "_ray_actor") and self._ray_actor: kill_ray_actor(self._ray_actor) @@ -636,8 +660,8 @@ def process_engine_inputs( def _stage_worker( model: str, stage_payload: dict[str, Any], - in_q: mp.Queue, - out_q: mp.Queue, + in_q: mp.queues.Queue | ZmqQueue | str, + out_q: mp.queues.Queue | ZmqQueue | str, batch_timeout: int = 10, stage_init_timeout: int = 300, ) -> None: @@ -648,6 +672,8 @@ def _stage_worker( import os as _os import time as _time + import zmq + from vllm_omni.plugins import load_omni_general_plugins load_omni_general_plugins() @@ -674,6 +700,21 @@ def _stage_worker( if stage_type != "diffusion": _resolve_worker_cls(engine_args) + # Resolve ZMQ queue endpoints if needed + zmq_ctx = None + if isinstance(in_q, str) or isinstance(out_q, str): + zmq_ctx = zmq.Context() + if isinstance(in_q, str): + in_q = create_zmq_queue(zmq_ctx, in_q, zmq.PULL) + if isinstance(out_q, str): + out_q = create_zmq_queue(zmq_ctx, out_q, zmq.PUSH) + # When using ZMQ (cross-node IPC), disable SHM so data is sent inline. + shm_threshold_bytes = sys.maxsize + logger.info( + "[Stage-%s] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)", + stage_id, + ) + # Aggregates for running average _agg_total_tokens = 0 _agg_total_gen_time_ms = 0.0 @@ -1008,19 +1049,21 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: def _stage_worker_async_entry( - omni_stage: OmniStage, model: str, stage_payload: dict[str, Any], + in_q: mp.queues.Queue | ZmqQueue | str, + out_q: mp.queues.Queue | ZmqQueue | str, batch_timeout: int = 10, stage_init_timeout: int = 300, ) -> None: - asyncio.run(_stage_worker_async(omni_stage, model, stage_payload, batch_timeout, stage_init_timeout)) + asyncio.run(_stage_worker_async(model, stage_payload, in_q, out_q, batch_timeout, stage_init_timeout)) async def _stage_worker_async( - omni_stage: OmniStage, model: str, stage_payload: dict[str, Any], + in_q: mp.queues.Queue | ZmqQueue | str, + out_q: mp.queues.Queue | ZmqQueue | str, batch_timeout: int = 10, stage_init_timeout: int = 300, ) -> None: @@ -1030,6 +1073,8 @@ async def _stage_worker_async( import os as _os import time as _time + import zmq + from vllm_omni.plugins import load_omni_general_plugins load_omni_general_plugins() @@ -1049,12 +1094,26 @@ async def _stage_worker_async( shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) connectors_config = stage_payload.get("connectors_config", {}) stage_type = stage_payload.get("stage_type", "llm") + final_output = stage_payload.get("final_output", False) + final_output_type = stage_payload.get("final_output_type", None) if stage_type != "diffusion": _resolve_worker_cls(engine_args) - in_q = omni_stage._in_q - out_q = omni_stage._out_q + # Resolve ZMQ queue endpoints if needed + zmq_ctx = None + if isinstance(in_q, str) or isinstance(out_q, str): + zmq_ctx = zmq.Context() + if isinstance(in_q, str): + in_q = create_zmq_queue(zmq_ctx, in_q, zmq.PULL) + if isinstance(out_q, str): + out_q = create_zmq_queue(zmq_ctx, out_q, zmq.PUSH) + # When using ZMQ (cross-node IPC), disable SHM so data is sent inline. + shm_threshold_bytes = sys.maxsize + logger.info( + "[Stage-%s] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)", + stage_id, + ) # Aggregates for running average _agg_total_tokens = 0 @@ -1130,14 +1189,13 @@ async def _stage_worker_async( engine_args.get("disable_log_stats", True) or getattr(omni_engine_args, "disable_log_stats", True) ), ) - omni_stage.set_async_engine(stage_engine) - if hasattr(omni_stage.async_engine, "log_stats") and omni_stage.async_engine.log_stats: + if hasattr(stage_engine, "log_stats") and stage_engine.log_stats: async def _force_log(): try: while True: await asyncio.sleep(10.0) - await omni_stage.async_engine.do_log_stats() + await stage_engine.do_log_stats() except asyncio.CancelledError: pass @@ -1334,7 +1392,7 @@ async def generation_single_request(task: dict[str, Any]): batch_request_ids, batch_request_outputs, _gen_ms_list, batch_metrics ): try: - r_outputs = [output_strip(output, omni_stage)] + r_outputs = [output_strip(output, final_output, final_output_type)] use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) if use_shm: out_q.put( @@ -1420,7 +1478,7 @@ def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float): return StageStats(total_token=_agg_total_tokens, total_gen_time_ms=_agg_total_gen_time_ms) -def output_strip(r_output: RequestOutput | OmniRequestOutput, omni_stage: OmniStage): +def output_strip(r_output: RequestOutput | OmniRequestOutput, final_output: bool, final_output_type: str | None): """ Strip unnecessary multimodal outputs from stages results, in order to: @@ -1429,7 +1487,7 @@ def output_strip(r_output: RequestOutput | OmniRequestOutput, omni_stage: OmniSt """ # check multimodal data is required by stage output config. - if omni_stage.final_output and omni_stage.final_output_type != "text": + if final_output and final_output_type != "text": return r_output # If the request has already finished, should not be altered. diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 0abe029d68d..6ee42a23145 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -236,6 +236,40 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None return stage_args +def load_and_resolve_stage_configs( + model: str, + stage_configs_path: str | None, + base_engine_args: dict | None, + default_stage_cfg_factory: Any = None, +) -> tuple[str, list]: + """Load stage configurations from model or YAML file with fallback to defaults. + + Args: + model: Model name or path + stage_configs_path: Optional path to YAML file containing stage configurations + base_engine_args: Base engine arguments to merge with stage configs + default_stage_cfg_factory: Optional callable that takes no args and returns + default stage config list when no configs are found + + Returns: + Tuple of (config_path, stage_configs) + """ + if stage_configs_path is None: + config_path = resolve_model_config_path(model) + stage_configs = load_stage_configs_from_model(model, base_engine_args=base_engine_args) + if not stage_configs: + if default_stage_cfg_factory is not None: + default_stage_cfg = default_stage_cfg_factory() + stage_configs = OmegaConf.create(default_stage_cfg) + else: + stage_configs = [] + else: + config_path = stage_configs_path + stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=base_engine_args) + + return config_path, stage_configs + + def get_final_stage_id_for_e2e( output_modalities: list[str] | None, default_modalities: list[str], stage_list: list ) -> int: @@ -282,6 +316,54 @@ def get_final_stage_id_for_e2e( return final_stage_id_for_e2e +# TODO(wuhang): Remove after PR #1115. +def build_base_engine_args(source: Any) -> dict[str, Any] | None: + """Build base engine args with tokenizer and parallel configuration. + + Automatically detects whether source is a dict-like object or namespace object. + + Args: + source: Source object (args namespace or kwargs dict) containing configuration. + + Returns: + Dictionary containing tokenizer and parallel configuration overrides, + or None if no configuration is present. + """ + # Auto-detect source type: dict-like objects have 'get' method + is_dict_like = hasattr(source, "get") and callable(getattr(source, "get")) + + # Extract tokenizer + if is_dict_like: + tokenizer = source.get("tokenizer", None) + else: + tokenizer = getattr(source, "tokenizer", None) + + base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None + + # Extract parallel configuration + parallel_keys = [ + "tensor_parallel_size", + "pipeline_parallel_size", + "data_parallel_size", + "data_parallel_size_local", + "data_parallel_backend", + "distributed_executor_backend", + ] + + if is_dict_like: + parallel_overrides = {k: source[k] for k in parallel_keys if k in source and source[k] is not None} + else: + parallel_overrides = { + k: getattr(source, k) for k in parallel_keys if hasattr(source, k) and getattr(source, k) is not None + } + + if parallel_overrides: + base_engine_args = base_engine_args or {} + base_engine_args.update(parallel_overrides) + + return base_engine_args + + # The following code detects if the process is running in a container and if # PID host is available. If so, we can use process-scoped memory tracking; # otherwise we need sequential init locks. diff --git a/vllm_omni/entrypoints/zmq_utils.py b/vllm_omni/entrypoints/zmq_utils.py new file mode 100644 index 00000000000..2ef5685cdaa --- /dev/null +++ b/vllm_omni/entrypoints/zmq_utils.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""ZMQ-based queue utilities for Omni IPC.""" + +from __future__ import annotations + +import queue +from typing import Any + +import zmq +from vllm.utils.network_utils import make_zmq_socket + + +class ZmqQueue: + """Queue-like wrapper on a ZMQ socket.""" + + def __init__( + self, + ctx: zmq.Context, + socket_type: int, + *, + bind: str | None = None, + connect: str | None = None, + recv_timeout_ms: int | None = None, + send_timeout_ms: int | None = None, + ) -> None: + # Determine path and bind mode + path = bind if bind is not None else connect + if path is None: + raise ValueError("Either bind or connect must be specified") + bind_mode = bind is not None + + self._socket = make_zmq_socket(ctx, path, socket_type, bind=bind_mode, linger=5000) + + # Reusable poller for efficient polling operations + self._poller = zmq.Poller() + self._poller.register(self._socket, zmq.POLLIN) + + # Store default timeout settings + self._default_recv_timeout = recv_timeout_ms + self._default_send_timeout = send_timeout_ms + + # Apply timeout settings if specified + if recv_timeout_ms is not None: + self._socket.rcvtimeo = recv_timeout_ms + if send_timeout_ms is not None: + self._socket.sndtimeo = send_timeout_ms + + self.endpoint = path + + def put(self, obj: Any) -> None: + """Send an object to the queue. Blocks until sent or timeout.""" + try: + self._socket.send_pyobj(obj) + except zmq.Again as e: + raise queue.Full() from e + + def put_nowait(self, obj: Any) -> None: + """Send an object to the queue without blocking.""" + try: + self._socket.send_pyobj(obj, flags=zmq.NOBLOCK) + except zmq.Again as e: + raise queue.Full() from e + + def get(self, timeout: float | None = None) -> Any: + """Receive an object from the queue with optional timeout in seconds.""" + if timeout is None: + return self._socket.recv_pyobj() + + # Use the reusable poller for timeout handling + events = dict(self._poller.poll(int(timeout * 1000))) + if events.get(self._socket) == zmq.POLLIN: + return self._socket.recv_pyobj() + raise queue.Empty() + + def get_nowait(self) -> Any: + """Receive an object from the queue without blocking.""" + try: + return self._socket.recv_pyobj(flags=zmq.NOBLOCK) + except zmq.Again as e: + raise queue.Empty() from e + + def empty(self) -> bool: + """Check if the queue is empty without blocking.""" + events = dict(self._poller.poll(0)) + return events.get(self._socket) != zmq.POLLIN + + def close(self) -> None: + self._socket.close(0) + + +def create_zmq_queue(ctx: zmq.Context, endpoint: str, socket_type: int) -> ZmqQueue: + """Create a ZmqQueue from an endpoint string and socket type.""" + return ZmqQueue(ctx, socket_type, connect=endpoint)