Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 6 additions & 33 deletions python/ray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,16 @@ def ray_start_10_cpus(request):


@contextmanager
def _ray_start_cluster(**kwargs):
def _ray_start_cluster(num_nodes=0, do_init=False, **kwargs):
init_kwargs = get_default_fixture_ray_kwargs()
num_nodes = 0
do_init = False
# num_nodes & do_init are not arguments for ray.init, so delete them.
if "num_nodes" in kwargs:
num_nodes = kwargs["num_nodes"]
del kwargs["num_nodes"]
if "do_init" in kwargs:
do_init = kwargs["do_init"]
del kwargs["do_init"]
elif num_nodes > 0:
if num_nodes > 0:
do_init = True
init_kwargs.update(kwargs)
cluster = Cluster()
remote_nodes = []
for _ in range(num_nodes):
for i in range(num_nodes):
remote_nodes.append(cluster.add_node(**init_kwargs))
# Make sure the driver is conencting to the head node.
if do_init:
ray.init(redis_address=cluster.redis_address)
yield cluster
Expand All @@ -116,14 +108,14 @@ def ray_start_cluster(request):
@pytest.fixture
def ray_start_cluster_head(request):
param = getattr(request, "param", {})
with _ray_start_cluster(do_init=True, num_nodes=1, **param) as res:
with _ray_start_cluster(num_nodes=1, do_init=True, **param) as res:
yield res


@pytest.fixture
def ray_start_cluster_2_nodes(request):
param = getattr(request, "param", {})
with _ray_start_cluster(do_init=True, num_nodes=2, **param) as res:
with _ray_start_cluster(num_nodes=2, do_init=True, **param) as res:
yield res


Expand Down Expand Up @@ -161,22 +153,3 @@ def call_ray_start(request):
ray.shutdown()
# Kill the Ray cluster.
subprocess.Popen(["ray", "stop"]).wait()


@pytest.fixture()
def two_node_cluster():
internal_config = json.dumps({
"initial_reconstruction_timeout_milliseconds": 200,
"num_heartbeats_timeout": 10,
})
cluster = ray.tests.cluster_utils.Cluster(
head_node_args={"_internal_config": internal_config})
for _ in range(2):
remote_node = cluster.add_node(
num_cpus=1, _internal_config=internal_config)
ray.init(redis_address=cluster.redis_address)
yield cluster, remote_node

# The code after the yield will run as teardown code.
ray.shutdown()
cluster.shutdown()
9 changes: 4 additions & 5 deletions python/ray/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,9 +1995,10 @@ def method(self):

def test_custom_label_placement(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=2, resources={"CustomResource1": 2})
node = cluster.add_node(num_cpus=2, resources={"CustomResource1": 2})
cluster.add_node(num_cpus=2, resources={"CustomResource2": 2})
ray.init(redis_address=cluster.redis_address)
resource1_plasma_socket_name = node.plasma_store_socket_name

@ray.remote(resources={"CustomResource1": 1})
class ResourceActor1(object):
Expand All @@ -2009,17 +2010,15 @@ class ResourceActor2(object):
def get_location(self):
return ray.worker.global_worker.plasma_client.store_socket_name

local_plasma = ray.worker.global_worker.plasma_client.store_socket_name

# Create some actors.
actors1 = [ResourceActor1.remote() for _ in range(2)]
actors2 = [ResourceActor2.remote() for _ in range(2)]
locations1 = ray.get([a.get_location.remote() for a in actors1])
locations2 = ray.get([a.get_location.remote() for a in actors2])
for location in locations1:
assert location == local_plasma
assert location == resource1_plasma_socket_name
for location in locations2:
assert location != local_plasma
assert location != resource1_plasma_socket_name


