Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,26 @@

logger = logging.getLogger(__name__)

INITIAL_TIMEOUT_SEC = 5
MAX_TIMEOUT_SEC = 30


class Worker:
def __init__(self,
conn_str: str = "",
secure: bool = False,
metadata: List[Tuple[str, str]] = None):
metadata: List[Tuple[str, str]] = None,
connection_retries=3):
"""Initializes the worker side grpc client.

Args:
conn_str: The host:port connection string for the ray server.
secure: whether to use SSL secure channel or not.
metadata: additional metadata passed in the grpc request headers.
connection_retries: Number of times to attempt to reconnect to the
ray server if it doesn't respond immediately. Setting to 0 tries
at least once. For infinite retries, catch the ConnectionError
exception.
"""
self.metadata = metadata if metadata else []
self.channel = None
Expand All @@ -49,6 +58,21 @@ def __init__(self,
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.channel = grpc.insecure_channel(conn_str)

conn_attempts = 0
timeout = INITIAL_TIMEOUT_SEC
while conn_attempts < connection_retries + 1:
conn_attempts += 1
try:
grpc.channel_ready_future(self.channel).result(timeout=timeout)
except grpc.FutureTimeoutError:
if conn_attempts >= connection_retries:
raise ConnectionError("ray client connection timeout")
logger.info(f"Couldn't connect in {timeout} seconds, retrying")
timeout = timeout + 5
if timeout > MAX_TIMEOUT_SEC:
timeout = MAX_TIMEOUT_SEC

self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)

self.data_client = DataClient(self.channel, self._client_id,
Expand Down