Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ like so:
ray.nodes()
# Returns current information about the nodes in the cluster, such as:
# [{'ClientID': '2a9d2b34ad24a37ed54e4fcd32bf19f915742f5b',
# 'EntryType': 0,
# 'IsInsertion': True,
# 'NodeManagerAddress': '1.2.3.4',
# 'NodeManagerPort': 43280,
# 'ObjectManagerPort': 38062,
Expand Down
54 changes: 30 additions & 24 deletions java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.ray.runtime.gcs;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -13,9 +14,9 @@
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.runtimecontext.NodeInfo;
import org.ray.runtime.generated.Gcs;
import org.ray.runtime.generated.Gcs.ActorCheckpointIdData;
import org.ray.runtime.generated.Gcs.ClientTableData;
import org.ray.runtime.generated.Gcs.ClientTableData.EntryType;
import org.ray.runtime.generated.Gcs.TablePrefix;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
Expand Down Expand Up @@ -72,42 +73,47 @@ public List<NodeInfo> getAllNodeInfo() {
final UniqueId clientId = UniqueId
.fromByteBuffer(data.getClientId().asReadOnlyByteBuffer());

if (data.getEntryType() == EntryType.INSERTION) {
if (data.getIsInsertion()) {
//Code path of node insertion.
Map<String, Double> resources = new HashMap<>();
// Compute resources.
Preconditions.checkState(
data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount());
for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) {
resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i));
}
NodeInfo nodeInfo = new NodeInfo(
clientId, data.getNodeManagerAddress(), true, resources);
clientId, data.getNodeManagerAddress(), true, ImmutableMap.of());
clients.put(clientId, nodeInfo);
} else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) {
Preconditions.checkState(clients.containsKey(clientId));
NodeInfo nodeInfo = clients.get(clientId);
for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) {
nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i));
}
} else if (data.getEntryType() == EntryType.RES_DELETE) {
Preconditions.checkState(clients.containsKey(clientId));
NodeInfo nodeInfo = clients.get(clientId);
for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) {
nodeInfo.resources.remove(data.getResourcesTotalLabel(i));
}
} else {
// Code path of node deletion.
Preconditions.checkState(data.getEntryType() == EntryType.DELETION);
NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress,
false, clients.get(clientId).resources);
false, ImmutableMap.of());
clients.put(clientId, nodeInfo);
}
}

// Fill resources.
for (Map.Entry<UniqueId, NodeInfo> client : clients.entrySet()) {
if (client.getValue().isAlive) {
client.getValue().resources.putAll(getResourcesForClient(client.getKey()));
}
}

return new ArrayList<>(clients.values());
}

private Map<String, Double> getResourcesForClient(UniqueId clientId) {
final String prefix = TablePrefix.NODE_RESOURCE.toString();
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), clientId.getBytes());
Map<byte[], byte[]> results = primary.hgetAll(key);
Map<String, Double> resources = new HashMap<>();
for (Map.Entry<byte[], byte[]> entry : results.entrySet()) {
String resourceName = new String(entry.getKey());
Gcs.ResourceTableData resourceTableData;
try {
resourceTableData = Gcs.ResourceTableData.parseFrom(entry.getValue());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Received invalid protobuf data from GCS.");
}
resources.put(resourceName, resourceTableData.getResourceCapacity());
}
return resources;
}