def test_creating_more_actors_than_resources(shutdown_only):
Expand Down
21 changes: 9 additions & 12 deletions python/ray/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,10 +1872,9 @@ def f():
def test_zero_cpus_actor(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
cluster.add_node(num_cpus=2)
node = cluster.add_node(num_cpus=2)
ray.init(redis_address=cluster.redis_address)

local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
available_plasma_socket_name = node.plasma_store_socket_name

@ray.remote
class Foo(object):
Expand All @@ -1884,7 +1883,7 @@ def method(self):

# Make sure tasks and actors run on the remote raylet.
a = Foo.remote()
assert ray.get(a.method.remote()) != local_plasma
assert ray.get(a.method.remote()) == available_plasma_socket_name


def test_fractional_resources(shutdown_only):
Expand Down Expand Up @@ -2067,8 +2066,9 @@ def run_nested2():
def test_custom_resources(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=3, resources={"CustomResource": 0})
cluster.add_node(num_cpus=3, resources={"CustomResource": 1})
node = cluster.add_node(num_cpus=3, resources={"CustomResource": 1})
ray.init(redis_address=cluster.redis_address)
available_plasma_socket_name = node.plasma_store_socket_name

@ray.remote
def f():
Expand All @@ -2088,12 +2088,10 @@ def h():
# The f tasks should be scheduled on both raylets.
assert len(set(ray.get([f.remote() for _ in range(50)]))) == 2

local_plasma = ray.worker.global_worker.plasma_client.store_socket_name

# The g tasks should be scheduled only on the second raylet.
raylet_ids = set(ray.get([g.remote() for _ in range(50)]))
assert len(raylet_ids) == 1
assert list(raylet_ids)[0] != local_plasma
assert list(raylet_ids)[0] == available_plasma_socket_name

# Make sure that resource bookkeeping works when a task that uses a
# custom resources gets blocked.
Expand All @@ -2107,12 +2105,13 @@ def test_two_custom_resources(ray_start_cluster):
"CustomResource1": 1,
"CustomResource2": 2
})
cluster.add_node(
node = cluster.add_node(
num_cpus=3, resources={
"CustomResource1": 3,
"CustomResource2": 4
})
ray.init(redis_address=cluster.redis_address)
available_plasma_socket_name = node.plasma_store_socket_name

@ray.remote(resources={"CustomResource1": 1})
def f():
Expand Down Expand Up @@ -2143,12 +2142,10 @@ def k():
assert len(set(ray.get([f.remote() for _ in range(50)]))) == 2
assert len(set(ray.get([g.remote() for _ in range(50)]))) == 2

local_plasma = ray.worker.global_worker.plasma_client.store_socket_name

# The h tasks should be scheduled only on the second raylet.
raylet_ids = set(ray.get([h.remote() for _ in range(50)]))
assert len(raylet_ids) == 1
assert list(raylet_ids)[0] != local_plasma
assert list(raylet_ids)[0] == available_plasma_socket_name

# Make sure that tasks with unsatisfied custom resource requirements do
# not get scheduled.
Expand Down
3 changes: 2 additions & 1 deletion python/ray/tests/test_object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ def test_object_transfer_retry(ray_start_cluster):
object_store_memory = 10**8
cluster.add_node(
object_store_memory=object_store_memory, _internal_config=config)
# Make sure the driver is connecting to the head node.
ray.init(redis_address=cluster.redis_address)
cluster.add_node(
num_gpus=1,
object_store_memory=object_store_memory,
_internal_config=config)
ray.init(redis_address=cluster.redis_address)

@ray.remote(num_gpus=1)
def f(size):
Expand Down
5 changes: 3 additions & 2 deletions python/ray/tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def send_signals(self, value, count):
assert len(result_list) == count


def test_signal_on_node_failure(two_node_cluster):
def test_signal_on_node_failure(ray_start_cluster_2_nodes):
"""Test actor checkpointing on a remote node."""

class ActorSignal(object):
Expand All @@ -285,7 +285,8 @@ def local_plasma(self):
return ray.worker.global_worker.plasma_client.store_socket_name

# Place the actor on the remote node.
cluster, remote_node = two_node_cluster
cluster = ray_start_cluster_2_nodes
remote_node = list(cluster.worker_nodes)[0]
actor_cls = ray.remote(max_reconstructions=0)(ActorSignal)
actor = actor_cls.remote()
# Try until we put an actor on a different node.
Expand Down
76 changes: 65 additions & 11 deletions src/ray/gcs/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TestGcsWithAsio : public TestGcs {
void Start() override { io_service_.run(); }
void Stop() override { io_service_.stop(); }

private:
protected:
boost::asio::io_service io_service_;
// Give the event loop some work so that it's forced to run until Stop() is
// called.
Expand Down Expand Up @@ -292,11 +292,55 @@ void TestSet(const DriverID &driver_id, std::shared_ptr<gcs::AsyncGcsClient> cli
// Do a lookup at the object ID.
RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback));

// Test the Replace function of set.
const std::string suffix = "-replaced";
for (auto &manager : managers) {
auto old_data = std::make_shared<ObjectTableDataT>();
auto new_data = std::make_shared<ObjectTableDataT>();
old_data->manager = manager;
new_data->manager = manager + suffix;
// Check that we added the correct object entries.
auto add_callback = [object_id, new_data, suffix](
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(new_data->manager, d.manager);
test->IncrementNumCallbacks();
};
auto no_trigger_failure = [](gcs::AsyncGcsClient *client, const UniqueID &id) {
ASSERT_TRUE(false);
};
auto trigger_failure = [](gcs::AsyncGcsClient *client, const UniqueID &id) {
ASSERT_TRUE(true);
};
auto no_trigger_add = [](gcs::AsyncGcsClient *client, const UniqueID &id,
const ObjectTableDataT &d) { ASSERT_TRUE(false); };
RAY_CHECK_OK(client->object_table().Replace(driver_id, object_id, old_data, new_data,
add_callback, no_trigger_failure));
// The old entry should be replaced with the new entry.
RAY_CHECK_OK(client->object_table().Replace(driver_id, object_id, old_data, new_data,
no_trigger_add, trigger_failure));
}

// Check that lookup returns the added object entries.
auto lookup_callback2 = [object_id, managers, suffix](
gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data.size(), managers.size());
for (auto &d : data) {
ASSERT_EQ(d.manager.find(suffix), managers[0].length());
}
test->IncrementNumCallbacks();
};

