diff --git a/BUILD.bazel b/BUILD.bazel
index da36eec0cf57..bc9e6bcd8006 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1,22 +1,55 @@
# Bazel build
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
-load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
+load("@build_stack_rules_proto//python:python_proto_compile.bzl", "python_proto_compile")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
load("@//bazel:ray.bzl", "flatbuffer_py_library")
load("@//bazel:cython_library.bzl", "pyx_library")
COPTS = ["-DRAY_USE_GLOG"]
-# Node manager gRPC lib.
-grpc_proto_library(
- name = "node_manager_grpc_lib",
+# === Begin of protobuf definitions ===
+
+proto_library(
+ name = "gcs_proto",
+ srcs = ["src/ray/protobuf/gcs.proto"],
+ visibility = ["//java:__subpackages__"],
+)
+
+cc_proto_library(
+ name = "gcs_cc_proto",
+ deps = [":gcs_proto"],
+)
+
+python_proto_compile(
+ name = "gcs_py_proto",
+ deps = [":gcs_proto"],
+)
+
+proto_library(
+ name = "node_manager_proto",
srcs = ["src/ray/protobuf/node_manager.proto"],
)
+cc_proto_library(
+ name = "node_manager_cc_proto",
+ deps = ["node_manager_proto"],
+)
+
+# === End of protobuf definitions ===
+
+# Node manager gRPC lib.
+cc_grpc_library(
+ name = "node_manager_cc_grpc",
+ srcs = [":node_manager_proto"],
+ grpc_only = True,
+ deps = [":node_manager_cc_proto"],
+)
+
# Node manager server and client.
cc_library(
- name = "node_manager_rpc_lib",
+ name = "node_manager_rpc",
srcs = glob([
"src/ray/rpc/*.cc",
]),
@@ -25,7 +58,7 @@ cc_library(
]),
copts = COPTS,
deps = [
- ":node_manager_grpc_lib",
+ ":node_manager_cc_grpc",
":ray_common",
"@boost//:asio",
"@com_github_grpc_grpc//:grpc++",
@@ -114,7 +147,7 @@ cc_library(
":gcs",
":gcs_fbs",
":node_manager_fbs",
- ":node_manager_rpc_lib",
+ ":node_manager_rpc",
":object_manager",
":ray_common",
":ray_util",
@@ -422,9 +455,11 @@ cc_library(
"src/ray/gcs/format",
],
deps = [
+ ":gcs_cc_proto",
":gcs_fbs",
":hiredis",
":node_manager_fbs",
+ ":node_manager_rpc",
":ray_common",
":ray_util",
":stats_lib",
@@ -555,46 +590,6 @@ filegroup(
visibility = ["//java:__subpackages__"],
)
-flatbuffer_py_library(
- name = "python_gcs_fbs",
- srcs = [
- ":gcs_fbs_file",
- ],
- outs = [
- "ActorCheckpointIdData.py",
- "ActorState.py",
- "ActorTableData.py",
- "Arg.py",
- "ClassTableData.py",
- "ClientTableData.py",
- "ConfigTableData.py",
- "CustomSerializerData.py",
- "DriverTableData.py",
- "EntryType.py",
- "ErrorTableData.py",
- "ErrorType.py",
- "FunctionTableData.py",
- "GcsEntry.py",
- "HeartbeatBatchTableData.py",
- "HeartbeatTableData.py",
- "Language.py",
- "ObjectTableData.py",
- "ProfileEvent.py",
- "ProfileTableData.py",
- "RayResource.py",
- "ResourcePair.py",
- "SchedulingState.py",
- "TablePrefix.py",
- "TablePubsub.py",
- "TaskInfo.py",
- "TaskLeaseData.py",
- "TaskReconstructionData.py",
- "TaskTableData.py",
- "TaskTableTestAndUpdate.py",
- ],
- out_prefix = "python/ray/core/generated/",
-)
-
flatbuffer_py_library(
name = "python_node_manager_fbs",
srcs = [
@@ -679,6 +674,7 @@ cc_binary(
linkstatic = 1,
visibility = ["//java:__subpackages__"],
deps = [
+ ":gcs_cc_proto",
":ray_common",
],
)
@@ -688,7 +684,7 @@ genrule(
srcs = [
"python/ray/_raylet.so",
"//:python_sources",
- "//:python_gcs_fbs",
+ "//:gcs_py_proto",
"//:python_node_manager_fbs",
"//:redis-server",
"//:redis-cli",
@@ -710,11 +706,13 @@ genrule(
cp -f $(location //:raylet_monitor) $$WORK_DIR/python/ray/core/src/ray/raylet/ &&
cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ &&
cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ &&
- for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done &&
mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ &&
for f in $(locations //:python_node_manager_fbs); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/;
done &&
+ for f in $(locations //:gcs_py_proto); do
+ cp -f $$f $$WORK_DIR/python/ray/core/generated/;
+ done &&
echo $$WORK_DIR > $@
""",
local = 1,
diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl
index 3e1e1838a59a..eda88bece7d2 100644
--- a/bazel/ray_deps_build_all.bzl
+++ b/bazel/ray_deps_build_all.bzl
@@ -4,6 +4,8 @@ load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_rep
load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure")
load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps")
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile")
+load("@build_stack_rules_proto//python:deps.bzl", "python_proto_compile")
def ray_deps_build_all():
@@ -13,4 +15,6 @@ def ray_deps_build_all():
prometheus_cpp_repositories()
python_configure(name = "local_config_python")
grpc_deps()
+ java_proto_compile()
+ python_proto_compile()
diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl
index e6dc21585699..aa322654cf9f 100644
--- a/bazel/ray_deps_setup.bzl
+++ b/bazel/ray_deps_setup.bzl
@@ -105,7 +105,14 @@ def ray_deps_setup():
http_archive(
name = "com_github_grpc_grpc",
urls = [
- "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz",
+ "https://github.com/grpc/grpc/archive/76a381869413834692b8ed305fbe923c0f9c4472.tar.gz",
],
- strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49",
+ strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472",
+ )
+
+ http_archive(
+ name = "build_stack_rules_proto",
+ urls = ["https://github.com/stackb/rules_proto/archive/b93b544f851fdcd3fc5c3d47aee3b7ca158a8841.tar.gz"],
+ sha256 = "c62f0b442e82a6152fcd5b1c0b7c4028233a9e314078952b6b04253421d56d61",
+ strip_prefix = "rules_proto-b93b544f851fdcd3fc5c3d47aee3b7ca158a8841",
)
diff --git a/doc/source/conf.py b/doc/source/conf.py
index 98fb3e0d02dd..5cf6b01217f9 100644
--- a/doc/source/conf.py
+++ b/doc/source/conf.py
@@ -23,20 +23,7 @@
"gym.spaces",
"ray._raylet",
"ray.core.generated",
- "ray.core.generated.ActorCheckpointIdData",
- "ray.core.generated.ClientTableData",
- "ray.core.generated.DriverTableData",
- "ray.core.generated.EntryType",
- "ray.core.generated.ErrorTableData",
- "ray.core.generated.ErrorType",
- "ray.core.generated.GcsEntry",
- "ray.core.generated.HeartbeatBatchTableData",
- "ray.core.generated.HeartbeatTableData",
- "ray.core.generated.Language",
- "ray.core.generated.ObjectTableData",
- "ray.core.generated.ProfileTableData",
- "ray.core.generated.TablePrefix",
- "ray.core.generated.TablePubsub",
+ "ray.core.generated.gcs_pb2",
"ray.core.generated.ray.protocol.Task",
"scipy",
"scipy.signal",
diff --git a/java/BUILD.bazel b/java/BUILD.bazel
index 80ccabccfc12..4960434af180 100644
--- a/java/BUILD.bazel
+++ b/java/BUILD.bazel
@@ -1,4 +1,5 @@
load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module")
+load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile")
exports_files([
"testng.xml",
@@ -50,6 +51,7 @@ define_java_module(
name = "runtime",
additional_srcs = [
":generate_java_gcs_fbs",
+ ":gcs_java_proto",
],
additional_resources = [
":java_native_deps",
@@ -68,6 +70,7 @@ define_java_module(
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_github_davidmoten_flatbuffers_java",
"@maven//:com_google_guava_guava",
+ "@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
"@maven//:commons_io_commons_io",
"@maven//:de_ruedigermoeller_fst",
@@ -148,38 +151,16 @@ java_binary(
],
)
+java_proto_compile(
+ name = "gcs_java_proto",
+ deps = ["@//:gcs_proto"],
+)
+
flatbuffers_generated_files = [
- "ActorCheckpointData.java",
- "ActorCheckpointIdData.java",
- "ActorState.java",
- "ActorTableData.java",
"Arg.java",
- "ClassTableData.java",
- "ClientTableData.java",
- "ConfigTableData.java",
- "CustomSerializerData.java",
- "DriverTableData.java",
- "EntryType.java",
- "ErrorTableData.java",
- "ErrorType.java",
- "FunctionTableData.java",
- "GcsEntry.java",
- "HeartbeatBatchTableData.java",
- "HeartbeatTableData.java",
"Language.java",
- "ObjectTableData.java",
- "ProfileEvent.java",
- "ProfileTableData.java",
- "RayResource.java",
- "ResourcePair.java",
- "SchedulingState.java",
- "TablePrefix.java",
- "TablePubsub.java",
"TaskInfo.java",
- "TaskLeaseData.java",
- "TaskReconstructionData.java",
- "TaskTableData.java",
- "TaskTableTestAndUpdate.java",
+ "ResourcePair.java",
]
flatbuffer_java_library(
@@ -198,7 +179,7 @@ genrule(
cmd = """
for f in $(locations //java:java_gcs_fbs); do
chmod +w $$f
- cp -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated
+ mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated
done
python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/..
""",
@@ -221,8 +202,10 @@ filegroup(
genrule(
name = "gen_maven_deps",
srcs = [
- ":java_native_deps",
+ ":gcs_java_proto",
":generate_java_gcs_fbs",
+ ":java_native_deps",
+ ":copy_pom_file",
"@plasma//:org_apache_arrow_arrow_plasma",
],
outs = ["gen_maven_deps.out"],
@@ -237,10 +220,15 @@ genrule(
chmod +w $$f
cp $$f $$NATIVE_DEPS_DIR
done
- # Copy flatbuffers-generated files
+ # Copy protobuf-generated files.
GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
rm -rf $$GENERATED_DIR
mkdir -p $$GENERATED_DIR
+ for f in $(locations //java:gcs_java_proto); do
+ unzip $$f
+ mv org/ray/runtime/generated/* $$GENERATED_DIR
+ done
+ # Copy flatbuffers-generated files
for f in $(locations //java:generate_java_gcs_fbs); do
cp $$f $$GENERATED_DIR
done
@@ -250,6 +238,7 @@ genrule(
echo $$(date) > $@
""",
local = 1,
+ tags = ["no-cache"],
)
genrule(
diff --git a/java/dependencies.bzl b/java/dependencies.bzl
index 7c716166d399..ef667137562b 100644
--- a/java/dependencies.bzl
+++ b/java/dependencies.bzl
@@ -6,6 +6,7 @@ def gen_java_deps():
"com.beust:jcommander:1.72",
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
"com.google.guava:guava:27.0.1-jre",
+ "com.google.protobuf:protobuf-java:3.8.0",
"com.puppycrawl.tools:checkstyle:8.15",
"com.sun.xml.bind:jaxb-core:2.3.0",
"com.sun.xml.bind:jaxb-impl:2.3.0",
diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py
index c1b723f25f8d..5bf62e56d7e4 100644
--- a/java/modify_generated_java_flatbuffers_files.py
+++ b/java/modify_generated_java_flatbuffers_files.py
@@ -4,7 +4,6 @@
import os
import sys
-
"""
This script is used for modifying the generated java flatbuffer
files for the reason: The package declaration in Java is different
@@ -21,19 +20,18 @@
PACKAGE_DECLARATION = "package org.ray.runtime.generated;"
-def add_new_line(file, line_num, text):
+def add_package(file):
with open(file, "r") as file_handler:
lines = file_handler.readlines()
- if (line_num <= 0) or (line_num > len(lines) + 1):
- return False
- lines.insert(line_num - 1, text + os.linesep)
+ if "FlatBuffers" not in lines[0]:
+ return
+
+ lines.insert(1, PACKAGE_DECLARATION + os.linesep)
with open(file, "w") as file_handler:
for line in lines:
file_handler.write(line)
- return True
-
def add_package_declarations(generated_root_path):
file_names = os.listdir(generated_root_path)
@@ -41,15 +39,11 @@ def add_package_declarations(generated_root_path):
if not file_name.endswith(".java"):
continue
full_name = os.path.join(generated_root_path, file_name)
- success = add_new_line(full_name, 2, PACKAGE_DECLARATION)
- if not success:
- raise RuntimeError("Failed to add package declarations, "
- "file name is %s" % full_name)
+ add_package(full_name)
if __name__ == "__main__":
ray_home = sys.argv[1]
root_path = os.path.join(
- ray_home,
- "java/runtime/src/main/java/org/ray/runtime/generated")
+ ray_home, "java/runtime/src/main/java/org/ray/runtime/generated")
add_package_declarations(root_path)
diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml
index c75e2eeef13f..e13dd95f927f 100644
--- a/java/runtime/pom.xml
+++ b/java/runtime/pom.xml
@@ -41,6 +41,11 @@
guava
27.0.1-jre
+
+ com.google.protobuf
+ protobuf-java
+ 3.8.0
+
com.typesafe
config
diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java
index 431b48ded58c..17c248ed0a57 100644
--- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java
+++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java
@@ -1,7 +1,7 @@
package org.ray.runtime.gcs;
import com.google.common.base.Preconditions;
-import java.nio.ByteBuffer;
+import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -13,10 +13,10 @@
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.runtimecontext.NodeInfo;
-import org.ray.runtime.generated.ActorCheckpointIdData;
-import org.ray.runtime.generated.ClientTableData;
-import org.ray.runtime.generated.EntryType;
-import org.ray.runtime.generated.TablePrefix;
+import org.ray.runtime.generated.Gcs.ActorCheckpointIdData;
+import org.ray.runtime.generated.Gcs.ClientTableData;
+import org.ray.runtime.generated.Gcs.ClientTableData.EntryType;
+import org.ray.runtime.generated.Gcs.TablePrefix;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -51,7 +51,7 @@ public GcsClient(String redisAddress, String redisPassword) {
}
public List getAllNodeInfo() {
- final String prefix = TablePrefix.name(TablePrefix.CLIENT);
+ final String prefix = TablePrefix.CLIENT.toString();
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes());
List results = primary.lrange(key, 0, -1);
@@ -63,36 +63,42 @@ public List getAllNodeInfo() {
Map clients = new HashMap<>();
for (byte[] result : results) {
Preconditions.checkNotNull(result);
- ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result));
- final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer());
+ ClientTableData data = null;
+ try {
+ data = ClientTableData.parseFrom(result);
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException("Received invalid protobuf data from GCS.");
+ }
+ final UniqueId clientId = UniqueId
+ .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer());
- if (data.entryType() == EntryType.INSERTION) {
+ if (data.getEntryType() == EntryType.INSERTION) {
//Code path of node insertion.
Map resources = new HashMap<>();
// Compute resources.
Preconditions.checkState(
- data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength());
- for (int i = 0; i < data.resourcesTotalLabelLength(); i++) {
- resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i));
+ data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount());
+ for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) {
+ resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i));
}
NodeInfo nodeInfo = new NodeInfo(
- clientId, data.nodeManagerAddress(), true, resources);
+ clientId, data.getNodeManagerAddress(), true, resources);
clients.put(clientId, nodeInfo);
- } else if (data.entryType() == EntryType.RES_CREATEUPDATE) {
+ } else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) {
Preconditions.checkState(clients.containsKey(clientId));
NodeInfo nodeInfo = clients.get(clientId);
- for (int i = 0; i < data.resourcesTotalLabelLength(); i++) {
- nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i));
+ for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) {
+ nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i));
}
- } else if (data.entryType() == EntryType.RES_DELETE) {
+ } else if (data.getEntryType() == EntryType.RES_DELETE) {
Preconditions.checkState(clients.containsKey(clientId));
NodeInfo nodeInfo = clients.get(clientId);
- for (int i = 0; i < data.resourcesTotalLabelLength(); i++) {
- nodeInfo.resources.remove(data.resourcesTotalLabel(i));
+ for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) {
+ nodeInfo.resources.remove(data.getResourcesTotalLabel(i));
}
} else {
// Code path of node deletion.
- Preconditions.checkState(data.entryType() == EntryType.DELETION);
+ Preconditions.checkState(data.getEntryType() == EntryType.DELETION);
NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress,
false, clients.get(clientId).resources);
clients.put(clientId, nodeInfo);
@@ -107,7 +113,7 @@ public List getAllNodeInfo() {
*/
public boolean actorExists(UniqueId actorId) {
byte[] key = ArrayUtils.addAll(
- TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes());
+ TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes());
return primary.exists(key);
}
@@ -115,7 +121,7 @@ public boolean actorExists(UniqueId actorId) {
* Query whether the raylet task exists in Gcs.
*/
public boolean rayletTaskExistsInGcs(TaskId taskId) {
- byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(),
+ byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(),
taskId.getBytes());
RedisClient client = getShardClient(taskId);
return client.exists(key);
@@ -126,19 +132,26 @@ public boolean rayletTaskExistsInGcs(TaskId taskId) {
*/
public List getCheckpointsForActor(UniqueId actorId) {
List checkpoints = new ArrayList<>();
- final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
+ final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString();
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
RedisClient client = getShardClient(actorId);
byte[] result = client.get(key);
if (result != null) {
- ActorCheckpointIdData data =
- ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
- UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer(
- data.checkpointIdsAsByteBuffer());
+ ActorCheckpointIdData data = null;
+ try {
+ data = ActorCheckpointIdData.parseFrom(result);
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException("Received invalid protobuf data from GCS.");
+ }
+ UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()];
+ for (int i = 0; i < checkpointIds.length; i++) {
+ checkpointIds[i] = UniqueId
+ .fromByteBuffer(data.getCheckpointIds(i).asReadOnlyByteBuffer());
+ }
for (int i = 0; i < checkpointIds.length; i++) {
- checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
+ checkpoints.add(new Checkpoint(checkpointIds[i], data.getTimestamps(i)));
}
}
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java
index f9e310249a35..1a7e4701c22b 100644
--- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java
+++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java
@@ -16,7 +16,7 @@
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
-import org.ray.runtime.generated.ErrorType;
+import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.IdUtil;
import org.ray.runtime.util.Serializer;
import org.slf4j.Logger;
@@ -29,12 +29,12 @@ public class ObjectStoreProxy {
private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class);
- private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED)
- .getBytes();
- private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED)
- .getBytes();
+ private static final byte[] WORKER_EXCEPTION_META = String
+ .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
+ private static final byte[] ACTOR_EXCEPTION_META = String
+ .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
- .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
+ .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py
index cadd197ec73f..ba72e96f41db 100644
--- a/python/ray/gcs_utils.py
+++ b/python/ray/gcs_utils.py
@@ -2,38 +2,39 @@
from __future__ import division
from __future__ import print_function
-import flatbuffers
-import ray.core.generated.ErrorTableData
-
-from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData
-from ray.core.generated.ClientTableData import ClientTableData
-from ray.core.generated.DriverTableData import DriverTableData
-from ray.core.generated.ErrorTableData import ErrorTableData
-from ray.core.generated.GcsEntry import GcsEntry
-from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
-from ray.core.generated.HeartbeatTableData import HeartbeatTableData
-from ray.core.generated.Language import Language
-from ray.core.generated.ObjectTableData import ObjectTableData
-from ray.core.generated.ProfileTableData import ProfileTableData
-from ray.core.generated.TablePrefix import TablePrefix
-from ray.core.generated.TablePubsub import TablePubsub
-
from ray.core.generated.ray.protocol.Task import Task
+from ray.core.generated.gcs_pb2 import (
+ ActorCheckpointIdData,
+ ClientTableData,
+ DriverTableData,
+ ErrorTableData,
+ ErrorType,
+ GcsEntry,
+ HeartbeatBatchTableData,
+ HeartbeatTableData,
+ ObjectTableData,
+ ProfileTableData,
+ TablePrefix,
+ TablePubsub,
+ TaskTableData,
+)
+
__all__ = [
"ActorCheckpointIdData",
"ClientTableData",
"DriverTableData",
"ErrorTableData",
+ "ErrorType",
"GcsEntry",
"HeartbeatBatchTableData",
"HeartbeatTableData",
- "Language",
"ObjectTableData",
"ProfileTableData",
"TablePrefix",
"TablePubsub",
"Task",
+ "TaskTableData",
"construct_error_message",
]
@@ -42,13 +43,16 @@
REPORTER_CHANNEL = "RAY_REPORTER"
# xray heartbeats
-XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")
-XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii")
+XRAY_HEARTBEAT_CHANNEL = str(
+ TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii")
+XRAY_HEARTBEAT_BATCH_CHANNEL = str(
+ TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii")
# xray driver updates
-XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii")
+XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii")
-# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
+# These prefixes must be kept up-to-date with the TablePrefix enum in
+# gcs.proto.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
@@ -70,22 +74,9 @@ def construct_error_message(driver_id, error_type, message, timestamp):
Returns:
The serialized object.
"""
- builder = flatbuffers.Builder(0)
- driver_offset = builder.CreateString(driver_id.binary())
- error_type_offset = builder.CreateString(error_type)
- message_offset = builder.CreateString(message)
-
- ray.core.generated.ErrorTableData.ErrorTableDataStart(builder)
- ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId(
- builder, driver_offset)
- ray.core.generated.ErrorTableData.ErrorTableDataAddType(
- builder, error_type_offset)
- ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage(
- builder, message_offset)
- ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp(
- builder, timestamp)
- error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd(
- builder)
- builder.Finish(error_data_offset)
-
- return bytes(builder.Output())
+ data = ErrorTableData()
+ data.driver_id = driver_id.binary()
+ data.type = error_type
+ data.error_message = message
+ data.timestamp = timestamp
+ return data.SerializeToString()
diff --git a/python/ray/monitor.py b/python/ray/monitor.py
index c9e0424b3eb8..35597ef231e3 100644
--- a/python/ray/monitor.py
+++ b/python/ray/monitor.py
@@ -101,28 +101,26 @@ def subscribe(self, channel):
def xray_heartbeat_batch_handler(self, unused_channel, data):
"""Handle an xray heartbeat batch message from Redis."""
- gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
- heartbeat_data = gcs_entries.Entries(0)
+ gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
+ heartbeat_data = gcs_entries.entries[0]
- message = (ray.gcs_utils.HeartbeatBatchTableData.
- GetRootAsHeartbeatBatchTableData(heartbeat_data, 0))
+ message = ray.gcs_utils.HeartbeatBatchTableData.FromString(
+ heartbeat_data)
- for j in range(message.BatchLength()):
- heartbeat_message = message.Batch(j)
-
- num_resources = heartbeat_message.ResourcesTotalLabelLength()
+ for heartbeat_message in message.batch:
+ num_resources = len(heartbeat_message.resources_available_label)
static_resources = {}
dynamic_resources = {}
for i in range(num_resources):
- dyn = heartbeat_message.ResourcesAvailableLabel(i)
- static = heartbeat_message.ResourcesTotalLabel(i)
+ dyn = heartbeat_message.resources_available_label[i]
+ static = heartbeat_message.resources_total_label[i]
dynamic_resources[dyn] = (
- heartbeat_message.ResourcesAvailableCapacity(i))
+ heartbeat_message.resources_available_capacity[i])
static_resources[static] = (
- heartbeat_message.ResourcesTotalCapacity(i))
+ heartbeat_message.resources_total_capacity[i])
# Update the load metrics for this raylet.
- client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId())
+ client_id = ray.utils.binary_to_hex(heartbeat_message.client_id)
ip = self.raylet_id_to_ip_map.get(client_id)
if ip:
self.load_metrics.update(ip, static_resources,
@@ -207,11 +205,10 @@ def xray_driver_removed_handler(self, unused_channel, data):
unused_channel: The message channel.
data: The message data.
"""
- gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
- driver_data = gcs_entries.Entries(0)
- message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
- driver_data, 0)
- driver_id = message.DriverId()
+ gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
+ driver_data = gcs_entries.entries[0]
+ message = ray.gcs_utils.DriverTableData.FromString(driver_data)
+ driver_id = message.driver_id
logger.info("Monitor: "
"XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
diff --git a/python/ray/state.py b/python/ray/state.py
index 14ba49987ec4..35f97cd65f5e 100644
--- a/python/ray/state.py
+++ b/python/ray/state.py
@@ -10,11 +10,11 @@
import ray
from ray.function_manager import FunctionDescriptor
-import ray.gcs_utils
-from ray.ray_constants import ID_SIZE
-from ray import services
-from ray.core.generated.EntryType import EntryType
+from ray import (
+ gcs_utils,
+ services,
+)
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
@@ -31,9 +31,9 @@ def _parse_client_table(redis_client):
A list of information about the nodes in the cluster.
"""
NIL_CLIENT_ID = ray.ObjectID.nil().binary()
- message = redis_client.execute_command("RAY.TABLE_LOOKUP",
- ray.gcs_utils.TablePrefix.CLIENT,
- "", NIL_CLIENT_ID)
+ message = redis_client.execute_command(
+ "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "",
+ NIL_CLIENT_ID)
# Handle the case where no clients are returned. This should only
# occur potentially immediately after the cluster is started.
@@ -41,36 +41,31 @@ def _parse_client_table(redis_client):
return []
node_info = {}
- gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
+ gcs_entry = gcs_utils.GcsEntry.FromString(message)
ordered_client_ids = []
# Since GCS entries are append-only, we override so that
# only the latest entries are kept.
- for i in range(gcs_entry.EntriesLength()):
- client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
- gcs_entry.Entries(i), 0))
+ for entry in gcs_entry.entries:
+ client = gcs_utils.ClientTableData.FromString(entry)
resources = {
- decode(client.ResourcesTotalLabel(i)):
- client.ResourcesTotalCapacity(i)
- for i in range(client.ResourcesTotalLabelLength())
+ client.resources_total_label[i]: client.resources_total_capacity[i]
+ for i in range(len(client.resources_total_label))
}
- client_id = ray.utils.binary_to_hex(client.ClientId())
+ client_id = ray.utils.binary_to_hex(client.client_id)
- if client.EntryType() == EntryType.INSERTION:
+ if client.entry_type == gcs_utils.ClientTableData.INSERTION:
ordered_client_ids.append(client_id)
node_info[client_id] = {
"ClientID": client_id,
- "EntryType": client.EntryType(),
- "NodeManagerAddress": decode(
- client.NodeManagerAddress(), allow_none=True),
- "NodeManagerPort": client.NodeManagerPort(),
- "ObjectManagerPort": client.ObjectManagerPort(),
- "ObjectStoreSocketName": decode(
- client.ObjectStoreSocketName(), allow_none=True),
- "RayletSocketName": decode(
- client.RayletSocketName(), allow_none=True),
+ "EntryType": client.entry_type,
+ "NodeManagerAddress": client.node_manager_address,
+ "NodeManagerPort": client.node_manager_port,
+ "ObjectManagerPort": client.object_manager_port,
+ "ObjectStoreSocketName": client.object_store_socket_name,
+ "RayletSocketName": client.raylet_socket_name,
"Resources": resources
}
@@ -79,22 +74,23 @@ def _parse_client_table(redis_client):
# it cannot have previously been removed.
else:
assert client_id in node_info, "Client not found!"
- assert node_info[client_id]["EntryType"] != EntryType.DELETION, (
- "Unexpected updation of deleted client.")
+ is_deletion = (node_info[client_id]["EntryType"] !=
+ gcs_utils.ClientTableData.DELETION)
+ assert is_deletion, "Unexpected updation of deleted client."
res_map = node_info[client_id]["Resources"]
- if client.EntryType() == EntryType.RES_CREATEUPDATE:
+ if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE:
for res in resources:
res_map[res] = resources[res]
- elif client.EntryType() == EntryType.RES_DELETE:
+ elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE:
for res in resources:
res_map.pop(res, None)
- elif client.EntryType() == EntryType.DELETION:
+ elif client.entry_type == gcs_utils.ClientTableData.DELETION:
pass # Do nothing with the resmap if client deletion
else:
raise RuntimeError("Unexpected EntryType {}".format(
- client.EntryType()))
+ client.entry_type))
node_info[client_id]["Resources"] = res_map
- node_info[client_id]["EntryType"] = client.EntryType()
+ node_info[client_id]["EntryType"] = client.entry_type
# NOTE: We return the list comprehension below instead of simply doing
# 'list(node_info.values())' in order to have the nodes appear in the order
# that they joined the cluster. Python dictionaries do not preserve
@@ -244,20 +240,19 @@ def _object_table(self, object_id):
# Return information about a single object ID.
message = self._execute_command(object_id, "RAY.TABLE_LOOKUP",
- ray.gcs_utils.TablePrefix.OBJECT, "",
- object_id.binary())
+ gcs_utils.TablePrefix.Value("OBJECT"),
+ "", object_id.binary())
if message is None:
return {}
- gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
+ gcs_entry = gcs_utils.GcsEntry.FromString(message)
- assert gcs_entry.EntriesLength() > 0
+ assert len(gcs_entry.entries) > 0
- entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
- gcs_entry.Entries(0), 0)
+ entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0])
object_info = {
- "DataSize": entry.ObjectSize(),
- "Manager": entry.Manager(),
+ "DataSize": entry.object_size,
+ "Manager": entry.manager,
}
return object_info
@@ -278,10 +273,9 @@ def object_table(self, object_id=None):
return self._object_table(object_id)
else:
# Return the entire object table.
- object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string +
- "*")
+ object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*")
object_ids_binary = {
- key[len(ray.gcs_utils.TablePrefix_OBJECT_string):]
+ key[len(gcs_utils.TablePrefix_OBJECT_string):]
for key in object_keys
}
@@ -301,17 +295,18 @@ def _task_table(self, task_id):
A dictionary with information about the task ID in question.
"""
assert isinstance(task_id, ray.TaskID)
- message = self._execute_command(task_id, "RAY.TABLE_LOOKUP",
- ray.gcs_utils.TablePrefix.RAYLET_TASK,
- "", task_id.binary())
+ message = self._execute_command(
+ task_id, "RAY.TABLE_LOOKUP",
+ gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary())
if message is None:
return {}
- gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
-
- assert gcs_entries.EntriesLength() == 1
+ gcs_entries = gcs_utils.GcsEntry.FromString(message)
- task_table_message = ray.gcs_utils.Task.GetRootAsTask(
- gcs_entries.Entries(0), 0)
+ assert len(gcs_entries.entries) == 1
+ task_table_data = gcs_utils.TaskTableData.FromString(
+ gcs_entries.entries[0])
+ task_table_message = gcs_utils.Task.GetRootAsTask(
+ task_table_data.task, 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
@@ -368,9 +363,9 @@ def task_table(self, task_id=None):
return self._task_table(task_id)
else:
task_table_keys = self._keys(
- ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
+ gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
task_ids_binary = [
- key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
+ key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):]
for key in task_table_keys
]
@@ -380,27 +375,6 @@ def task_table(self, task_id=None):
ray.TaskID(task_id_binary))
return results
- def function_table(self, function_id=None):
- """Fetch and parse the function table.
-
- Returns:
- A dictionary that maps function IDs to information about the
- function.
- """
- self._check_connected()
- function_table_keys = self.redis_client.keys(
- ray.gcs_utils.FUNCTION_PREFIX + "*")
- results = {}
- for key in function_table_keys:
- info = self.redis_client.hgetall(key)
- function_info_parsed = {
- "DriverID": binary_to_hex(info[b"driver_id"]),
- "Module": decode(info[b"module"]),
- "Name": decode(info[b"name"])
- }
- results[binary_to_hex(info[b"function_id"])] = function_info_parsed
- return results
-
def client_table(self):
"""Fetch and parse the Redis DB client table.
@@ -423,37 +397,32 @@ def _profile_table(self, batch_id):
# TODO(rkn): This method should support limiting the number of log
# events and should also support returning a window of events.
message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP",
- ray.gcs_utils.TablePrefix.PROFILE, "",
- batch_id.binary())
+ gcs_utils.TablePrefix.Value("PROFILE"),
+ "", batch_id.binary())
if message is None:
return []
- gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
+ gcs_entries = gcs_utils.GcsEntry.FromString(message)
profile_events = []
- for i in range(gcs_entries.EntriesLength()):
- profile_table_message = (
- ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData(
- gcs_entries.Entries(i), 0))
-
- component_type = decode(profile_table_message.ComponentType())
- component_id = binary_to_hex(profile_table_message.ComponentId())
- node_ip_address = decode(
- profile_table_message.NodeIpAddress(), allow_none=True)
+ for entry in gcs_entries.entries:
+ profile_table_message = gcs_utils.ProfileTableData.FromString(
+ entry)
- for j in range(profile_table_message.ProfileEventsLength()):
- profile_event_message = profile_table_message.ProfileEvents(j)
+ component_type = profile_table_message.component_type
+ component_id = binary_to_hex(profile_table_message.component_id)
+ node_ip_address = profile_table_message.node_ip_address
+ for profile_event_message in profile_table_message.profile_events:
profile_event = {
- "event_type": decode(profile_event_message.EventType()),
+ "event_type": profile_event_message.event_type,
"component_id": component_id,
"node_ip_address": node_ip_address,
"component_type": component_type,
- "start_time": profile_event_message.StartTime(),
- "end_time": profile_event_message.EndTime(),
- "extra_data": json.loads(
- decode(profile_event_message.ExtraData())),
+ "start_time": profile_event_message.start_time,
+ "end_time": profile_event_message.end_time,
+ "extra_data": json.loads(profile_event_message.extra_data),
}
profile_events.append(profile_event)
@@ -462,10 +431,10 @@ def _profile_table(self, batch_id):
def profile_table(self):
self._check_connected()
- profile_table_keys = self._keys(
- ray.gcs_utils.TablePrefix_PROFILE_string + "*")
+ profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string +
+ "*")
batch_identifiers_binary = [
- key[len(ray.gcs_utils.TablePrefix_PROFILE_string):]
+ key[len(gcs_utils.TablePrefix_PROFILE_string):]
for key in profile_table_keys
]
@@ -766,7 +735,7 @@ def cluster_resources(self):
clients = self.client_table()
for client in clients:
# Only count resources from latest entries of live clients.
- if client["EntryType"] != EntryType.DELETION:
+ if client["EntryType"] != gcs_utils.ClientTableData.DELETION:
for key, value in client["Resources"].items():
resources[key] += value
return dict(resources)
@@ -776,7 +745,7 @@ def _live_client_ids(self):
return {
client["ClientID"]
for client in self.client_table()
- if (client["EntryType"] != EntryType.DELETION)
+ if (client["EntryType"] != gcs_utils.ClientTableData.DELETION)
}
def available_resources(self):
@@ -800,7 +769,7 @@ def available_resources(self):
for redis_client in self.redis_clients
]
for subscribe_client in subscribe_clients:
- subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)
+ subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL)
client_ids = self._live_client_ids()
@@ -809,24 +778,23 @@ def available_resources(self):
# Parse client message
raw_message = subscribe_client.get_message()
if (raw_message is None or raw_message["channel"] !=
- ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
+ gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
- gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
- data, 0))
- heartbeat_data = gcs_entries.Entries(0)
- message = (ray.gcs_utils.HeartbeatTableData.
- GetRootAsHeartbeatTableData(heartbeat_data, 0))
+ gcs_entries = gcs_utils.GcsEntry.FromString(data)
+ heartbeat_data = gcs_entries.entries[0]
+ message = gcs_utils.HeartbeatTableData.FromString(
+ heartbeat_data)
# Calculate available resources for this client
- num_resources = message.ResourcesAvailableLabelLength()
+ num_resources = len(message.resources_available_label)
dynamic_resources = {}
for i in range(num_resources):
- resource_id = decode(message.ResourcesAvailableLabel(i))
+ resource_id = message.resources_available_label[i]
dynamic_resources[resource_id] = (
- message.ResourcesAvailableCapacity(i))
+ message.resources_available_capacity[i])
# Update available resources for this client
- client_id = ray.utils.binary_to_hex(message.ClientId())
+ client_id = ray.utils.binary_to_hex(message.client_id)
available_resources_by_id[client_id] = dynamic_resources
# Update clients in cluster
@@ -860,23 +828,22 @@ def _error_messages(self, driver_id):
"""
assert isinstance(driver_id, ray.DriverID)
message = self.redis_client.execute_command(
- "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "",
+ "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "",
driver_id.binary())
# If there are no errors, return early.
if message is None:
return []
- gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
+ gcs_entries = gcs_utils.GcsEntry.FromString(message)
error_messages = []
- for i in range(gcs_entries.EntriesLength()):
- error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
- gcs_entries.Entries(i), 0)
- assert driver_id.binary() == error_data.DriverId()
+ for entry in gcs_entries.entries:
+ error_data = gcs_utils.ErrorTableData.FromString(entry)
+ assert driver_id.binary() == error_data.driver_id
error_message = {
- "type": decode(error_data.Type()),
- "message": decode(error_data.ErrorMessage()),
- "timestamp": error_data.Timestamp(),
+ "type": error_data.type,
+ "message": error_data.error_message,
+ "timestamp": error_data.timestamp,
}
error_messages.append(error_message)
return error_messages
@@ -899,9 +866,9 @@ def error_messages(self, driver_id=None):
return self._error_messages(driver_id)
error_table_keys = self.redis_client.keys(
- ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*")
+ gcs_utils.TablePrefix_ERROR_INFO_string + "*")
driver_ids = [
- key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):]
+ key[len(gcs_utils.TablePrefix_ERROR_INFO_string):]
for key in error_table_keys
]
@@ -923,30 +890,23 @@ def actor_checkpoint_info(self, actor_id):
message = self._execute_command(
actor_id,
"RAY.TABLE_LOOKUP",
- ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID,
+ gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"),
"",
actor_id.binary(),
)
if message is None:
return None
- gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
- entry = (
- ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
- gcs_entry.Entries(0), 0))
- checkpoint_ids_str = entry.CheckpointIds()
- num_checkpoints = len(checkpoint_ids_str) // ID_SIZE
- assert len(checkpoint_ids_str) % ID_SIZE == 0
+ gcs_entry = gcs_utils.GcsEntry.FromString(message)
+ entry = gcs_utils.ActorCheckpointIdData.FromString(
+ gcs_entry.entries[0])
checkpoint_ids = [
- ray.ActorCheckpointID(
- checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)])
- for i in range(num_checkpoints)
+ ray.ActorCheckpointID(checkpoint_id)
+ for checkpoint_id in entry.checkpoint_ids
]
return {
- "ActorID": ray.utils.binary_to_hex(entry.ActorId()),
+ "ActorID": ray.utils.binary_to_hex(entry.actor_id),
"CheckpointIds": checkpoint_ids,
- "Timestamps": [
- entry.Timestamps(i) for i in range(num_checkpoints)
- ],
+ "Timestamps": list(entry.timestamps),
}
diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py
index 703c3a1420ed..76dfd3000b86 100644
--- a/python/ray/tests/cluster_utils.py
+++ b/python/ray/tests/cluster_utils.py
@@ -8,7 +8,7 @@
import redis
import ray
-from ray.core.generated.EntryType import EntryType
+from ray.gcs_utils import ClientTableData
logger = logging.getLogger(__name__)
@@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30):
clients = ray.state._parse_client_table(redis_client)
live_clients = [
client for client in clients
- if client["EntryType"] == EntryType.INSERTION
+ if client["EntryType"] == ClientTableData.INSERTION
]
expected = len(self.list_all_nodes())
diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py
index 7f1f78d1b5c4..6b4bd754cd4d 100644
--- a/python/ray/tests/test_basic.py
+++ b/python/ray/tests/test_basic.py
@@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only):
r = ray.worker.global_worker.redis_client
- r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
- ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(),
- error_data)
+ r.execute_command("RAY.TABLE_APPEND",
+ ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
+ ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
+ driver_id.binary(), error_data)
# Before https://github.com/ray-project/ray/pull/3316 this would
# give an error
- r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
- ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(),
- error_data)
+ r.execute_command("RAY.TABLE_APPEND",
+ ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
+ ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
+ driver_id.binary(), error_data)
@pytest.mark.skipif(
diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py
index 51b906695c2d..a560e461f7a2 100644
--- a/python/ray/tests/test_failure.py
+++ b/python/ray/tests/test_failure.py
@@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only):
malformed_message = "asdf"
redis_client = ray.worker.global_worker.redis_client
redis_client.execute_command(
- "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH,
- ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message)
+ "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"),
+ ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id,
+ malformed_message)
wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1)
diff --git a/python/ray/utils.py b/python/ray/utils.py
index 7b87486e325e..0db48e41d025 100644
--- a/python/ray/utils.py
+++ b/python/ray/utils.py
@@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client,
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
message, time.time())
- redis_client.execute_command("RAY.TABLE_APPEND",
- ray.gcs_utils.TablePrefix.ERROR_INFO,
- ray.gcs_utils.TablePubsub.ERROR_INFO,
- driver_id.binary(), error_data)
+ redis_client.execute_command(
+ "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
+ ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
+ driver_id.binary(), error_data)
def is_cython(obj):
diff --git a/python/ray/worker.py b/python/ray/worker.py
index 7505120574a6..710f0db43c6b 100644
--- a/python/ray/worker.py
+++ b/python/ray/worker.py
@@ -47,7 +47,7 @@
from ray import import_thread
from ray import profiling
-from ray.core.generated.ErrorType import ErrorType
+from ray.gcs_utils import ErrorType
from ray.exceptions import (
RayActorError,
RayError,
@@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id,
# Otherwise, return an exception object based on
# the error type.
error_type = int(metadata)
- if error_type == ErrorType.WORKER_DIED:
+ if error_type == ErrorType.Value("WORKER_DIED"):
return RayWorkerError()
- elif error_type == ErrorType.ACTOR_DIED:
+ elif error_type == ErrorType.Value("ACTOR_DIED"):
return RayActorError()
- elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE:
+ elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
return UnreconstructableError(ray.ObjectID(object_id.binary()))
else:
assert False, "Unrecognized error type " + str(error_type)
@@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
# Really we should just subscribe to the errors for this specific job.
# However, currently all errors seem to be published on the same channel.
error_pubsub_channel = str(
- ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii")
+ ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii")
worker.error_message_pubsub_client.subscribe(error_pubsub_channel)
# worker.error_message_pubsub_client.psubscribe("*")
@@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
if msg is None:
threads_stopped.wait(timeout=0.01)
continue
- gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
- msg["data"], 0)
- assert gcs_entry.EntriesLength() == 1
- error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
- gcs_entry.Entries(0), 0)
- driver_id = error_data.DriverId()
+ gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"])
+ assert len(gcs_entry.entries) == 1
+ error_data = ray.gcs_utils.ErrorTableData.FromString(
+ gcs_entry.entries[0])
+ driver_id = error_data.driver_id
if driver_id not in [
worker.task_driver_id.binary(),
DriverID.nil().binary()
]:
continue
- error_message = ray.utils.decode(error_data.ErrorMessage())
- if (ray.utils.decode(
- error_data.Type()) == ray_constants.TASK_PUSH_ERROR):
+ error_message = error_data.error_message
+ if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
else:
@@ -1878,14 +1876,16 @@ def connect(node,
{}, # resource_map.
{}, # placement_resource_map.
)
+ task_table_data = ray.gcs_utils.TaskTableData()
+ task_table_data.task = driver_task._serialized_raylet_task()
# Add the driver task to the task table.
- ray.state.state._execute_command(driver_task.task_id(),
- "RAY.TABLE_ADD",
- ray.gcs_utils.TablePrefix.RAYLET_TASK,
- ray.gcs_utils.TablePubsub.RAYLET_TASK,
- driver_task.task_id().binary(),
- driver_task._serialized_raylet_task())
+ ray.state.state._execute_command(
+ driver_task.task_id(), "RAY.TABLE_ADD",
+ ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
+ ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
+ driver_task.task_id().binary(),
+ task_table_data.SerializeToString())
# Set the driver's current task ID to the task ID assigned to the
# driver task.
diff --git a/python/setup.py b/python/setup.py
index db8676042de9..e7cf14737ee2 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -150,6 +150,7 @@ def find_version(*filepath):
"six >= 1.0.0",
"flatbuffers",
"faulthandler;python_version<'3.3'",
+ "protobuf",
]
setup(
diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc
index c9b1e138575d..6de29bb52764 100644
--- a/src/ray/gcs/client.cc
+++ b/src/ray/gcs/client.cc
@@ -206,10 +206,6 @@ TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_;
ClientTable &AsyncGcsClient::client_table() { return *client_table_; }
-FunctionTable &AsyncGcsClient::function_table() { return *function_table_; }
-
-ClassTable &AsyncGcsClient::class_table() { return *class_table_; }
-
HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; }
HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() {
diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h
index c9f5b4bca624..5e70025b39a0 100644
--- a/src/ray/gcs/client.h
+++ b/src/ray/gcs/client.h
@@ -44,11 +44,7 @@ class RAY_EXPORT AsyncGcsClient {
/// one event loop should be attached at a time.
Status Attach(boost::asio::io_service &io_service);
- inline FunctionTable &function_table();
// TODO: Some API for getting the error on the driver
- inline ClassTable &class_table();
- inline CustomSerializerTable &custom_serializer_table();
- inline ConfigTable &config_table();
ObjectTable &object_table();
raylet::TaskTable &raylet_task_table();
ActorTable &actor_table();
@@ -81,8 +77,6 @@ class RAY_EXPORT AsyncGcsClient {
std::string DebugString() const;
private:
- std::unique_ptr function_table_;
- std::unique_ptr class_table_;
std::unique_ptr object_table_;
std::unique_ptr raylet_task_table_;
std::unique_ptr actor_table_;
diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc
index c7dc02e50651..55115b1e2067 100644
--- a/src/ray/gcs/client_test.cc
+++ b/src/ray/gcs/client_test.cc
@@ -85,21 +85,21 @@ class TestGcsWithChainAsio : public TestGcsWithAsio {
void TestTableLookup(const DriverID &driver_id,
std::shared_ptr client) {
TaskID task_id = TaskID::FromRandom();
- auto data = std::make_shared();
- data->task_specification = "123";
+ auto data = std::make_shared();
+ data->set_task("123");
// Check that we added the correct task.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
- const protocol::TaskT &d) {
+ const TaskTableData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->task_specification, d.task_specification);
+ ASSERT_EQ(data->task(), d.task());
};
// Check that the lookup returns the added task.
auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
- const protocol::TaskT &d) {
+ const TaskTableData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->task_specification, d.task_specification);
+ ASSERT_EQ(data->task(), d.task());
test->Stop();
};
@@ -136,13 +136,13 @@ void TestLogLookup(const DriverID &driver_id,
TaskID task_id = TaskID::FromRandom();
std::vector node_manager_ids = {"abc", "def", "ghi"};
for (auto &node_manager_id : node_manager_ids) {
- auto data = std::make_shared();
- data->node_manager_id = node_manager_id;
+ auto data = std::make_shared();
+ data->set_node_manager_id(node_manager_id);
// Check that we added the correct object entries.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskReconstructionDataT &d) {
+ const TaskReconstructionData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->node_manager_id, d.node_manager_id);
+ ASSERT_EQ(data->node_manager_id(), d.node_manager_id());
};
RAY_CHECK_OK(
client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback));
@@ -151,10 +151,10 @@ void TestLogLookup(const DriverID &driver_id,
// Check that lookup returns the added object entries.
auto lookup_callback = [task_id, node_manager_ids](
gcs::AsyncGcsClient *client, const TaskID &id,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(id, task_id);
for (const auto &entry : data) {
- ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]);
+ ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == node_manager_ids.size()) {
@@ -182,7 +182,7 @@ void TestTableLookupFailure(const DriverID &driver_id,
// Check that the lookup does not return data.
auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id,
- const protocol::TaskT &d) { RAY_CHECK(false); };
+ const TaskTableData &d) { RAY_CHECK(false); };
// Check that the lookup returns an empty entry.
auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) {
@@ -207,16 +207,16 @@ void TestLogAppendAt(const DriverID &driver_id,
std::shared_ptr client) {
TaskID task_id = TaskID::FromRandom();
std::vector node_manager_ids = {"A", "B"};
- std::vector> data_log;
+ std::vector> data_log;
for (const auto &node_manager_id : node_manager_ids) {
- auto data = std::make_shared();
- data->node_manager_id = node_manager_id;
+ auto data = std::make_shared();
+ data->set_node_manager_id(node_manager_id);
data_log.push_back(data);
}
// Check that we added the correct task.
auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskReconstructionDataT &d) {
+ const TaskReconstructionData &d) {
ASSERT_EQ(id, task_id);
test->IncrementNumCallbacks();
};
@@ -242,10 +242,10 @@ void TestLogAppendAt(const DriverID &driver_id,
auto lookup_callback = [node_manager_ids](
gcs::AsyncGcsClient *client, const TaskID &id,
- const std::vector &data) {
+ const std::vector &data) {
std::vector appended_managers;
for (const auto &entry : data) {
- appended_managers.push_back(entry.node_manager_id);
+ appended_managers.push_back(entry.node_manager_id());
}
ASSERT_EQ(appended_managers, node_manager_ids);
test->Stop();
@@ -268,22 +268,22 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli
ObjectID object_id = ObjectID::FromRandom();
std::vector managers = {"abc", "def", "ghi"};
for (auto &manager : managers) {
- auto data = std::make_shared();
- data->manager = manager;
+ auto data = std::make_shared();
+ data->set_manager(manager);
// Check that we added the correct object entries.
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id,
- const ObjectTableDataT &d) {
+ const ObjectTableData &d) {
ASSERT_EQ(id, object_id);
- ASSERT_EQ(data->manager, d.manager);
+ ASSERT_EQ(data->manager(), d.manager());
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback));
}
// Check that lookup returns the added object entries.
- auto lookup_callback = [object_id, managers](
- gcs::AsyncGcsClient *client, const ObjectID &id,
- const std::vector &data) {
+ auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client,
+ const ObjectID &id,
+ const std::vector &data) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data.size(), managers.size());
test->IncrementNumCallbacks();
@@ -293,14 +293,14 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli
RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback));
for (auto &manager : managers) {
- auto data = std::make_shared();
- data->manager = manager;
+ auto data = std::make_shared();
+ data->set_manager(manager);
// Check that we added the correct object entries.
auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client,
const ObjectID &id,
- const ObjectTableDataT &d) {
+ const ObjectTableData &d) {
ASSERT_EQ(id, object_id);
- ASSERT_EQ(data->manager, d.manager);
+ ASSERT_EQ(data->manager(), d.manager());
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(
@@ -310,7 +310,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli
// Check that the entries are removed.
auto lookup_callback2 = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data.size(), 0);
test->IncrementNumCallbacks();
@@ -332,7 +332,7 @@ TEST_F(TestGcsWithAsio, TestSet) {
void TestDeleteKeysFromLog(
const DriverID &driver_id, std::shared_ptr client,
- std::vector> &data_vector) {
+ std::vector> &data_vector) {
std::vector ids;
TaskID task_id;
for (auto &data : data_vector) {
@@ -340,9 +340,9 @@ void TestDeleteKeysFromLog(
ids.push_back(task_id);
// Check that we added the correct object entries.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskReconstructionDataT &d) {
+ const TaskReconstructionData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->node_manager_id, d.node_manager_id);
+ ASSERT_EQ(data->node_manager_id(), d.node_manager_id());
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(
@@ -352,7 +352,7 @@ void TestDeleteKeysFromLog(
// Check that lookup returns the added object entries.
auto lookup_callback = [task_id, data_vector](
gcs::AsyncGcsClient *client, const TaskID &id,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data.size(), 1);
test->IncrementNumCallbacks();
@@ -367,7 +367,7 @@ void TestDeleteKeysFromLog(
}
for (const auto &task_id : ids) {
auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(data.size() == 0);
test->IncrementNumCallbacks();
@@ -379,7 +379,7 @@ void TestDeleteKeysFromLog(
void TestDeleteKeysFromTable(const DriverID &driver_id,
std::shared_ptr client,
- std::vector> &data_vector,
+ std::vector> &data_vector,
bool stop_at_end) {
std::vector ids;
TaskID task_id;
@@ -388,16 +388,16 @@ void TestDeleteKeysFromTable(const DriverID &driver_id,
ids.push_back(task_id);
// Check that we added the correct object entries.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
- const protocol::TaskT &d) {
+ const TaskTableData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->task_specification, d.task_specification);
+ ASSERT_EQ(data->task(), d.task());
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback));
}
for (const auto &task_id : ids) {
auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id,
- const protocol::TaskT &data) {
+ const TaskTableData &data) {
ASSERT_EQ(id, task_id);
test->IncrementNumCallbacks();
};
@@ -414,7 +414,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id,
test->IncrementNumCallbacks();
};
auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id,
- const protocol::TaskT &data) { ASSERT_TRUE(false); };
+ const TaskTableData &data) { ASSERT_TRUE(false); };
for (size_t i = 0; i < ids.size(); ++i) {
RAY_CHECK_OK(client->raylet_task_table().Lookup(
driver_id, task_id, undesired_callback, expected_failure_callback));
@@ -428,7 +428,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id,
void TestDeleteKeysFromSet(const DriverID &driver_id,
std::shared_ptr client,
- std::vector> &data_vector) {
+ std::vector> &data_vector) {
std::vector ids;
ObjectID object_id;
for (auto &data : data_vector) {
@@ -436,9 +436,9 @@ void TestDeleteKeysFromSet(const DriverID &driver_id,
ids.push_back(object_id);
// Check that we added the correct object entries.
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id,
- const ObjectTableDataT &d) {
+ const ObjectTableData &d) {
ASSERT_EQ(id, object_id);
- ASSERT_EQ(data->manager, d.manager);
+ ASSERT_EQ(data->manager(), d.manager());
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback));
@@ -447,7 +447,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id,
// Check that lookup returns the added object entries.
auto lookup_callback = [object_id, data_vector](
gcs::AsyncGcsClient *client, const ObjectID &id,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data.size(), 1);
test->IncrementNumCallbacks();
@@ -461,7 +461,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id,
}
for (const auto &object_id : ids) {
auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(id, object_id);
ASSERT_TRUE(data.size() == 0);
test->IncrementNumCallbacks();
@@ -474,11 +474,11 @@ void TestDeleteKeysFromSet(const DriverID &driver_id,
void TestDeleteKeys(const DriverID &driver_id,
std::shared_ptr client) {
// Test delete function for keys of Log.
- std::vector> task_reconstruction_vector;
+ std::vector> task_reconstruction_vector;
auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
- auto data = std::make_shared();
- data->node_manager_id = ObjectID::FromRandom().Hex();
+ auto data = std::make_shared();
+ data->set_node_manager_id(ObjectID::FromRandom().Hex());
task_reconstruction_vector.push_back(data);
}
};
@@ -503,11 +503,11 @@ void TestDeleteKeys(const DriverID &driver_id,
TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector);
// Test delete function for keys of Table.
- std::vector> task_vector;
+ std::vector> task_vector;
auto AppendTaskData = [&task_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
- auto task_data = std::make_shared();
- task_data->task_specification = ObjectID::FromRandom().Hex();
+ auto task_data = std::make_shared();
+ task_data->set_task(ObjectID::FromRandom().Hex());
task_vector.push_back(task_data);
}
};
@@ -529,11 +529,11 @@ void TestDeleteKeys(const DriverID &driver_id,
9 * RayConfig::instance().maximum_gcs_deletion_batch_size());
// Test delete function for keys of Set.
- std::vector> object_vector;
+ std::vector> object_vector;
auto AppendObjectData = [&object_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
- auto data = std::make_shared();
- data->manager = ObjectID::FromRandom().Hex();
+ auto data = std::make_shared();
+ data->set_manager(ObjectID::FromRandom().Hex());
object_vector.push_back(data);
}
};
@@ -561,45 +561,6 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) {
TestDeleteKeys(driver_id_, client_);
}
-// Task table callbacks.
-void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskTableDataT &data) {
- ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED);
- ASSERT_EQ(data.raylet_id, kRandomId);
-}
-
-void TaskLookupHelper(gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskTableDataT &data, bool do_stop) {
- ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED);
- ASSERT_EQ(data.raylet_id, kRandomId);
- if (do_stop) {
- test->Stop();
- }
-}
-void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskTableDataT &data) {
- TaskLookupHelper(client, id, data, /*do_stop=*/false);
-}
-void TaskLookupWithStop(gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskTableDataT &data) {
- TaskLookupHelper(client, id, data, /*do_stop=*/true);
-}
-
-void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) {
- RAY_CHECK(false);
-}
-
-void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id,
- const TaskTableDataT &data) {
- ASSERT_EQ(data.scheduling_state, SchedulingState::LOST);
- test->Stop();
-}
-
-void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) {
- RAY_CHECK(false);
- test->Stop();
-}
-
void TestLogSubscribeAll(const DriverID &driver_id,
std::shared_ptr client) {
std::vector driver_ids;
@@ -609,11 +570,11 @@ void TestLogSubscribeAll(const DriverID &driver_id,
// Callback for a notification.
auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client,
const DriverID &id,
- const std::vector data) {
+ const std::vector data) {
ASSERT_EQ(id, driver_ids[test->NumCallbacks()]);
// Check that we get notifications in the same order as the writes.
for (const auto &entry : data) {
- ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary());
+ ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary());
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == driver_ids.size()) {
@@ -660,7 +621,7 @@ void TestSetSubscribeAll(const DriverID &driver_id,
auto notification_callback = [object_ids, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsChangeMode change_mode,
- const std::vector data) {
+ const std::vector data) {
if (test->NumCallbacks() < 3 * 3) {
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
} else {
@@ -669,7 +630,7 @@ void TestSetSubscribeAll(const DriverID &driver_id,
ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]);
// Check that we get notifications in the same order as the writes.
for (const auto &entry : data) {
- ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]);
+ ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == object_ids.size() * 3 * 2) {
@@ -684,8 +645,8 @@ void TestSetSubscribeAll(const DriverID &driver_id,
// We have subscribed. Do the writes to the table.
for (size_t i = 0; i < object_ids.size(); i++) {
for (size_t j = 0; j < managers.size(); j++) {
- auto data = std::make_shared();
- data->manager = managers[j];
+ auto data = std::make_shared();
+ data->set_manager(managers[j]);
for (int k = 0; k < 3; k++) {
// Add the same entry several times.
// Expect no notification if the entry already exists.
@@ -696,8 +657,8 @@ void TestSetSubscribeAll(const DriverID &driver_id,
}
for (size_t i = 0; i < object_ids.size(); i++) {
for (size_t j = 0; j < managers.size(); j++) {
- auto data = std::make_shared();
- data->manager = managers[j];
+ auto data = std::make_shared();
+ data->set_manager(managers[j]);
for (int k = 0; k < 3; k++) {
// Remove the same entry several times.
// Expect no notification if the entry doesn't exist.
@@ -740,11 +701,11 @@ void TestTableSubscribeId(const DriverID &driver_id,
// received for keys that we requested notifications for.
auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client,
const TaskID &id,
- const protocol::TaskT &data) {
+ const TaskTableData &data) {
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, task_id2);
// Check that we get notifications in the same order as the writes.
- ASSERT_EQ(data.task_specification, task_specs2[test->NumCallbacks()]);
+ ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]);
test->IncrementNumCallbacks();
if (test->NumCallbacks() == task_specs2.size()) {
test->Stop();
@@ -771,13 +732,13 @@ void TestTableSubscribeId(const DriverID &driver_id,
// Write both keys. We should only receive notifications for the key that
// we requested them for.
for (const auto &task_spec : task_specs1) {
- auto data = std::make_shared();
- data->task_specification = task_spec;
+ auto data = std::make_shared();
+ data->set_task(task_spec);
RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr));
}
for (const auto &task_spec : task_specs2) {
- auto data = std::make_shared();
- data->task_specification = task_spec;
+ auto data = std::make_shared();
+ data->set_task(task_spec);
RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr));
}
};
@@ -808,27 +769,27 @@ void TestLogSubscribeId(const DriverID &driver_id,
// Add a log entry.
DriverID driver_id1 = DriverID::FromRandom();
std::vector driver_ids1 = {"abc", "def", "ghi"};
- auto data1 = std::make_shared();
- data1->driver_id = driver_ids1[0];
+ auto data1 = std::make_shared();
+ data1->set_driver_id(driver_ids1[0]);
RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr));
// Add a log entry at a second key.
DriverID driver_id2 = DriverID::FromRandom();
std::vector driver_ids2 = {"jkl", "mno", "pqr"};
- auto data2 = std::make_shared();
- data2->driver_id = driver_ids2[0];
+ auto data2 = std::make_shared();
+ data2->set_driver_id(driver_ids2[0]);
RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr));
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [driver_id2, driver_ids2](
gcs::AsyncGcsClient *client, const UniqueID &id,
- const std::vector &data) {
+ const std::vector &data) {
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, driver_id2);
// Check that we get notifications in the same order as the writes.
for (const auto &entry : data) {
- ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]);
+ ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == driver_ids2.size()) {
@@ -847,14 +808,14 @@ void TestLogSubscribeId(const DriverID &driver_id,
// we requested them for.
auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end());
for (const auto &driver_id_it : remaining) {
- auto data = std::make_shared();
- data->driver_id = driver_id_it;
+ auto data = std::make_shared();
+ data->set_driver_id(driver_id_it);
RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr));
}
remaining = std::vector(++driver_ids2.begin(), driver_ids2.end());
for (const auto &driver_id_it : remaining) {
- auto data = std::make_shared();
- data->driver_id = driver_id_it;
+ auto data = std::make_shared();
+ data->set_driver_id(driver_id_it);
RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr));
}
};
@@ -882,15 +843,15 @@ void TestSetSubscribeId(const DriverID &driver_id,
// Add a set entry.
ObjectID object_id1 = ObjectID::FromRandom();
std::vector managers1 = {"abc", "def", "ghi"};
- auto data1 = std::make_shared();
- data1->manager = managers1[0];
+ auto data1 = std::make_shared();
+ data1->set_manager(managers1[0]);
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr));
// Add a set entry at a second key.
ObjectID object_id2 = ObjectID::FromRandom();
std::vector managers2 = {"jkl", "mno", "pqr"};
- auto data2 = std::make_shared();
- data2->manager = managers2[0];
+ auto data2 = std::make_shared();
+ data2->set_manager(managers2[0]);
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr));
// The callback for a notification from the table. This should only be
@@ -898,13 +859,13 @@ void TestSetSubscribeId(const DriverID &driver_id,
auto notification_callback = [object_id2, managers2](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsChangeMode change_mode,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, object_id2);
// Check that we get notifications in the same order as the writes.
for (const auto &entry : data) {
- ASSERT_EQ(entry.manager, managers2[test->NumCallbacks()]);
+ ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == managers2.size()) {
@@ -923,14 +884,14 @@ void TestSetSubscribeId(const DriverID &driver_id,
// we requested them for.
auto remaining = std::vector(++managers1.begin(), managers1.end());
for (const auto &manager : remaining) {
- auto data = std::make_shared();
- data->manager = manager;
+ auto data = std::make_shared();
+ data->set_manager(manager);
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr));
}
remaining = std::vector(++managers2.begin(), managers2.end());
for (const auto &manager : remaining) {
- auto data = std::make_shared();
- data->manager = manager;
+ auto data = std::make_shared();
+ data->set_manager(manager);
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr));
}
};
@@ -958,8 +919,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id,
// Add a table entry.
TaskID task_id = TaskID::FromRandom();
std::vector task_specs = {"jkl", "mno", "pqr"};
- auto data = std::make_shared();
- data->task_specification = task_specs[0];
+ auto data = std::make_shared();
+ data->set_task(task_specs[0]);
RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr));
// The failure callback should not be called since all keys are non-empty
@@ -972,14 +933,14 @@ void TestTableSubscribeCancel(const DriverID &driver_id,
// received for keys that we requested notifications for.
auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client,
const TaskID &id,
- const protocol::TaskT &data) {
+ const TaskTableData &data) {
ASSERT_EQ(id, task_id);
// Check that we only get notifications for the first and last writes,
// since notifications are canceled in between.
if (test->NumCallbacks() == 0) {
- ASSERT_EQ(data.task_specification, task_specs.front());
+ ASSERT_EQ(data.task(), task_specs.front());
} else {
- ASSERT_EQ(data.task_specification, task_specs.back());
+ ASSERT_EQ(data.task(), task_specs.back());
}
test->IncrementNumCallbacks();
if (test->NumCallbacks() == 2) {
@@ -1001,8 +962,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id,
// a notification for these writes.
auto remaining = std::vector(++task_specs.begin(), task_specs.end());
for (const auto &task_spec : remaining) {
- auto data = std::make_shared();
- data->task_specification = task_spec;
+ auto data = std::make_shared();
+ data->set_task(task_spec);
RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
@@ -1034,15 +995,15 @@ void TestLogSubscribeCancel(const DriverID &driver_id,
// Add a log entry.
DriverID random_driver_id = DriverID::FromRandom();
std::vector driver_ids = {"jkl", "mno", "pqr"};
- auto data = std::make_shared();
- data->driver_id = driver_ids[0];
+ auto data = std::make_shared();
+ data->set_driver_id(driver_ids[0]);
RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr));
// The callback for a notification from the object table. This should only be
// received for the object that we requested notifications for.
auto notification_callback = [random_driver_id, driver_ids](
gcs::AsyncGcsClient *client, const UniqueID &id,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(id, random_driver_id);
// Check that we get a duplicate notification for the first write. We get a
// duplicate notification because the log is append-only and notifications
@@ -1050,7 +1011,7 @@ void TestLogSubscribeCancel(const DriverID &driver_id,
auto driver_ids_copy = driver_ids;
driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front());
for (const auto &entry : data) {
- ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]);
+ ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == driver_ids_copy.size()) {
@@ -1072,8 +1033,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id,
// receive a notification for these writes.
auto remaining = std::vector(++driver_ids.begin(), driver_ids.end());
for (const auto &remaining_driver_id : remaining) {
- auto data = std::make_shared();
- data->driver_id = remaining_driver_id;
+ auto data = std::make_shared();
+ data->set_driver_id(remaining_driver_id);
RAY_CHECK_OK(
client->driver_table().Append(driver_id, random_driver_id, data, nullptr));
}
@@ -1107,8 +1068,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
// Add a set entry.
ObjectID object_id = ObjectID::FromRandom();
std::vector managers = {"jkl", "mno", "pqr"};
- auto data = std::make_shared();
- data->manager = managers[0];
+ auto data = std::make_shared();
+ data->set_manager(managers[0]);
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr));
// The callback for a notification from the object table. This should only be
@@ -1116,7 +1077,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
auto notification_callback = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsChangeMode change_mode,
- const std::vector &data) {
+ const std::vector &data) {
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
ASSERT_EQ(id, object_id);
// Check that we get a duplicate notification for the first write. We get a
@@ -1124,7 +1085,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
// are canceled after the first write, then requested again.
if (data.size() == 1) {
// first notification
- ASSERT_EQ(data[0].manager, managers[0]);
+ ASSERT_EQ(data[0].manager(), managers[0]);
test->IncrementNumCallbacks();
} else {
// second notification
@@ -1132,7 +1093,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
std::unordered_set managers_set(managers.begin(), managers.end());
std::unordered_set data_managers_set;
for (const auto &entry : data) {
- data_managers_set.insert(entry.manager);
+ data_managers_set.insert(entry.manager());
test->IncrementNumCallbacks();
}
ASSERT_EQ(managers_set, data_managers_set);
@@ -1156,8 +1117,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
// receive a notification for these writes.
auto remaining = std::vector(++managers.begin(), managers.end());
for (const auto &manager : remaining) {
- auto data = std::make_shared();
- data->manager = manager;
+ auto data = std::make_shared();
+ data->set_manager(manager);
RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
@@ -1186,17 +1147,17 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) {
}
void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id,
- const ClientTableDataT &data, bool is_insertion) {
+ const ClientTableData &data, bool is_insertion) {
ClientID added_id = client->client_table().GetLocalClientId();
ASSERT_EQ(client_id, added_id);
- ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id);
- ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id);
- ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion);
+ ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id);
+ ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id);
+ ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion);
- ClientTableDataT cached_client;
+ ClientTableData cached_client;
client->client_table().GetClient(added_id, cached_client);
- ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id);
- ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion);
+ ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id);
+ ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion);
}
void TestClientTableConnect(const DriverID &driver_id,
@@ -1204,17 +1165,17 @@ void TestClientTableConnect(const DriverID &driver_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
- [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
+ [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) {
ClientTableNotification(client, id, data, true);
test->Stop();
});
// Connect and disconnect to client table. We should receive notifications
// for the addition and removal of our own entry.
- ClientTableDataT local_client_info = client->client_table().GetLocalClient();
- local_client_info.node_manager_address = "127.0.0.1";
- local_client_info.node_manager_port = 0;
- local_client_info.object_manager_port = 0;
+ ClientTableData local_client_info = client->client_table().GetLocalClient();
+ local_client_info.set_node_manager_address("127.0.0.1");
+ local_client_info.set_node_manager_port(0);
+ local_client_info.set_object_manager_port(0);
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
test->Start();
}
@@ -1229,23 +1190,23 @@ void TestClientTableDisconnect(const DriverID &driver_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
- [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
+ [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) {
ClientTableNotification(client, id, data, /*is_insertion=*/true);
// Disconnect from the client table. We should receive a notification
// for the removal of our own entry.
RAY_CHECK_OK(client->client_table().Disconnect());
});
client->client_table().RegisterClientRemovedCallback(
- [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
+ [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) {
ClientTableNotification(client, id, data, /*is_insertion=*/false);
test->Stop();
});
// Connect to the client table. We should receive notification for the
// addition of our own entry.
- ClientTableDataT local_client_info = client->client_table().GetLocalClient();
- local_client_info.node_manager_address = "127.0.0.1";
- local_client_info.node_manager_port = 0;
- local_client_info.object_manager_port = 0;
+ ClientTableData local_client_info = client->client_table().GetLocalClient();
+ local_client_info.set_node_manager_address("127.0.0.1");
+ local_client_info.set_node_manager_port(0);
+ local_client_info.set_object_manager_port(0);
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
test->Start();
}
@@ -1260,20 +1221,20 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
- [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
+ [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) {
ClientTableNotification(client, id, data, true);
});
client->client_table().RegisterClientRemovedCallback(
- [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
+ [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) {
ClientTableNotification(client, id, data, false);
test->Stop();
});
// Connect to then immediately disconnect from the client table. We should
// receive notifications for the addition and removal of our own entry.
- ClientTableDataT local_client_info = client->client_table().GetLocalClient();
- local_client_info.node_manager_address = "127.0.0.1";
- local_client_info.node_manager_port = 0;
- local_client_info.object_manager_port = 0;
+ ClientTableData local_client_info = client->client_table().GetLocalClient();
+ local_client_info.set_node_manager_address("127.0.0.1");
+ local_client_info.set_node_manager_port(0);
+ local_client_info.set_object_manager_port(0);
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
RAY_CHECK_OK(client->client_table().Disconnect());
test->Start();
@@ -1286,10 +1247,10 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) {
void TestClientTableMarkDisconnected(const DriverID &driver_id,
std::shared_ptr client) {
- ClientTableDataT local_client_info = client->client_table().GetLocalClient();
- local_client_info.node_manager_address = "127.0.0.1";
- local_client_info.node_manager_port = 0;
- local_client_info.object_manager_port = 0;
+ ClientTableData local_client_info = client->client_table().GetLocalClient();
+ local_client_info.set_node_manager_address("127.0.0.1");
+ local_client_info.set_node_manager_port(0);
+ local_client_info.set_object_manager_port(0);
// Connect to the client table to start receiving notifications.
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
// Mark a different client as dead.
@@ -1299,8 +1260,8 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id,
// marked as dead.
client->client_table().RegisterClientRemovedCallback(
[dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id,
- const ClientTableDataT &data) {
- ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id);
+ const ClientTableData &data) {
+ ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id);
test->Stop();
});
test->Start();
@@ -1316,31 +1277,31 @@ void TestHashTable(const DriverID &driver_id,
const int expected_count = 14;
ClientID client_id = ClientID::FromRandom();
// Prepare the first resource map: data_map1.
- auto cpu_data = std::make_shared();
- cpu_data->resource_name = "CPU";
- cpu_data->resource_capacity = 100;
- auto gpu_data = std::make_shared();
- gpu_data->resource_name = "GPU";
- gpu_data->resource_capacity = 2;
+ auto cpu_data = std::make_shared();
+ cpu_data->set_resource_name("CPU");
+ cpu_data->set_resource_capacity(100);
+ auto gpu_data = std::make_shared();
+ gpu_data->set_resource_name("GPU");
+ gpu_data->set_resource_capacity(2);
DynamicResourceTable::DataMap data_map1;
data_map1.emplace("CPU", cpu_data);
data_map1.emplace("GPU", gpu_data);
// Prepare the second resource map: data_map2 which decreases CPU,
// increases GPU and add a new CUSTOM compared to data_map1.
- auto data_cpu = std::make_shared();
- data_cpu->resource_name = "CPU";
- data_cpu->resource_capacity = 50;
- auto data_gpu = std::make_shared();
- data_gpu->resource_name = "GPU";
- data_gpu->resource_capacity = 10;
- auto data_custom = std::make_shared();
- data_custom->resource_name = "CUSTOM";
- data_custom->resource_capacity = 2;
+ auto data_cpu = std::make_shared();
+ data_cpu->set_resource_name("CPU");
+ data_cpu->set_resource_capacity(50);
+ auto data_gpu = std::make_shared();
+ data_gpu->set_resource_name("GPU");
+ data_gpu->set_resource_capacity(10);
+ auto data_custom = std::make_shared();
+ data_custom->set_resource_name("CUSTOM");
+ data_custom->set_resource_capacity(2);
DynamicResourceTable::DataMap data_map2;
data_map2.emplace("CPU", data_cpu);
data_map2.emplace("GPU", data_gpu);
data_map2.emplace("CUSTOM", data_custom);
- data_map2["CPU"]->resource_capacity = 50;
+ data_map2["CPU"]->set_resource_capacity(50);
// This is a common comparison function for the test.
auto compare_test = [](const DynamicResourceTable::DataMap &data1,
const DynamicResourceTable::DataMap &data2) {
@@ -1348,8 +1309,8 @@ void TestHashTable(const DriverID &driver_id,
for (const auto &data : data1) {
auto iter = data2.find(data.first);
ASSERT_TRUE(iter != data2.end());
- ASSERT_EQ(iter->second->resource_name, data.second->resource_name);
- ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity);
+ ASSERT_EQ(iter->second->resource_name(), data.second->resource_name());
+ ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity());
}
};
auto subscribe_callback = [](AsyncGcsClient *client) {
diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs
index 90476da73425..c06c79a02928 100644
--- a/src/ray/gcs/format/gcs.fbs
+++ b/src/ray/gcs/format/gcs.fbs
@@ -1,52 +1,9 @@
-enum Language:int {
- PYTHON = 0,
- CPP = 1,
- JAVA = 2
-}
-
-// These indexes are mapped to strings in ray_redis_module.cc.
-enum TablePrefix:int {
- UNUSED = 0,
- TASK,
- RAYLET_TASK,
- CLIENT,
- OBJECT,
- ACTOR,
- FUNCTION,
- TASK_RECONSTRUCTION,
- HEARTBEAT,
- HEARTBEAT_BATCH,
- ERROR_INFO,
- DRIVER,
- PROFILE,
- TASK_LEASE,
- ACTOR_CHECKPOINT,
- ACTOR_CHECKPOINT_ID,
- NODE_RESOURCE,
-}
+// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`).
-// The channel that Add operations to the Table should be published on, if any.
-enum TablePubsub:int {
- NO_PUBLISH = 0,
- TASK,
- RAYLET_TASK,
- CLIENT,
- OBJECT,
- ACTOR,
- HEARTBEAT,
- HEARTBEAT_BATCH,
- ERROR_INFO,
- TASK_LEASE,
- DRIVER,
- NODE_RESOURCE,
-}
-
-// Enum for the entry type in the ClientTable
-enum EntryType:int {
- INSERTION = 0,
- DELETION,
- RES_CREATEUPDATE,
- RES_DELETE,
+enum Language:int {
+ PYTHON=0,
+ JAVA=1,
+ CPP=2,
}
table Arg {
@@ -120,118 +77,6 @@ table ResourcePair {
value: double;
}
-enum GcsChangeMode:int {
- APPEND_OR_ADD = 0,
- REMOVE,
-}
-
-table GcsEntry {
- change_mode: GcsChangeMode;
- id: string;
- entries: [string];
-}
-
-table FunctionTableData {
- language: Language;
- name: string;
- data: string;
-}
-
-table ObjectTableData {
- // The size of the object.
- object_size: long;
- // The node manager ID that this object appeared on or was evicted by.
- manager: string;
-}
-
-table TaskReconstructionData {
- // The number of times this task has been reconstructed so far.
- num_reconstructions: int;
- // The node manager that is trying to reconstruct the task.
- node_manager_id: string;
-}
-
-enum SchedulingState:int {
- NONE = 0,
- WAITING = 1,
- SCHEDULED = 2,
- QUEUED = 4,
- RUNNING = 8,
- DONE = 16,
- LOST = 32,
- RECONSTRUCTING = 64
-}
-
-table TaskTableData {
- // The state of the task.
- scheduling_state: SchedulingState;
- // A raylet ID.
- raylet_id: string;
- // A string of bytes representing the task's TaskExecutionDependencies.
- execution_dependencies: string;
- // The number of times the task was spilled back by raylets.
- spillback_count: long;
- // A string of bytes representing the task specification.
- task_info: string;
- // TODO(pcm): This is at the moment duplicated in task_info, remove that one
- updated: bool;
-}
-
-table TaskTableTestAndUpdate {
- test_raylet_id: string;
- test_state_bitmask: SchedulingState;
- update_state: SchedulingState;
-}
-
-table ClassTableData {
-}
-
-enum ActorState:int {
- // Actor is alive.
- ALIVE = 0,
- // Actor is dead, now being reconstructed.
- // After reconstruction finishes, the state will become alive again.
- RECONSTRUCTING = 1,
- // Actor is already dead and won't be reconstructed.
- DEAD = 2
-}
-
-table ActorTableData {
- // The ID of the actor that was created.
- actor_id: string;
- // The dummy object ID returned by the actor creation task. If the actor
- // dies, then this is the object that should be reconstructed for the actor
- // to be recreated.
- actor_creation_dummy_object_id: string;
- // The ID of the driver that created the actor.
- driver_id: string;
- // The ID of the node manager that created the actor.
- node_manager_id: string;
- // Current state of this actor.
- state: ActorState;
- // Max number of times this actor should be reconstructed.
- max_reconstructions: int;
- // Remaining number of reconstructions.
- remaining_reconstructions: int;
-}
-
-table ErrorTableData {
- // The ID of the driver that the error is for.
- driver_id: string;
- // The type of the error.
- type: string;
- // The error message.
- error_message: string;
- // The timestamp of the error message.
- timestamp: double;
-}
-
-table CustomSerializerData {
-}
-
-table ConfigTableData {
-}
-
table ProfileEvent {
// The type of the event.
event_type: string;
@@ -258,119 +103,3 @@ table ProfileTableData {
// we don't want each event to require a GCS command.
profile_events: [ProfileEvent];
}
-
-table RayResource {
- // The type of the resource.
- resource_name: string;
- // The total capacity of this resource type.
- resource_capacity: double;
-}
-
-table ClientTableData {
- // The client ID of the client that the message is about.
- client_id: string;
- // The IP address of the client's node manager.
- node_manager_address: string;
- // The IPC socket name of the client's raylet.
- raylet_socket_name: string;
- // The IPC socket name of the client's plasma store.
- object_store_socket_name: string;
- // The port at which the client's node manager is listening for TCP
- // connections from other node managers.
- node_manager_port: int;
- // The port at which the client's object manager is listening for TCP
- // connections from other object managers.
- object_manager_port: int;
- // Enum to store the entry type in the log
- entry_type: EntryType = INSERTION;
- resources_total_label: [string];
- resources_total_capacity: [double];
-}
-
-table HeartbeatTableData {
- // Node manager client id
- client_id: string;
- // Resource capacity currently available on this node manager.
- resources_available_label: [string];
- resources_available_capacity: [double];
- // Total resource capacity configured for this node manager.
- resources_total_label: [string];
- resources_total_capacity: [double];
- // Aggregate outstanding resource load on this node manager.
- resource_load_label: [string];
- resource_load_capacity: [double];
-}
-
-table HeartbeatBatchTableData {
- batch: [HeartbeatTableData];
-}
-
-// Data for a lease on task execution.
-table TaskLeaseData {
- // Node manager client ID.
- node_manager_id: string;
- // The time that the lease was last acquired at. NOTE(swang): This is the
- // system clock time according to the node that added the entry and is not
- // synchronized with other nodes.
- acquired_at: long;
- // The period that the lease is active for.
- timeout: long;
-}
-
-table DriverTableData {
- // The driver ID.
- driver_id: string;
- // Whether it's dead.
- is_dead: bool;
-}
-
-// This table stores the actor checkpoint data. An actor checkpoint
-// is the snapshot of an actor's state in the actor registration.
-// See `actor_registration.h` for more detailed explanation of these fields.
-table ActorCheckpointData {
- // ID of this actor.
- actor_id: string;
- // The dummy object ID of actor's most recently executed task.
- execution_dependency: string;
- // A list of IDs of this actor's handles.
- handle_ids: [string];
- // The task counters of the above handles.
- task_counters: [long];
- // The frontier dependencies of the above handles.
- frontier_dependencies: [string];
- // A list of unreleased dummy objects from this actor.
- unreleased_dummy_objects: [string];
- // The numbers of dependencies for the above unreleased dummy objects.
- num_dummy_object_dependencies: [int];
-}
-
-// This table stores the actor-to-available-checkpoint-ids mapping.
-table ActorCheckpointIdData {
- // ID of this actor.
- actor_id: string;
- // IDs of this actor's available checkpoints.
- // Note, this is a long string that concatenates all the IDs.
- checkpoint_ids: string;
- // A list of the timestamps for each of the above `checkpoint_ids`.
- timestamps: [long];
-}
-
-// This enum type is used as object's metadata to indicate the object's creating
-// task has failed because of a certain error.
-// TODO(hchen): We may want to make these errors more specific. E.g., we may want
-// to distinguish between intentional and expected actor failures, and between
-// worker process failure and node failure.
-enum ErrorType:int {
- // Indicates that a task failed because the worker died unexpectedly while executing it.
- WORKER_DIED = 1,
- // Indicates that a task failed because the actor died unexpectedly before finishing it.
- ACTOR_DIED = 2,
- // Indicates that an object is lost and cannot be reconstructed.
- // Note, this currently only happens to actor objects. When the actor's state is already
- // after the object's creating task, the actor cannot re-run the task.
- // TODO(hchen): we may want to reuse this error type for more cases. E.g.,
- // 1) A object that was put by the driver.
- // 2) The object's creating task is already cleaned up from GCS (this currently
- // crashes raylet).
- OBJECT_UNRECONSTRUCTABLE = 3,
-}
diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h
index fc42e5cd98c2..093aab2455d9 100644
--- a/src/ray/gcs/redis_context.h
+++ b/src/ray/gcs/redis_context.h
@@ -9,7 +9,7 @@
#include "ray/common/status.h"
#include "ray/util/logging.h"
-#include "ray/gcs/format/gcs_generated.h"
+#include "ray/protobuf/gcs.pb.h"
extern "C" {
#include "ray/thirdparty/hiredis/adapters/ae.h"
@@ -25,6 +25,9 @@ namespace ray {
namespace gcs {
+using rpc::TablePrefix;
+using rpc::TablePubsub;
+
/// A simple reply wrapper for redis reply.
class CallbackReply {
public:
@@ -126,8 +129,8 @@ class RedisContext {
/// -1 for unused. If set, then data must be provided.
/// \return Status.
template
- Status RunAsync(const std::string &command, const ID &id, const uint8_t *data,
- int64_t length, const TablePrefix prefix,
+ Status RunAsync(const std::string &command, const ID &id, const void *data,
+ size_t length, const TablePrefix prefix,
const TablePubsub pubsub_channel, RedisCallback redisCallback,
int log_length = -1);
@@ -157,9 +160,9 @@ class RedisContext {
};
template
-Status RedisContext::RunAsync(const std::string &command, const ID &id,
- const uint8_t *data, int64_t length,
- const TablePrefix prefix, const TablePubsub pubsub_channel,
+Status RedisContext::RunAsync(const std::string &command, const ID &id, const void *data,
+ size_t length, const TablePrefix prefix,
+ const TablePubsub pubsub_channel,
RedisCallback redisCallback, int log_length) {
int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false);
if (length > 0) {
diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc
index e291b7ffdb32..c3a82c320d06 100644
--- a/src/ray/gcs/redis_module/ray_redis_module.cc
+++ b/src/ray/gcs/redis_module/ray_redis_module.cc
@@ -5,11 +5,16 @@
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/gcs/format/gcs_generated.h"
+#include "ray/protobuf/gcs.pb.h"
#include "ray/util/logging.h"
#include "redis_string.h"
#include "redismodule.h"
using ray::Status;
+using ray::rpc::GcsChangeMode;
+using ray::rpc::GcsEntry;
+using ray::rpc::TablePrefix;
+using ray::rpc::TablePubsub;
#if RAY_USE_NEW_GCS
// Under this flag, ray-project/credis will be loaded. Specifically, via
@@ -64,8 +69,8 @@ Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channe
REDISMODULE_OK) {
return Status::RedisError("Pubsub channel must be a valid integer.");
}
- if (pubsub_channel_long > static_cast(TablePubsub::MAX) ||
- pubsub_channel_long < static_cast(TablePubsub::MIN)) {
+ if (pubsub_channel_long >= static_cast(TablePubsub::TABLE_PUBSUB_MAX) ||
+ pubsub_channel_long <= static_cast(TablePubsub::TABLE_PUBSUB_MIN)) {
return Status::RedisError("Pubsub channel must be in the TablePubsub range.");
} else {
*out = static_cast(pubsub_channel_long);
@@ -80,7 +85,7 @@ Status FormatPubsubChannel(RedisModuleString **out, RedisModuleCtx *ctx,
const RedisModuleString *id) {
// Format the pubsub channel enum to a string. TablePubsub_MAX should be more
// than enough digits, but add 1 just in case for the null terminator.
- char pubsub_channel[static_cast(TablePubsub::MAX) + 1];
+ char pubsub_channel[static_cast(TablePubsub::TABLE_PUBSUB_MAX) + 1];
TablePubsub table_pubsub;
RAY_RETURN_NOT_OK(ParseTablePubsub(&table_pubsub, pubsub_channel_str));
sprintf(pubsub_channel, "%d", static_cast(table_pubsub));
@@ -95,8 +100,8 @@ Status ParseTablePrefix(const RedisModuleString *table_prefix_str, TablePrefix *
REDISMODULE_OK) {
return Status::RedisError("Prefix must be a valid TablePrefix integer");
}
- if (table_prefix_long > static_cast(TablePrefix::MAX) ||
- table_prefix_long < static_cast(TablePrefix::MIN)) {
+ if (table_prefix_long >= static_cast(TablePrefix::TABLE_PREFIX_MAX) ||
+ table_prefix_long <= static_cast(TablePrefix::TABLE_PREFIX_MIN)) {
return Status::RedisError("Prefix must be in the TablePrefix range");
} else {
*out = static_cast(table_prefix_long);
@@ -113,7 +118,7 @@ RedisModuleString *PrefixedKeyString(RedisModuleCtx *ctx, RedisModuleString *pre
if (!ParseTablePrefix(prefix_enum, &prefix).ok()) {
return nullptr;
}
- return RedisString_Format(ctx, "%s%S", EnumNameTablePrefix(prefix), keyname);
+ return RedisString_Format(ctx, "%s%S", TablePrefix_Name(prefix).c_str(), keyname);
}
// TODO(swang): This helper function should be deprecated by the version below,
@@ -136,8 +141,8 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx,
int mode, RedisModuleString **mutated_key_str) {
TablePrefix prefix;
RAY_RETURN_NOT_OK(ParseTablePrefix(prefix_enum, &prefix));
- *out =
- OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, mutated_key_str);
+ *out = OpenPrefixedKey(ctx, TablePrefix_Name(prefix).c_str(), keyname, mode,
+ mutated_key_str);
return Status::OK();
}
@@ -165,18 +170,24 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st
return Status::OK();
}
-/// This is a helper method to convert a redis module string to a flatbuffer
-/// string.
+/// A helper function that creates `GcsEntry` protobuf object.
///
-/// \param fbb The flatbuffer builder.
-/// \param redis_string The redis string.
-/// \return The flatbuffer string.
-flatbuffers::Offset RedisStringToFlatbuf(
- flatbuffers::FlatBufferBuilder &fbb, RedisModuleString *redis_string) {
- size_t redis_string_size;
- const char *redis_string_str =
- RedisModule_StringPtrLen(redis_string, &redis_string_size);
- return fbb.CreateString(redis_string_str, redis_string_size);
+/// \param[in] id Id of the entry.
+/// \param[in] change_mode Change mode of the entry.
+/// \param[in] entries Vector of entries.
+/// \param[out] result The created `GcsEntry` object.
+inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode,
+ const std::vector &entries,
+ GcsEntry *result) {
+ const char *data;
+ size_t size;
+ data = RedisModule_StringPtrLen(id, &size);
+ result->set_id(data, size);
+ result->set_change_mode(change_mode);
+ for (const auto &entry : entries) {
+ data = RedisModule_StringPtrLen(entry, &size);
+ result->add_entries(data, size);
+ }
}
/// Helper method to publish formatted data to target channel.
@@ -234,13 +245,10 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st
RedisModuleString *id, GcsChangeMode change_mode,
RedisModuleString *data) {
// Serialize the notification to send.
- flatbuffers::FlatBufferBuilder fbb;
- auto data_flatbuf = RedisStringToFlatbuf(fbb, data);
- auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id),
- fbb.CreateVector(&data_flatbuf, 1));
- fbb.Finish(message);
- auto data_buffer = RedisModule_CreateString(
- ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize());
+ GcsEntry gcs_entry;
+ CreateGcsEntry(id, change_mode, {data}, &gcs_entry);
+ std::string str = gcs_entry.SerializeAsString();
+ auto data_buffer = RedisModule_CreateString(ctx, str.data(), str.size());
return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer);
}
@@ -570,19 +578,20 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
size_t update_data_len = 0;
const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len);
- auto data_vec = flatbuffers::GetRoot(update_data_buf);
- *change_mode = data_vec->change_mode();
+ GcsEntry gcs_entry;
+ gcs_entry.ParseFromArray(update_data_buf, update_data_len);
+ *change_mode = gcs_entry.change_mode();
+
if (*change_mode == GcsChangeMode::APPEND_OR_ADD) {
// This code path means they are updating command.
- size_t total_size = data_vec->entries()->size();
+ size_t total_size = gcs_entry.entries_size();
REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector.");
for (int i = 0; i < total_size; i += 2) {
// Reconstruct a key-value pair from a flattened list.
RedisModuleString *entry_key = RedisModule_CreateString(
- ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size());
- RedisModuleString *entry_value =
- RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(),
- data_vec->entries()->Get(i + 1)->size());
+ ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size());
+ RedisModuleString *entry_value = RedisModule_CreateString(
+ ctx, gcs_entry.entries(i + 1).data(), gcs_entry.entries(i + 1).size());
// Returning 0 if key exists(still updated), 1 if the key is created.
RAY_IGNORE_EXPR(
RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL));
@@ -590,27 +599,25 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
*changed_data = update_data;
} else {
// This code path means the command wants to remove the entries.
- size_t total_size = data_vec->entries()->size();
- flatbuffers::FlatBufferBuilder fbb;
- std::vector> data;
+ GcsEntry updated;
+ updated.set_id(gcs_entry.id());
+ updated.set_change_mode(gcs_entry.change_mode());
+
+ size_t total_size = gcs_entry.entries_size();
for (int i = 0; i < total_size; i++) {
RedisModuleString *entry_key = RedisModule_CreateString(
- ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size());
+ ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size());
int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key,
REDISMODULE_HASH_DELETE, NULL);
if (deleted_num != 0) {
// The corresponding key is removed.
- data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(),
- data_vec->entries()->Get(i)->size()));
+ updated.add_entries(gcs_entry.entries(i));
}
}
- auto message =
- CreateGcsEntry(fbb, data_vec->change_mode(),
- fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()),
- fbb.CreateVector(data));
- fbb.Finish(message);
- *changed_data = RedisModule_CreateString(
- ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize());
+
+ // Serialize updated data.
+ std::string str = updated.SerializeAsString();
+ *changed_data = RedisModule_CreateString(ctx, str.data(), str.size());
auto size = RedisModule_ValueLength(key);
if (size == 0) {
REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK,
@@ -631,7 +638,7 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
/// key should be published to. When publishing to a specific client, the
/// channel name should be :.
/// \param id The ID of the key to remove from.
-/// \param data The GcsEntry flatbugger data used to update this hash table.
+/// \param data The GcsEntry protobuf data used to update this hash table.
/// 1). For deletion, this is a list of keys.
/// 2). For updating, this is a list of pairs with each key followed by the value.
/// \return OK if the remove succeeds, or an error message string if the remove
@@ -648,7 +655,7 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a
return Hash_DoPublish(ctx, new_argv.data());
}
-/// A helper function to create and finish a GcsEntry, based on the
+/// A helper function to create a GcsEntry protobuf, based on the
/// current value or values at the given key.
///
/// \param ctx The Redis module context.
@@ -658,21 +665,18 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a
/// \param prefix_str The string prefix associated with the open Redis key.
/// When parsed, this is expected to be a TablePrefix.
/// \param entry_id The UniqueID associated with the open Redis key.
-/// \param fbb A flatbuffer builder used to build the GcsEntry.
-Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
- RedisModuleString *prefix_str, RedisModuleString *entry_id,
- flatbuffers::FlatBufferBuilder &fbb) {
+/// \param[out] gcs_entry The created GcsEntry.
+Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
+ RedisModuleString *prefix_str, RedisModuleString *entry_id,
+ GcsEntry *gcs_entry) {
auto key_type = RedisModule_KeyType(table_key);
switch (key_type) {
case REDISMODULE_KEYTYPE_STRING: {
- // Build the flatbuffer from the string data.
+ // Build the GcsEntry from the string data.
+ CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry);
size_t data_len = 0;
char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ);
- auto data = fbb.CreateString(data_buf, data_len);
- auto message =
- CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
- RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1));
- fbb.Finish(message);
+ gcs_entry->add_entries(data_buf, data_len);
} break;
case REDISMODULE_KEYTYPE_LIST:
case REDISMODULE_KEYTYPE_HASH:
@@ -696,27 +700,20 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str);
break;
}
- // Build the flatbuffer from the set of log entries.
+ // Build the GcsEntry from the set of log entries.
if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) {
return Status::RedisError("Empty list/set/hash or wrong type");
}
- std::vector> data;
+ CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry);
for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) {
RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i);
size_t len;
const char *element_str = RedisModule_CallReplyStringPtr(element, &len);
- data.push_back(fbb.CreateString(element_str, len));
+ gcs_entry->add_entries(element_str, len);
}
- auto message =
- CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
- RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data));
- fbb.Finish(message);
} break;
case REDISMODULE_KEYTYPE_EMPTY: {
- auto message = CreateGcsEntry(
- fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id),
- fbb.CreateVector(std::vector>()));
- fbb.Finish(message);
+ CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry);
} break;
default:
return Status::RedisError("Invalid Redis type during lookup.");
@@ -752,11 +749,12 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
if (table_key == nullptr) {
RedisModule_ReplyWithNull(ctx);
} else {
- // Serialize the data to a flatbuffer to return to the client.
- flatbuffers::FlatBufferBuilder fbb;
- REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb));
- RedisModule_ReplyWithStringBuffer(
- ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize());
+ // Serialize the data to a GcsEntry to return to the client.
+ GcsEntry gcs_entry;
+ REPLY_AND_RETURN_IF_NOT_OK(
+ TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry));
+ std::string str = gcs_entry.SerializeAsString();
+ RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size());
}
return REDISMODULE_OK;
}
@@ -870,10 +868,11 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin
// Publish the current value at the key to the client that is requesting
// notifications. An empty notification will be published if the key is
// empty.
- flatbuffers::FlatBufferBuilder fbb;
- REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb));
- RedisModule_Call(ctx, "PUBLISH", "sb", client_channel,
- reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize());
+ GcsEntry gcs_entry;
+ REPLY_AND_RETURN_IF_NOT_OK(
+ TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry));
+ std::string str = gcs_entry.SerializeAsString();
+ RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size());
return RedisModule_ReplyWithNull(ctx);
}
@@ -940,53 +939,6 @@ Status IsNil(bool *out, const std::string &data) {
return Status::OK();
}
-// This is a temporary redis command that will be removed once
-// the GCS uses https://github.com/pcmoritz/credis.
-// Be careful, this only supports Task Table payloads.
-int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
- int argc) {
- if (argc != 5) {
- return RedisModule_WrongArity(ctx);
- }
- RedisModuleString *prefix_str = argv[1];
- RedisModuleString *id = argv[3];
- RedisModuleString *update_data = argv[4];
-
- RedisModuleKey *key;
- REPLY_AND_RETURN_IF_NOT_OK(
- OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE));
-
- size_t value_len = 0;
- char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ);
-
- size_t update_len = 0;
- const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len);
-
- auto data =
- flatbuffers::GetMutableRoot(reinterpret_cast(value_buf));
-
- auto update = flatbuffers::GetRoot(update_buf);
-
- bool do_update = static_cast(data->scheduling_state()) &
- static_cast(update->test_state_bitmask());
-
- bool is_nil_result;
- REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str()));
- if (!is_nil_result) {
- do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str();
- }
-
- if (do_update) {
- REPLY_AND_RETURN_IF_FALSE(data->mutate_scheduling_state(update->update_state()),
- "mutate_scheduling_state failed");
- }
- REPLY_AND_RETURN_IF_FALSE(data->mutate_updated(do_update), "mutate_updated failed");
-
- int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len);
-
- return result;
-}
-
std::string DebugString() {
std::stringstream result;
result << "RedisModule:";
@@ -1016,7 +968,6 @@ AUTO_MEMORY(TableLookup_RedisCommand);
AUTO_MEMORY(TableRequestNotifications_RedisCommand);
AUTO_MEMORY(TableDelete_RedisCommand);
AUTO_MEMORY(TableCancelNotifications_RedisCommand);
-AUTO_MEMORY(TableTestAndUpdate_RedisCommand);
AUTO_MEMORY(DebugString_RedisCommand);
#if RAY_USE_NEW_GCS
AUTO_MEMORY(ChainTableAdd_RedisCommand);
@@ -1082,12 +1033,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_ERR;
}
- if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update",
- TableTestAndUpdate_RedisCommand, "write", 0, 0,
- 0) == REDISMODULE_ERR) {
- return REDISMODULE_ERR;
- }
-
if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand,
"readonly", 0, 0, 0) == REDISMODULE_ERR) {
return REDISMODULE_ERR;
diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc
index 33f1615580a6..b7c19ebfd595 100644
--- a/src/ray/gcs/tables.cc
+++ b/src/ray/gcs/tables.cc
@@ -3,6 +3,7 @@
#include "ray/common/common_protocol.h"
#include "ray/common/ray_config.h"
#include "ray/gcs/client.h"
+#include "ray/rpc/util.h"
#include "ray/util/util.h"
namespace {
@@ -39,48 +40,44 @@ namespace gcs {
template
Status Log::Append(const DriverID &driver_id, const ID &id,
- std::shared_ptr &dataT, const WriteCallback &done) {
+ std::shared_ptr &data, const WriteCallback &done) {
num_appends_++;
- auto callback = [this, id, dataT, done](const CallbackReply &reply) {
+ auto callback = [this, id, data, done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
// Failed to append the entry.
RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:"
<< status.ToString();
if (done != nullptr) {
- (done)(client_, id, *dataT);
+ (done)(client_, id, *data);
}
};
- flatbuffers::FlatBufferBuilder fbb;
- fbb.ForceDefaults(true);
- fbb.Finish(Data::Pack(fbb, dataT.get()));
- return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id,
- fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
- pubsub_channel_, std::move(callback));
+ std::string str = data->SerializeAsString();
+ return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(),
+ str.length(), prefix_, pubsub_channel_,
+ std::move(callback));
}
template
Status Log::AppendAt(const DriverID &driver_id, const ID &id,
- std::shared_ptr &dataT, const WriteCallback &done,
+ std::shared_ptr &data, const WriteCallback &done,
const WriteCallback &failure, int log_length) {
num_appends_++;
- auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) {
+ auto callback = [this, id, data, done, failure](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
if (status.ok()) {
if (done != nullptr) {
- (done)(client_, id, *dataT);
+ (done)(client_, id, *data);
}
} else {
if (failure != nullptr) {
- (failure)(client_, id, *dataT);
+ (failure)(client_, id, *data);
}
}
};
- flatbuffers::FlatBufferBuilder fbb;
- fbb.ForceDefaults(true);
- fbb.Finish(Data::Pack(fbb, dataT.get()));
- return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id,
- fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
- pubsub_channel_, std::move(callback), log_length);
+ std::string str = data->SerializeAsString();
+ return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(),
+ str.length(), prefix_, pubsub_channel_,
+ std::move(callback), log_length);
}
template
@@ -89,16 +86,15 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id,
num_lookups_++;
auto callback = [this, id, lookup](const CallbackReply &reply) {
if (lookup != nullptr) {
- std::vector results;
+ std::vector results;
if (!reply.IsNil()) {
- const auto data = reply.ReadAsString();
- auto root = flatbuffers::GetRoot(data.data());
- RAY_CHECK(from_flatbuf(*root->id()) == id);
- for (size_t i = 0; i < root->entries()->size(); i++) {
- DataT result;
- auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data());
- data_root->UnPackTo(&result);
- results.emplace_back(std::move(result));
+ GcsEntry gcs_entry;
+ gcs_entry.ParseFromString(reply.ReadAsString());
+ RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id);
+ for (size_t i = 0; i < gcs_entry.entries_size(); i++) {
+ Data data;
+ data.ParseFromString(gcs_entry.entries(i));
+ results.emplace_back(std::move(data));
}
}
lookup(client_, id, results);
@@ -115,7 +111,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien
const SubscriptionCallback &done) {
auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id,
const GcsChangeMode change_mode,
- const std::vector &data) {
+ const std::vector &data) {
RAY_CHECK(change_mode != GcsChangeMode::REMOVE);
subscribe(client, id, data);
};
@@ -141,19 +137,16 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien
// Data is provided. This is the callback for a message.
if (subscribe != nullptr) {
// Parse the notification.
- auto root = flatbuffers::GetRoot(data.data());
- ID id;
- if (root->id()->size() > 0) {
- id = from_flatbuf(*root->id());
- }
- std::vector results;
- for (size_t i = 0; i < root->entries()->size(); i++) {
- DataT result;
- auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data());
- data_root->UnPackTo(&result);
+ GcsEntry gcs_entry;
+ gcs_entry.ParseFromString(data);
+ ID id = ID::FromBinary(gcs_entry.id());
+ std::vector results;
+ for (size_t i = 0; i < gcs_entry.entries_size(); i++) {
+ Data result;
+ result.ParseFromString(gcs_entry.entries(i));
results.emplace_back(std::move(result));
}
- subscribe(client_, id, root->change_mode(), results);
+ subscribe(client_, id, gcs_entry.change_mode(), results);
}
}
};
@@ -234,19 +227,17 @@ std::string Log::DebugString() const {
template
Status Table::Add(const DriverID &driver_id, const ID &id,
- std::shared_ptr &dataT, const WriteCallback &done) {
+ std::shared_ptr &data, const WriteCallback &done) {
num_adds_++;
- auto callback = [this, id, dataT, done](const CallbackReply &reply) {
+ auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
- (done)(client_, id, *dataT);
+ (done)(client_, id, *data);
}
};
- flatbuffers::FlatBufferBuilder fbb;
- fbb.ForceDefaults(true);
- fbb.Finish(Data::Pack(fbb, dataT.get()));
- return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id,
- fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
- pubsub_channel_, std::move(callback));
+ std::string str = data->SerializeAsString();
+ return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(),
+ str.length(), prefix_, pubsub_channel_,
+ std::move(callback));
}
template
@@ -255,7 +246,7 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id,
num_lookups_++;
return Log::Lookup(driver_id, id,
[lookup, failure](AsyncGcsClient *client, const ID &id,
- const std::vector &data) {
+ const std::vector &data) {
if (data.empty()) {
if (failure != nullptr) {
(failure)(client, id);
@@ -277,7 +268,7 @@ Status Table::Subscribe(const DriverID &driver_id, const ClientID &cli
return Log::Subscribe(
driver_id, client_id,
[subscribe, failure](AsyncGcsClient *client, const ID &id,
- const std::vector &data) {
+ const std::vector &data) {
RAY_CHECK(data.empty() || data.size() == 1);
if (data.size() == 1) {
subscribe(client, id, data[0]);
@@ -299,36 +290,30 @@ std::string Table::DebugString() const {
template
Status Set::Add(const DriverID &driver_id, const ID &id,
- std::shared_ptr &dataT, const WriteCallback &done) {
+ std::shared_ptr &data, const WriteCallback &done) {
num_adds_++;
- auto callback = [this, id, dataT, done](const CallbackReply &reply) {
+ auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
- (done)(client_, id, *dataT);
+ (done)(client_, id, *data);
}
};
- flatbuffers::FlatBufferBuilder fbb;
- fbb.ForceDefaults(true);
- fbb.Finish(Data::Pack(fbb, dataT.get()));
- return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(),
- fbb.GetSize(), prefix_, pubsub_channel_,
- std::move(callback));
+ std::string str = data->SerializeAsString();
+ return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(),
+ prefix_, pubsub_channel_, std::move(callback));
}
template
Status Set::Remove(const DriverID &driver_id, const ID &id,
- std::shared_ptr &dataT, const WriteCallback &done) {
+ std::shared_ptr &data, const WriteCallback &done) {
num_removes_++;
- auto callback = [this, id, dataT, done](const CallbackReply &reply) {
+ auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
- (done)(client_, id, *dataT);
+ (done)(client_, id, *data);
}
};
- flatbuffers::FlatBufferBuilder fbb;
- fbb.ForceDefaults(true);
- fbb.Finish(Data::Pack(fbb, dataT.get()));
- return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(),
- fbb.GetSize(), prefix_, pubsub_channel_,
- std::move(callback));
+ std::string str = data->SerializeAsString();
+ return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(),
+ prefix_, pubsub_channel_, std::move(callback));
}
template
@@ -348,26 +333,16 @@ Status Hash::Update(const DriverID &driver_id, const ID &id,
(done)(client_, id, data_map);
}
};
- flatbuffers::FlatBufferBuilder fbb;
- std::vector> data_vec;
- data_vec.reserve(data_map.size() * 2);
- for (auto const &pair : data_map) {
- // Add the key.
- data_vec.push_back(fbb.CreateString(pair.first));
- flatbuffers::FlatBufferBuilder fbb_data;
- fbb_data.ForceDefaults(true);
- fbb_data.Finish(Data::Pack(fbb_data, pair.second.get()));
- std::string data(reinterpret_cast(fbb_data.GetBufferPointer()),
- fbb_data.GetSize());
- // Add the value.
- data_vec.push_back(fbb.CreateString(data));
+ GcsEntry gcs_entry;
+ gcs_entry.set_id(id.Binary());
+ gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD);
+ for (const auto &pair : data_map) {
+ gcs_entry.add_entries(pair.first);
+ gcs_entry.add_entries(pair.second->SerializeAsString());
}
-
- fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
- fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec)));
- return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(),
- fbb.GetSize(), prefix_, pubsub_channel_,
- std::move(callback));
+ std::string str = gcs_entry.SerializeAsString();
+ return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(),
+ prefix_, pubsub_channel_, std::move(callback));
}
template
@@ -380,19 +355,15 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id,
(remove_callback)(client_, id, keys);
}
};
- flatbuffers::FlatBufferBuilder fbb;
- std::vector> data_vec;
- data_vec.reserve(keys.size());
- // Add the keys.
- for (auto const &key : keys) {
- data_vec.push_back(fbb.CreateString(key));
+ GcsEntry gcs_entry;
+ gcs_entry.set_id(id.Binary());
+ gcs_entry.set_change_mode(GcsChangeMode::REMOVE);
+ for (const auto &key : keys) {
+ gcs_entry.add_entries(key);
}
-
- fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()),
- fbb.CreateVector(data_vec)));
- return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(),
- fbb.GetSize(), prefix_, pubsub_channel_,
- std::move(callback));
+ std::string str = gcs_entry.SerializeAsString();
+ return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(),
+ prefix_, pubsub_channel_, std::move(callback));
}
template
@@ -412,17 +383,15 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id,
DataMap results;
if (!reply.IsNil()) {
const auto data = reply.ReadAsString();
- auto root = flatbuffers::GetRoot(data.data());
- RAY_CHECK(from_flatbuf(*root->id()) == id);
- RAY_CHECK(root->entries()->size() % 2 == 0);
- for (size_t i = 0; i < root->entries()->size(); i += 2) {
- std::string key(root->entries()->Get(i)->data(),
- root->entries()->Get(i)->size());
- auto result = std::make_shared();
- auto data_root =
- flatbuffers::GetRoot(root->entries()->Get(i + 1)->data());
- data_root->UnPackTo(result.get());
- results.emplace(key, std::move(result));
+ GcsEntry gcs_entry;
+ gcs_entry.ParseFromString(reply.ReadAsString());
+ RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id);
+ RAY_CHECK(gcs_entry.entries_size() % 2 == 0);
+ for (int i = 0; i < gcs_entry.entries_size(); i += 2) {
+ const auto &key = gcs_entry.entries(i);
+ const auto value = std::make_shared();
+ value->ParseFromString(gcs_entry.entries(i + 1));
+ results.emplace(key, std::move(value));
}
}
lookup(client_, id, results);
@@ -451,31 +420,24 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie
// Data is provided. This is the callback for a message.
if (subscribe != nullptr) {
// Parse the notification.
- auto root = flatbuffers::GetRoot(data.data());
+ GcsEntry gcs_entry;
+ gcs_entry.ParseFromString(data);
+ ID id = ID::FromBinary(gcs_entry.id());
DataMap data_map;
- ID id;
- if (root->id()->size() > 0) {
- id = from_flatbuf(*root->id());
- }
- if (root->change_mode() == GcsChangeMode::REMOVE) {
- for (size_t i = 0; i < root->entries()->size(); i++) {
- std::string key(root->entries()->Get(i)->data(),
- root->entries()->Get(i)->size());
- data_map.emplace(key, std::shared_ptr());
+ if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) {
+ for (const auto &key : gcs_entry.entries()) {
+ data_map.emplace(key, std::shared_ptr());
}
} else {
- RAY_CHECK(root->entries()->size() % 2 == 0);
- for (size_t i = 0; i < root->entries()->size(); i += 2) {
- std::string key(root->entries()->Get(i)->data(),
- root->entries()->Get(i)->size());
- auto result = std::make_shared();
- auto data_root =
- flatbuffers::GetRoot(root->entries()->Get(i + 1)->data());
- data_root->UnPackTo(result.get());
- data_map.emplace(key, std::move(result));
+ RAY_CHECK(gcs_entry.entries_size() % 2 == 0);
+ for (int i = 0; i < gcs_entry.entries_size(); i += 2) {
+ const auto &key = gcs_entry.entries(i);
+ const auto value = std::make_shared();
+ value->ParseFromString(gcs_entry.entries(i + 1));
+ data_map.emplace(key, std::move(value));
}
}
- subscribe(client_, id, root->change_mode(), data_map);
+ subscribe(client_, id, gcs_entry.change_mode(), data_map);
}
}
};
@@ -490,11 +452,11 @@ Status Hash