diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index cd7089cb5129..989eed79607e 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -7,6 +7,7 @@ import json import logging import os +import queue import tempfile import threading import time @@ -163,6 +164,14 @@ def __init__( self._req_id_lock = threading.Lock() self._req_id = 0 + # ReleaseObject grabs a lock, so it should not be called directly from + # __del__ methods that may be executed at any time on the Python main thread. + self._release_queue = queue.SimpleQueue() + self._release_thread = threading.Thread( + target=self._release_server_worker, daemon=True + ) + self._release_thread.start() + def _connect_channel(self, reconnecting=False) -> None: """ Attempts to connect to the server specified by conn_str. If @@ -644,8 +653,37 @@ def call_release(self, id: bytes) -> None: def _release_server(self, id: bytes) -> None: if self.data_client is not None: - logger.debug(f"Releasing {id.hex()}") - self.data_client.ReleaseObject(ray_client_pb2.ReleaseRequest(ids=[id])) + logger.debug(f"Put {id.hex()} to release queue") + self._release_queue.put(id) + + def _release_server_worker(self): + """Background thread to release objects from the server. + + Runs forever until a sentinel is received. + """ + while not self.closed: + try: + id = self._release_queue.get(timeout=1) + if id is None: # Sentinel value for shutdown + logger.debug("Received sentinel, will stop release thread.") + break + + if self.data_client is not None: + logger.debug(f"Releasing {id.hex()}") + try: + self.data_client.ReleaseObject( + ray_client_pb2.ReleaseRequest(ids=[id]) + ) + except Exception as e: + # Log the error but continue processing + # This prevents the release thread from crashing + logger.warning( + f"Failed to release object {id.hex()}: {e}. " + "This is expected if the connection is closed." + ) + except queue.Empty: + continue + logger.debug("Release thread finished.") def call_retain(self, id: bytes) -> None: logger.debug(f"Retaining {id.hex()}") @@ -653,6 +691,13 @@ def call_retain(self, id: bytes) -> None: def close(self): self._in_shutdown = True + + self._release_queue.put(None) # Sentinel + timeout = 5 + self._release_thread.join(timeout=timeout) + if self._release_thread.is_alive(): + logger.warning(f"The release thread failed to join in {timeout}s.") + self.closed = True self.data_client.close() self.log_client.close()