diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 7ef15c717309..745b78b51ea1 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -29,17 +29,26 @@ logger = logging.getLogger(__name__) +INITIAL_TIMEOUT_SEC = 5 +MAX_TIMEOUT_SEC = 30 + class Worker: def __init__(self, conn_str: str = "", secure: bool = False, - metadata: List[Tuple[str, str]] = None): + metadata: List[Tuple[str, str]] = None, + connection_retries=3): """Initializes the worker side grpc client. Args: + conn_str: The host:port connection string for the ray server. secure: whether to use SSL secure channel or not. metadata: additional metadata passed in the grpc request headers. + connection_retries: Number of times to attempt to reconnect to the + ray server if it doesn't respond immediately. Setting to 0 tries + at least once. For infinite retries, catch the ConnectionError + exception. """ self.metadata = metadata if metadata else [] self.channel = None @@ -49,6 +58,21 @@ def __init__(self, self.channel = grpc.secure_channel(conn_str, credentials) else: self.channel = grpc.insecure_channel(conn_str) + + conn_attempts = 0 + timeout = INITIAL_TIMEOUT_SEC + while conn_attempts < connection_retries + 1: + conn_attempts += 1 + try: + grpc.channel_ready_future(self.channel).result(timeout=timeout) + except grpc.FutureTimeoutError: + if conn_attempts >= connection_retries: + raise ConnectionError("ray client connection timeout") + logger.info(f"Couldn't connect in {timeout} seconds, retrying") + timeout = timeout + 5 + if timeout > MAX_TIMEOUT_SEC: + timeout = MAX_TIMEOUT_SEC + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) self.data_client = DataClient(self.channel, self._client_id,