Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,13 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))

def run_test_and_cleanup():
llm = LLM(**llm_kwargs)
try:
_run_abort_timeout_test(llm, timeout)
finally:
llm.llm_engine.engine_core.shutdown()

# Build runtime_env only if we're using Ray
if distributed_executor_backend == "ray":
with _make_fake_nixl_pkg() as working_dir:
Expand All @@ -950,15 +957,16 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
},
}
ray.init(runtime_env=runtime_env)

_run_abort_timeout_test(llm_kwargs, timeout)
try:
run_test_and_cleanup()
finally:
ray.shutdown()
else:
_run_abort_timeout_test(llm_kwargs, timeout)
run_test_and_cleanup()


def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
def _run_abort_timeout_test(llm: LLM, timeout: int):
"""Helper function to run the abort timeout test logic."""
llm = LLM(**llm_kwargs)
remote_prefill_opts = {
"do_remote_decode": True,
"do_remote_prefill": False,
Expand Down Expand Up @@ -1042,7 +1050,7 @@ def test_register_kv_caches(dist_init):
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
),
) as mock_thread,
): # noqa: E501
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
Expand All @@ -1054,6 +1062,9 @@ def test_register_kv_caches(dist_init):
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance

# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False

# Execute register_kv_caches
connector.register_kv_caches(kv_caches)

Expand Down Expand Up @@ -1171,6 +1182,7 @@ def test_shutdown_cleans_up_resources(dist_init):
with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
Expand All @@ -1182,14 +1194,17 @@ def test_shutdown_cleans_up_resources(dist_init):
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]

mock_listener.is_alive.return_value = False

worker.shutdown()

# Test idempotency
worker.shutdown()
worker.shutdown()

mock_exec.shutdown.assert_called_with(wait=False)
mock_listener.join.assert_called_once_with(timeout=0)
mock_event.set.assert_called_once()
mock_listener.join.assert_called_once_with(timeout=1.0)

mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
Expand Down
34 changes: 30 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,8 @@ def request_finished(
class NixlConnectorWorker:
"""Implementation of Worker side methods"""

_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms

@dataclass
class TpKVTopology:
"""
Expand Down Expand Up @@ -719,6 +721,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):

# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._nixl_handshake_listener_stop_event: threading.Event | None = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
Expand Down Expand Up @@ -773,6 +776,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
def _nixl_handshake_listener(
metadata: NixlAgentMetadata,
ready_event: threading.Event,
stop_event: threading.Event,
base_port: int,
tp_rank: int,
):
Expand All @@ -791,7 +795,14 @@ def _nixl_handshake_listener(
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
while True:
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
while not stop_event.is_set():
events = dict(
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
)
if sock not in events:
continue
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
Expand Down Expand Up @@ -1101,14 +1112,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout,
)
ready_event = threading.Event()
ready_event, stop_event = threading.Event(), threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
args=(
metadata,
ready_event,
stop_event,
self.side_channel_port,
self.tp_rank,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
self._nixl_handshake_listener_stop_event = stop_event
ready_event.wait() # Wait for listener ZMQ socket to be ready.

def add_remote_agent(
Expand Down Expand Up @@ -1782,11 +1800,19 @@ def get_block_ids_with_load_errors(self) -> set[int]:
self._invalid_block_ids = set()
return result

def __del__(self):
self.shutdown()
Comment on lines +1803 to +1804
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using __del__ for resource cleanup is an anti-pattern in Python because its execution is not guaranteed. It may not be called if the object is part of a reference cycle, or during interpreter shutdown, which can lead to resource leaks that are hard to debug. The explicit shutdown() calls, like the one you've added in the tests, are the correct and reliable way to handle cleanup.

Furthermore, __del__ can be invoked by the garbage collector at any time and from any thread. This makes the thread-safety of shutdown() critical, but the current implementation is not thread-safe. For example:

  • One thread could be iterating over self._recving_transfers while another clears it, causing a RuntimeError.
  • self._nixl_handshake_listener_t could be set to None by one thread after another has checked it for None but before calling .join() on it, resulting in an AttributeError.

I strongly recommend removing the __del__ method and relying solely on explicit shutdown() calls. If shutdown() might be called concurrently from other paths, it should be protected with a threading.Lock.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling shutdown() from here is convenient for tests - it means resources are cleaned up even without explicit shutdown calls. Outside of tests, we should be calling shutdown() explicitly

Another example of us doing this:

vllm/v1/engine/async_llm.py:235:    def __del__(self):
vllm/v1/engine/async_llm.py-236-        self.shutdown()


def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_stop_event is not None:
self._nixl_handshake_listener_stop_event.set()
self._nixl_handshake_listener_stop_event = None
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join(timeout=0)
# Generous timeout to allow the thread to exit
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
assert not self._nixl_handshake_listener_t.is_alive()
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles:
Expand Down