/**
* If the actor exists in GCS.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ public Long set(final String key, final String value, final String field) {
return jedis.hset(key, field, value);
}
}

}

public String hmset(String key, Map<String, String> hash) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hmset(key, hash);
}
}

public Map<byte[], byte[]> hgetAll(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hgetAll(key);
}
}

public String get(final String key, final String field) {
Expand All @@ -67,7 +71,6 @@ public String get(final String key, final String field) {
return jedis.hget(key, field);
}
}

}

public byte[] get(byte[] key) {
Expand Down
2 changes: 2 additions & 0 deletions python/ray/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TablePrefix,
TablePubsub,
TaskTableData,
ResourceTableData,
)

__all__ = [
Expand All @@ -32,6 +33,7 @@
"TablePrefix",
"TablePubsub",
"TaskTableData",
"ResourceTableData",
"construct_error_message",
]

Expand Down
77 changes: 48 additions & 29 deletions python/ray/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,47 +50,35 @@ def _parse_client_table(redis_client):
for entry in gcs_entry.entries:
client = gcs_utils.ClientTableData.FromString(entry)

resources = {
client.resources_total_label[i]: client.resources_total_capacity[i]
for i in range(len(client.resources_total_label))
}
client_id = ray.utils.binary_to_hex(client.client_id)

if client.entry_type == gcs_utils.ClientTableData.INSERTION:
if client.is_insertion:
ordered_client_ids.append(client_id)
node_info[client_id] = {
"ClientID": client_id,
"EntryType": client.entry_type,
"IsInsertion": client.is_insertion,
"NodeManagerAddress": client.node_manager_address,
"NodeManagerPort": client.node_manager_port,
"ObjectManagerPort": client.object_manager_port,
"ObjectStoreSocketName": client.object_store_socket_name,
"RayletSocketName": client.raylet_socket_name,
"Resources": resources
"RayletSocketName": client.raylet_socket_name
}

# If this client is being updated, then it must
# If this client is being removed, then it must
# have previously been inserted, and
# it cannot have previously been removed.
else:
assert client_id in node_info, "Client not found!"
is_deletion = (node_info[client_id]["EntryType"] !=
gcs_utils.ClientTableData.DELETION)
assert is_deletion, "Unexpected updation of deleted client."
res_map = node_info[client_id]["Resources"]
if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE:
for res in resources:
res_map[res] = resources[res]
elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE:
for res in resources:
res_map.pop(res, None)
elif client.entry_type == gcs_utils.ClientTableData.DELETION:
pass # Do nothing with the resmap if client deletion
else:
raise RuntimeError("Unexpected EntryType {}".format(
client.entry_type))
node_info[client_id]["Resources"] = res_map
node_info[client_id]["EntryType"] = client.entry_type
assert node_info[client_id]["IsInsertion"], (
"Unexpected duplicate removal of client.")
node_info[client_id]["IsInsertion"] = client.is_insertion
# Fill resource info.
for client_id in ordered_client_ids:
if node_info[client_id]["IsInsertion"]:
resources = _parse_resource_table(redis_client, client_id)
else:
resources = {}
node_info[client_id]["Resources"] = resources
# NOTE: We return the list comprehension below instead of simply doing
# 'list(node_info.values())' in order to have the nodes appear in the order
# that they joined the cluster. Python dictionaries do not preserve
Expand All @@ -100,6 +88,38 @@ def _parse_client_table(redis_client):
return [node_info[client_id] for client_id in ordered_client_ids]


def _parse_resource_table(redis_client, client_id):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Considering client_id is stored/used in hex at most places in state.py and elsewhere, would it be more consistent to have the client_id argument as hex here and convert it to binary inside this method, rather than invoking it as _parse_resource_table(redis_client, ray.utils.hex_to_binary(client_id))?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Sorry that I'm not familiar with state.py.

"""Read the resource table with given client id.

Args:
redis_client: A client to the primary Redis shard.
client_id: The client ID of the node in hex.

Returns:
A dict of resources about this node.
"""
message = redis_client.execute_command(
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("NODE_RESOURCE"), "",
ray.utils.hex_to_binary(client_id))

if message is None:
return {}

resources = {}
gcs_entry = gcs_utils.GcsEntry.FromString(message)
entries_len = len(gcs_entry.entries)
if entries_len % 2 != 0:
raise Exception("Invalid entry size for resource lookup: " +
str(entries_len))

for i in range(0, entries_len, 2):
resource_table_data = gcs_utils.ResourceTableData.FromString(
gcs_entry.entries[i + 1])
resources[decode(
gcs_entry.entries[i])] = resource_table_data.resource_capacity
return resources


class GlobalState(object):
"""A class used to interface with the Ray control state.

Expand Down Expand Up @@ -800,7 +820,7 @@ def cluster_resources(self):
clients = self.client_table()
for client in clients:
# Only count resources from latest entries of live clients.
if client["EntryType"] != gcs_utils.ClientTableData.DELETION:
if client["IsInsertion"]:
for key, value in client["Resources"].items():
resources[key] += value
return dict(resources)
Expand All @@ -809,8 +829,7 @@ def _live_client_ids(self):
"""Returns a set of client IDs corresponding to clients still alive."""
return {
client["ClientID"]
for client in self.client_table()
if (client["EntryType"] != gcs_utils.ClientTableData.DELETION)
for client in self.client_table() if (client["IsInsertion"])
}

def available_resources(self):
Expand Down
4 changes: 1 addition & 3 deletions python/ray/tests/cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import redis

import ray
from ray.gcs_utils import ClientTableData

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -176,8 +175,7 @@ def wait_for_nodes(self, timeout=30):
while time.time() - start_time < timeout:
clients = ray.state._parse_client_table(redis_client)
live_clients = [
client for client in clients
if client["EntryType"] == ClientTableData.INSERTION
client for client in clients if client["IsInsertion"]
]

expected = len(self.list_all_nodes())
Expand Down
30 changes: 12 additions & 18 deletions src/ray/gcs/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,12 +1149,12 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client
ASSERT_EQ(client_id, added_id);
ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id);
ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id);
ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion);
ASSERT_EQ(data.is_insertion(), is_insertion);

ClientTableData cached_client;
client->client_table().GetClient(added_id, cached_client);
ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id);
ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion);
ASSERT_EQ(cached_client.is_insertion(), is_insertion);
}

void TestClientTableConnect(const JobID &job_id,
Expand Down Expand Up @@ -1273,29 +1273,24 @@ void TestHashTable(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> cli
const int expected_count = 14;
ClientID client_id = ClientID::FromRandom();
// Prepare the first resource map: data_map1.
auto cpu_data = std::make_shared<RayResource>();
cpu_data->set_resource_name("CPU");
cpu_data->set_resource_capacity(100);
auto gpu_data = std::make_shared<RayResource>();
gpu_data->set_resource_name("GPU");
gpu_data->set_resource_capacity(2);
DynamicResourceTable::DataMap data_map1;
auto cpu_data = std::make_shared<ResourceTableData>();
cpu_data->set_resource_capacity(100);
data_map1.emplace("CPU", cpu_data);
auto gpu_data = std::make_shared<ResourceTableData>();
gpu_data->set_resource_capacity(2);
data_map1.emplace("GPU", gpu_data);
// Prepare the second resource map: data_map2 which decreases CPU,
// increases GPU and add a new CUSTOM compared to data_map1.
auto data_cpu = std::make_shared<RayResource>();
data_cpu->set_resource_name("CPU");
data_cpu->set_resource_capacity(50);
auto data_gpu = std::make_shared<RayResource>();
data_gpu->set_resource_name("GPU");
data_gpu->set_resource_capacity(10);
auto data_custom = std::make_shared<RayResource>();
data_custom->set_resource_name("CUSTOM");
data_custom->set_resource_capacity(2);
DynamicResourceTable::DataMap data_map2;
auto data_cpu = std::make_shared<ResourceTableData>();
data_cpu->set_resource_capacity(50);
data_map2.emplace("CPU", data_cpu);
auto data_gpu = std::make_shared<ResourceTableData>();
data_gpu->set_resource_capacity(10);
data_map2.emplace("GPU", data_gpu);
auto data_custom = std::make_shared<ResourceTableData>();
data_custom->set_resource_capacity(2);
data_map2.emplace("CUSTOM", data_custom);
data_map2["CPU"]->set_resource_capacity(50);
// This is a common comparison function for the test.
Expand All @@ -1305,7 +1300,6 @@ void TestHashTable(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> cli
for (const auto &data : data1) {
auto iter = data2.find(data.first);
ASSERT_TRUE(iter != data2.end());
ASSERT_EQ(iter->second->resource_name(), data.second->resource_name());
ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity());
}
};
Expand Down
Loading