// Do a lookup at the object ID.
RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback2));

for (auto &manager : managers) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = manager;
data->manager = manager + suffix;
// Check that we added the correct object entries.
auto remove_entry_callback = [object_id, data](
auto remove_entry_callback = [object_id, data, suffix](
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data->manager, d.manager);
Expand All @@ -307,7 +351,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr<gcs::AsyncGcsClient> cli
}

// Check that the entries are removed.
auto lookup_callback2 = [object_id, managers](
auto lookup_callback3 = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
Expand All @@ -317,11 +361,11 @@ void TestSet(const DriverID &driver_id, std::shared_ptr<gcs::AsyncGcsClient> cli
};

// Do a lookup at the object ID.
RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback2));
RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback3));
// Run the event loop. The loop will only stop if the Lookup callback is
// called (or an assertion failure).
test->Start();
ASSERT_EQ(test->NumCallbacks(), managers.size() * 2 + 2);
ASSERT_EQ(test->NumCallbacks(), managers.size() * 3 + 3);
}

TEST_F(TestGcsWithAsio, TestSet) {
Expand Down Expand Up @@ -1227,14 +1271,14 @@ void TestClientTableDisconnect(const DriverID &driver_id,
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
[](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, /*is_insertion=*/true);
ClientTableNotification(client, id, data, /*is_connected=*/true);
// Disconnect from the client table. We should receive a notification
// for the removal of our own entry.
RAY_CHECK_OK(client->client_table().Disconnect());
});
client->client_table().RegisterClientRemovedCallback(
[](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, /*is_insertion=*/false);
ClientTableNotification(client, id, data, /*is_connected=*/false);
test->Stop();
});
// Connect to the client table. We should receive notification for the
Expand Down Expand Up @@ -1282,15 +1326,21 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) {
}

void TestClientTableMarkDisconnected(const DriverID &driver_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
std::shared_ptr<gcs::AsyncGcsClient> client,
std::shared_ptr<gcs::AsyncGcsClient> dead_client) {
// Since marking a non-existing node will not trigger the ClientRemovedCallback,
// a new client is used to connect and mark to dead.
ClientTableDataT dead_client_info = dead_client->client_table().GetLocalClient();
RAY_CHECK_OK(dead_client->client_table().Connect(dead_client_info));

ClientTableDataT local_client_info = client->client_table().GetLocalClient();
local_client_info.node_manager_address = "127.0.0.1";
local_client_info.node_manager_port = 0;
local_client_info.object_manager_port = 0;
// Connect to the client table to start receiving notifications.
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
// Mark a different client as dead.
ClientID dead_client_id = ClientID::from_random();
ClientID dead_client_id = ClientID::from_binary(dead_client_info.client_id);
RAY_CHECK_OK(client->client_table().MarkDisconnected(dead_client_id));
// Make sure we only get a notification for the removal of the client we
// marked as dead.
Expand All @@ -1304,7 +1354,11 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id,

TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) {
test = this;
TestClientTableMarkDisconnected(driver_id_, client_);
auto dead_client =
std::make_shared<gcs::AsyncGcsClient>("127.0.0.1", 6379, command_type_,
/*is_test_client=*/true);
RAY_CHECK_OK(dead_client->Attach(io_service_));
TestClientTableMarkDisconnected(driver_id_, client_, dead_client);
}

#undef TEST_MACRO
Expand Down
7 changes: 7 additions & 0 deletions src/ray/gcs/format/gcs.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ table ClientTableData {
resources_total_capacity: [double];
}

table SetReplaceEntryData {
// The old packed data to remove.
old_data: string;
// The new packed data to add.
new_data: string;
}

table HeartbeatTableData {
// Node manager client id
client_id: string;
Expand Down
Loading