Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion python/ray/tests/test_core_worker_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pytest

import ray
from ray._common.test_utils import wait_for_condition
from ray._common.test_utils import SignalActor, wait_for_condition
from ray.core.generated import common_pb2, gcs_pb2
from ray.exceptions import GetTimeoutError, TaskCancelledError
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy


Expand Down Expand Up @@ -164,5 +165,47 @@ def verify_actor_ref_deleted():
wait_for_condition(verify_actor_ref_deleted, timeout=30)


@pytest.fixture
def inject_cancel_remote_task_rpc_failure(monkeypatch, request):
deterministic_failure = request.param
monkeypatch.setenv(
"RAY_testing_rpc_failure",
"CoreWorkerService.grpc_client.CancelRemoteTask=1:"
+ ("100:0" if deterministic_failure == "request" else "0:100"),
)


@pytest.mark.parametrize(
"inject_cancel_remote_task_rpc_failure", ["request", "response"], indirect=True
)
def test_cancel_remote_task_rpc_retry_and_idempotency(
inject_cancel_remote_task_rpc_failure, ray_start_cluster
):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
ray.init(address=cluster.address)
cluster.add_node(num_cpus=1, resources={"worker1": 1})
cluster.add_node(num_cpus=1, resources={"worker2": 1})
signaler = SignalActor.remote()

@ray.remote(num_cpus=1, resources={"worker1": 1})
def wait_for(y):
return ray.get(y[0])

@ray.remote(num_cpus=1, resources={"worker2": 1})
def remote_wait(sg):
return [wait_for.remote([sg[0]])]

sig = signaler.wait.remote()

