Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ py_test_module_list(
"test_asyncio.py",
"test_autoscaler.py",
"test_autoscaler_yaml.py",
"test_client_init.py",
"test_client_metadata.py",
"test_client.py",
"test_client_references.py",
Expand Down
51 changes: 21 additions & 30 deletions python/ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,13 @@
import time
import sys
import logging
import threading

import ray.util.client.server.server as ray_client_server
from ray.util.client import RayAPIStub
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import ray_start_client_server


def test_num_clients(shutdown_only):
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server = ray_client_server.serve("localhost:50051")
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2

# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)

api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3

# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
finally:
server.stop(0)


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_real_ray_fallback(ray_start_regular_shared):
with ray_start_client_server() as ray:
Expand Down Expand Up @@ -373,5 +344,25 @@ def test_internal_kv(ray_start_regular_shared):
assert ray._internal_kv_get("apple") == b""


def test_startup_retry(ray_start_regular_shared):
from ray.util.client import ray as ray_client
ray_client._inside_client_test = True

with pytest.raises(ConnectionError):
ray_client.connect("localhost:50051", connection_retries=1)

def run_client():
ray_client.connect("localhost:50051")
ray_client.disconnect()

thread = threading.Thread(target=run_client, daemon=True)
thread.start()
time.sleep(3)
server = ray_client_server.serve("localhost:50051")
thread.join()
server.stop(0)
ray_client._inside_client_test = False


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
37 changes: 37 additions & 0 deletions python/ray/tests/test_client_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Client tests that run their own init (as with init_and_serve) live here"""
import time

import ray.util.client.server.server as ray_client_server

from ray.util.client import RayAPIStub


def test_num_clients():
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server, _ = ray_client_server.init_and_serve("localhost:50051")
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2

# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)

api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3

# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
api3.disconnect()
finally:
ray_client_server.shutdown_with_server(server)
time.sleep(2)
66 changes: 55 additions & 11 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import base64
import json
import logging
import time
import uuid
from collections import defaultdict
from typing import Any
Expand Down Expand Up @@ -33,6 +34,13 @@
MAX_TIMEOUT_SEC = 30


def backoff(timeout: int) -> int:
timeout = timeout + 5
if timeout > MAX_TIMEOUT_SEC:
timeout = MAX_TIMEOUT_SEC
return timeout


class Worker:
def __init__(self,
conn_str: str = "",
Expand All @@ -59,23 +67,59 @@ def __init__(self,
else:
self.channel = grpc.insecure_channel(conn_str)

# Retry the connection until the channel responds to something
# looking like a gRPC connection, though it may be a proxy.
conn_attempts = 0
timeout = INITIAL_TIMEOUT_SEC
while conn_attempts < connection_retries + 1:
ray_ready = False
while conn_attempts < max(connection_retries, 1):
conn_attempts += 1
try:
# Let gRPC wait for us to see if the channel becomes ready.
# If it throws, we couldn't connect.
grpc.channel_ready_future(self.channel).result(timeout=timeout)
break
# The HTTP2 channel is ready. Wrap the channel with the
# RayletDriverStub, allowing for unary requests.
self.server = ray_client_pb2_grpc.RayletDriverStub(
self.channel)
# Now the HTTP2 channel is ready, or proxied, but the
# servicer may not be ready. Call is_initialized() and if
# it throws, the servicer is not ready. On success, the
# `ray_ready` result is checked.
ray_ready = self.is_initialized()
if ray_ready:
# Ray is ready! Break out of the retry loop
break
# Ray is not ready yet, wait a timeout
time.sleep(timeout)
except grpc.FutureTimeoutError:
if conn_attempts >= connection_retries:
raise ConnectionError("ray client connection timeout")
logger.info(f"Couldn't connect in {timeout} seconds, retrying")
timeout = timeout + 5
if timeout > MAX_TIMEOUT_SEC:
timeout = MAX_TIMEOUT_SEC

self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)

logger.info(
f"Couldn't connect channel in {timeout} seconds, retrying")
# Note that channel_ready_future constitutes its own timeout,
# which is why we do not sleep here.
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
# UNAVAILABLE is gRPC's retryable error,
# so we do that here.
logger.info("Ray client server unavailable, "
f"retrying in {timeout}s...")
logger.debug(f"Received when checking init: {e.details()}")
# Ray is not ready yet, wait a timeout
time.sleep(timeout)
else:
# Any other gRPC error gets a reraise
raise e
# Fallthrough, backoff, and retry at the top of the loop
logger.info("Waiting for Ray to become ready on the server, "
f"retry in {timeout}s...")
timeout = backoff(timeout)

# If we made it through the loop without ray_ready it means we've used
# up our retries and should error back to the user.
if not ray_ready:
raise ConnectionError("ray client connection timeout")

# Initialize the streams to finish protocol negotiation.
self.data_client = DataClient(self.channel, self._client_id,
self.metadata)
self.reference_count: Dict[bytes, int] = defaultdict(int)
Expand Down