diff --git a/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py index b74a48f49c..0ba19c7fff 100644 --- a/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py +++ b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py @@ -2,13 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +import threading +import pytest import zmq from vllm_omni.distributed.omni_coordinator import ( OmniCoordClientForStage, StageStatus, ) +from vllm_omni.distributed.omni_coordinator import ( + omni_coord_client_for_stage as stage_client_module, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]: @@ -19,7 +26,8 @@ def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]: return ctx, router, endpoint -def _recv_event(router: zmq.Socket) -> dict: +def _recv_event(router: zmq.Socket, timeout_ms: int = 2000) -> dict: + assert router.poll(timeout=timeout_ms) != 0, "Timed out waiting for coordinator event" frames = router.recv_multipart() # ROUTER adds identity frame; the last frame is the payload. payload = frames[-1] @@ -108,3 +116,197 @@ def test_stage_client_close_sends_down_status(): router.close(0) ctx.term() + + +def test_stage_client_reconnects_after_send_failure(mocker): + """Verify send failure path invokes reconnect before retrying send.""" + ctx, router, endpoint = _bind_router() + + client = OmniCoordClientForStage( + endpoint, + "tcp://stage:reconnect-in", + "tcp://stage:reconnect-out", + 0, + ) + + # Discard initial registration event from the real socket. + _recv_event(router) + + class _FlakySocket: + def __init__(self): + self.send_calls = 0 + self.closed = False + + def send(self, *_args, **_kwargs): + self.send_calls += 1 + if self.send_calls == 1: + raise RuntimeError("simulated send failure") + + def close(self, *_args, **_kwargs): + self.closed = True + + flaky_socket = _FlakySocket() + client._socket = flaky_socket + client._reconnect = mocker.Mock(return_value=True) + + client.update_info(queue_length=1) + + client._reconnect.assert_called_once_with(max_retries=3) + assert flaky_socket.send_calls == 2 + + client.close() + router.close(0) + ctx.term() + + +def test_stage_client_raises_when_reconnect_fails(mocker): + """Verify send failure is propagated when reconnect cannot recover.""" + ctx, router, endpoint = _bind_router() + + client = OmniCoordClientForStage( + endpoint, + "tcp://stage:reconnect-fail-in", + "tcp://stage:reconnect-fail-out", + 0, + ) + + # Discard initial registration event from the real socket. + _recv_event(router) + + class _AlwaysFailSocket: + def send(self, *_args, **_kwargs): + raise RuntimeError("simulated send failure") + + def close(self, *_args, **_kwargs): + pass + + client._socket = _AlwaysFailSocket() + client._reconnect = mocker.Mock(return_value=False) + + with pytest.raises(RuntimeError, match="simulated send failure"): + client.update_info(queue_length=2) + + client._reconnect.assert_called_once_with(max_retries=3) + client.close() + router.close(0) + ctx.term() + + +def test_stage_client_close_handles_runtime_error_in_final_update(mocker): + """Verify close() still releases resources when final update raises RuntimeError.""" + ctx, router, endpoint = _bind_router() + + client = OmniCoordClientForStage( + endpoint, + "tcp://stage:close-runtime-in", + "tcp://stage:close-runtime-out", + 0, + ) + + # Discard initial registration event from the real socket. + _recv_event(router) + + client._send_event = mocker.Mock(side_effect=RuntimeError("simulated close-time failure")) + client.close() + + assert client._closed + assert client._socket.closed + + router.close(0) + ctx.term() + + +def test_reconnect_respects_retry_limit(monkeypatch): + """Verify _reconnect stops after max_retries on repeated failures.""" + attempts = {"connect": 0} + + class _FailSocket: + def close(self, *_args, **_kwargs): + pass + + def connect(self, *_args, **_kwargs): + attempts["connect"] += 1 + raise zmq.ZMQError("simulated reconnect failure") + + class _FailContext: + def socket(self, *_args, **_kwargs): + return _FailSocket() + + def term(self): + pass + + client = OmniCoordClientForStage.__new__(OmniCoordClientForStage) + client._closed = False + client._coord_zmq_addr = "tcp://127.0.0.1:9999" + client._stop_event = threading.Event() + client._send_lock = threading.RLock() + client._socket = _FailSocket() + client._ctx = _FailContext() + + monkeypatch.setattr(stage_client_module.zmq, "Context", lambda: _FailContext()) + monkeypatch.setattr(stage_client_module.time, "sleep", lambda *_args, **_kwargs: None) + + assert client._reconnect(max_retries=3, retry_interval=5.0) is False + assert attempts["connect"] == 3 + + +def test_heartbeat_loop_retries_after_transient_send_failure(): + """Verify heartbeat loop continues after one transient send failure.""" + + class _FakeStopEvent: + def __init__(self): + self.wait_calls = 0 + self._set = False + + def wait(self, timeout=None): + _ = timeout + self.wait_calls += 1 + # Run two loop iterations, then stop. + return self._set or self.wait_calls >= 3 + + def is_set(self): + return self._set + + def set(self): + self._set = True + + client = OmniCoordClientForStage.__new__(OmniCoordClientForStage) + client._closed = False + client._heartbeat_interval = 0.0 + client._stop_event = _FakeStopEvent() + + calls = {"count": 0} + + def _fake_send(event_type): + assert event_type == "heartbeat" + calls["count"] += 1 + if calls["count"] == 1: + raise RuntimeError("transient heartbeat failure") + + client._send_event = _fake_send + + client._heartbeat_loop() + + assert calls["count"] == 2 + + +def test_update_info_rejected_while_closing(): + """Verify update_info is rejected once client enters closing state.""" + ctx, router, endpoint = _bind_router() + + client = OmniCoordClientForStage( + endpoint, + "tcp://stage:closing-in", + "tcp://stage:closing-out", + 0, + ) + _recv_event(router) + + client._closing = True + with pytest.raises(RuntimeError, match="closing"): + client.update_info(queue_length=3) + + client._closing = False + client.close() + router.close(0) + ctx.term() diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py index cd5c357bb4..cd3c99ab81 100644 --- a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py +++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py @@ -45,9 +45,10 @@ def __init__( self._status = StageStatus.UP self._queue_length = 0 self._closed = False + self._closing = False self._heartbeat_interval = 5.0 self._stop_event = threading.Event() - self._send_lock = threading.Lock() + self._send_lock = threading.RLock() self._send_event("update") @@ -57,38 +58,45 @@ def __init__( ) self._heartbeat_thread.start() - def _reconnect(self) -> bool: + def _reconnect(self, max_retries: int = 3, retry_interval: float = 5.0) -> bool: """Best-effort reconnect with up to ``max_retries`` attempts. - Each attempt closes the current socket/context, sleeps 5 seconds, - then creates a new DEALER socket and reconnects to the coordinator. - Caller must hold ``_send_lock``. + Each attempt closes the current socket/context, sleeps ``retry_interval`` + seconds, then creates a new DEALER socket and reconnects to the coordinator. Returns True on success, False if all attempts fail. """ - while not self._stop_event.is_set() and not self._closed: - try: - self._socket.close(0) - except zmq.ZMQError: - pass - try: - self._ctx.term() - except zmq.ZMQError: - pass + if max_retries <= 0: + return False - time.sleep(5.0) + for attempt in range(1, max_retries + 1): + with self._send_lock: + if self._stop_event.is_set() or self._closed: + return False + try: + self._socket.close(0) + except zmq.ZMQError: + pass + try: + self._ctx.term() + except zmq.ZMQError: + pass - try: - self._ctx = zmq.Context() - self._socket = self._ctx.socket(zmq.DEALER) - self._socket.connect(self._coord_zmq_addr) - return True - except zmq.ZMQError as e: - logger.error( - "Stage client reconnect failed, will retry in 5s (coord=%s)", - self._coord_zmq_addr, - exc_info=e, - ) - continue + try: + self._ctx = zmq.Context() + self._socket = self._ctx.socket(zmq.DEALER) + self._socket.connect(self._coord_zmq_addr) + return True + except zmq.ZMQError as e: + logger.error( + "Stage client reconnect failed (attempt=%d/%d, coord=%s)", + attempt, + max_retries, + self._coord_zmq_addr, + exc_info=e, + ) + + if retry_interval > 0: + time.sleep(retry_interval) return False def _send_event(self, event_type: str) -> None: @@ -102,20 +110,20 @@ def _send_event(self, event_type: str) -> None: to 3 times (5s sleep each) and retries the send once after a successful reconnect. Raises if reconnect or the retry send fails. """ - if self._closed: - raise RuntimeError("Client already closed") - - event = InstanceEvent( - input_addr=self._input_addr, - output_addr=self._output_addr, - stage_id=self._stage_id, - event_type=event_type, - status=self._status, - queue_length=self._queue_length, - ) - data = json.dumps(asdict(event)).encode("utf-8") - with self._send_lock: + if self._closed: + raise RuntimeError("Client already closed") + + event = InstanceEvent( + input_addr=self._input_addr, + output_addr=self._output_addr, + stage_id=self._stage_id, + event_type=event_type, + status=self._status, + queue_length=self._queue_length, + ) + data = json.dumps(asdict(event)).encode("utf-8") + try: self._socket.send(data, flags=zmq.NOBLOCK) return @@ -124,7 +132,7 @@ def _send_event(self, event_type: str) -> None: return except (RuntimeError, zmq.ZMQError) as e: # First send failed; try reconnecting a few times. - if not self._reconnect: + if not self._reconnect(max_retries=3): logger.error("Failed to send event and reconnect to coordinator", exc_info=e) raise @@ -149,12 +157,16 @@ def update_info( if status is None and queue_length is None: raise ValueError("At least one of status or queue_length must be provided") - if status is not None: - self._status = status - if queue_length is not None: - self._queue_length = queue_length + with self._send_lock: + if self._closed or self._closing: + raise RuntimeError("Client is closing or already closed") + + if status is not None: + self._status = status + if queue_length is not None: + self._queue_length = queue_length - self._send_event("update") + self._send_event("update") def _heartbeat_loop(self) -> None: """Periodically send heartbeat events while the client is alive.""" @@ -164,8 +176,11 @@ def _heartbeat_loop(self) -> None: try: self._send_event("heartbeat") - except (RuntimeError, zmq.ZMQError): - break + except (RuntimeError, zmq.ZMQError) as e: + if self._closed or self._stop_event.is_set(): + break + logger.warning("Heartbeat send failed; will retry on next interval", exc_info=e) + continue def close(self) -> None: """Send a final down event and close the underlying socket.""" @@ -177,17 +192,23 @@ def close(self) -> None: if hasattr(self, "_heartbeat_thread"): self._heartbeat_thread.join(timeout=1.0) - # Mark status as DOWN and send one last update. - self._status = StageStatus.DOWN - try: - self._send_event("update") - except zmq.ZMQError: - pass # Socket may already be broken, proceed with close + with self._send_lock: + if self._closed: + raise RuntimeError("Client already closed") - # Close DEALER socket and terminate this client's context. - self._socket.close(0) - try: - self._ctx.term() - except zmq.ZMQError: - pass - self._closed = True + self._closing = True + + # Mark status as DOWN and send one last update. + self._status = StageStatus.DOWN + try: + self._send_event("update") + except (RuntimeError, zmq.ZMQError): + pass # Socket may already be broken, proceed with close + + # Close DEALER socket and terminate this client's context. + self._socket.close(0) + try: + self._ctx.term() + except zmq.ZMQError: + pass + self._closed = True