outer = remote_wait.remote([sig])
inner = ray.get(outer)[0]
with pytest.raises(GetTimeoutError):
ray.get(inner, timeout=1)
ray.cancel(inner)
with pytest.raises(TaskCancelledError):
ray.get(inner, timeout=10)


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
6 changes: 3 additions & 3 deletions src/mock/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class MockCoreWorker : public CoreWorker {
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
HandleRemoteCancelTask,
(rpc::RemoteCancelTaskRequest request,
rpc::RemoteCancelTaskReply *reply,
HandleCancelRemoteTask,
(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
Expand Down
6 changes: 3 additions & 3 deletions src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class MockCoreWorkerClientInterface : public CoreWorkerClientInterface {
const ClientCallback<CancelTaskReply> &callback),
(override));
MOCK_METHOD(void,
RemoteCancelTask,
(const RemoteCancelTaskRequest &request,
const ClientCallback<RemoteCancelTaskReply> &callback),
CancelRemoteTask,
(CancelRemoteTaskRequest && request,
const ClientCallback<CancelRemoteTaskReply> &callback),
(override));
MOCK_METHOD(void,
GetCoreWorkerStats,
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3871,8 +3871,8 @@ void CoreWorker::ProcessSubscribeForRefRemoved(
reference_counter_->SubscribeRefRemoved(object_id, contained_in_id, owner_address);
}

void CoreWorker::HandleRemoteCancelTask(rpc::RemoteCancelTaskRequest request,
rpc::RemoteCancelTaskReply *reply,
void CoreWorker::HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()),
request.force_kill(),
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1209,8 +1209,8 @@ class CoreWorker {
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
void HandleRemoteCancelTask(rpc::RemoteCancelTaskRequest request,
rpc::RemoteCancelTaskReply *reply,
void HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker_rpc_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CoreWorkerServiceHandlerProxy : public rpc::CoreWorkerServiceHandler {
RAY_CORE_WORKER_RPC_PROXY(ReportGeneratorItemReturns)
RAY_CORE_WORKER_RPC_PROXY(KillActor)
RAY_CORE_WORKER_RPC_PROXY(CancelTask)
RAY_CORE_WORKER_RPC_PROXY(RemoteCancelTask)
RAY_CORE_WORKER_RPC_PROXY(CancelRemoteTask)
RAY_CORE_WORKER_RPC_PROXY(RegisterMutableObjectReader)
RAY_CORE_WORKER_RPC_PROXY(GetCoreWorkerStats)
RAY_CORE_WORKER_RPC_PROXY(LocalGC)
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/grpc_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void CoreWorkerGrpcService::InitServerCallFactories(
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
CoreWorkerService, CancelTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
RemoteCancelTask,
CancelRemoteTask,
max_active_rpcs_per_handler_,
AuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/grpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ class CoreWorkerServiceHandler : public DelayedServiceHandler {
CancelTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleRemoteCancelTask(RemoteCancelTaskRequest request,
RemoteCancelTaskReply *reply,
virtual void HandleCancelRemoteTask(CancelRemoteTaskRequest request,
CancelRemoteTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleRegisterMutableObjectReader(
Expand Down
10 changes: 8 additions & 2 deletions src/ray/core_worker/task_submission/normal_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -777,11 +777,17 @@ void NormalTaskSubmitter::CancelRemoteTask(const ObjectID &object_id,
bool force_kill,
bool recursive) {
auto client = core_worker_client_pool_->GetOrConnect(worker_addr);
auto request = rpc::RemoteCancelTaskRequest();
auto request = rpc::CancelRemoteTaskRequest();
request.set_force_kill(force_kill);
request.set_recursive(recursive);
request.set_remote_object_id(object_id.Binary());
client->RemoteCancelTask(request, nullptr);
client->CancelRemoteTask(
std::move(request),
[](const Status &status, const rpc::CancelRemoteTaskReply &reply) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to cancel remote task: " << status.ToString();
}
});
}

bool NormalTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) {
Expand Down
11 changes: 6 additions & 5 deletions src/ray/core_worker_rpc_client/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
/*method_timeout_ms*/ -1,
override)

VOID_RPC_CLIENT_METHOD(CoreWorkerService,
RemoteCancelTask,
grpc_client_,
/*method_timeout_ms*/ -1,
override)
VOID_RETRYABLE_RPC_CLIENT_METHOD(retryable_grpc_client_,
CoreWorkerService,
CancelRemoteTask,
grpc_client_,
/*method_timeout_ms*/ -1,
override)

VOID_RETRYABLE_RPC_CLIENT_METHOD(retryable_grpc_client_,
CoreWorkerService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface {
virtual void CancelTask(const CancelTaskRequest &request,
const ClientCallback<CancelTaskReply> &callback) = 0;

virtual void RemoteCancelTask(
const RemoteCancelTaskRequest &request,
const ClientCallback<RemoteCancelTaskReply> &callback) = 0;
virtual void CancelRemoteTask(
CancelRemoteTaskRequest &&request,
const ClientCallback<CancelRemoteTaskReply> &callback) = 0;

virtual void RegisterMutableObjectReader(
const RegisterMutableObjectReaderRequest &request,
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker_rpc_client/fake_core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class FakeCoreWorkerClient : public CoreWorkerClientInterface {
void CancelTask(const CancelTaskRequest &request,
const ClientCallback<CancelTaskReply> &callback) override {}

void RemoteCancelTask(const RemoteCancelTaskRequest &request,
const ClientCallback<RemoteCancelTaskReply> &callback) override {}
void CancelRemoteTask(CancelRemoteTaskRequest &&request,
const ClientCallback<CancelRemoteTaskReply> &callback) override {}

void RegisterMutableObjectReader(
const RegisterMutableObjectReaderRequest &request,
Expand Down
8 changes: 4 additions & 4 deletions src/ray/protobuf/core_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ message CancelTaskReply {
bool attempt_succeeded = 2;
}

message RemoteCancelTaskRequest {
message CancelRemoteTaskRequest {
// Object ID of the remote task that should be killed.
bytes remote_object_id = 1;
// Whether to kill the worker.
Expand All @@ -302,7 +302,7 @@ message RemoteCancelTaskRequest {
bool recursive = 3;
}

message RemoteCancelTaskReply {}
message CancelRemoteTaskReply {}

message GetCoreWorkerStatsRequest {
// The ID of the worker this message is intended for.
Expand Down Expand Up @@ -534,8 +534,8 @@ service CoreWorkerService {
rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply);

// Request from a worker to the owner worker to issue a cancellation.
// Failure: TODO: needs failure behavior
rpc RemoteCancelTask(RemoteCancelTaskRequest) returns (RemoteCancelTaskReply);
// Failure: Retries, it's idempotent.
rpc CancelRemoteTask(CancelRemoteTaskRequest) returns (CancelRemoteTaskReply);

// From raylet to get metrics from its workers.
// Failure: Should not fail, always from local raylet.
Expand Down