diff --git a/python/ray/tests/test_core_worker_fault_tolerance.py b/python/ray/tests/test_core_worker_fault_tolerance.py index d1fa4c3d463a..6ab8cf9ba5de 100644 --- a/python/ray/tests/test_core_worker_fault_tolerance.py +++ b/python/ray/tests/test_core_worker_fault_tolerance.py @@ -4,8 +4,9 @@ import pytest import ray -from ray._common.test_utils import wait_for_condition +from ray._common.test_utils import SignalActor, wait_for_condition from ray.core.generated import common_pb2, gcs_pb2 +from ray.exceptions import GetTimeoutError, TaskCancelledError from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -164,5 +165,47 @@ def verify_actor_ref_deleted(): wait_for_condition(verify_actor_ref_deleted, timeout=30) +@pytest.fixture +def inject_cancel_remote_task_rpc_failure(monkeypatch, request): + deterministic_failure = request.param + monkeypatch.setenv( + "RAY_testing_rpc_failure", + "CoreWorkerService.grpc_client.CancelRemoteTask=1:" + + ("100:0" if deterministic_failure == "request" else "0:100"), + ) + + +@pytest.mark.parametrize( + "inject_cancel_remote_task_rpc_failure", ["request", "response"], indirect=True +) +def test_cancel_remote_task_rpc_retry_and_idempotency( + inject_cancel_remote_task_rpc_failure, ray_start_cluster +): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0) + ray.init(address=cluster.address) + cluster.add_node(num_cpus=1, resources={"worker1": 1}) + cluster.add_node(num_cpus=1, resources={"worker2": 1}) + signaler = SignalActor.remote() + + @ray.remote(num_cpus=1, resources={"worker1": 1}) + def wait_for(y): + return ray.get(y[0]) + + @ray.remote(num_cpus=1, resources={"worker2": 1}) + def remote_wait(sg): + return [wait_for.remote([sg[0]])] + + sig = signaler.wait.remote() + + outer = remote_wait.remote([sig]) + inner = ray.get(outer)[0] + with pytest.raises(GetTimeoutError): + ray.get(inner, timeout=1) + ray.cancel(inner) + with pytest.raises(TaskCancelledError): + ray.get(inner, timeout=10) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/src/mock/ray/core_worker/core_worker.h b/src/mock/ray/core_worker/core_worker.h index 403d97de0db0..563d7f3d3f6c 100644 --- a/src/mock/ray/core_worker/core_worker.h +++ b/src/mock/ray/core_worker/core_worker.h @@ -89,9 +89,9 @@ class MockCoreWorker : public CoreWorker { rpc::SendReplyCallback send_reply_callback), (override)); MOCK_METHOD(void, - HandleRemoteCancelTask, - (rpc::RemoteCancelTaskRequest request, - rpc::RemoteCancelTaskReply *reply, + HandleCancelRemoteTask, + (rpc::CancelRemoteTaskRequest request, + rpc::CancelRemoteTaskReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); MOCK_METHOD(void, diff --git a/src/mock/ray/rpc/worker/core_worker_client.h b/src/mock/ray/rpc/worker/core_worker_client.h index 6fd174d254d9..cd293cebbd93 100644 --- a/src/mock/ray/rpc/worker/core_worker_client.h +++ b/src/mock/ray/rpc/worker/core_worker_client.h @@ -87,9 +87,9 @@ class MockCoreWorkerClientInterface : public CoreWorkerClientInterface { const ClientCallback &callback), (override)); MOCK_METHOD(void, - RemoteCancelTask, - (const RemoteCancelTaskRequest &request, - const ClientCallback &callback), + CancelRemoteTask, + (CancelRemoteTaskRequest && request, + const ClientCallback &callback), (override)); MOCK_METHOD(void, GetCoreWorkerStats, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 814d42bde6d9..1003ed5232cc 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -3871,8 +3871,8 @@ void CoreWorker::ProcessSubscribeForRefRemoved( reference_counter_->SubscribeRefRemoved(object_id, contained_in_id, owner_address); } -void CoreWorker::HandleRemoteCancelTask(rpc::RemoteCancelTaskRequest request, - rpc::RemoteCancelTaskReply *reply, +void CoreWorker::HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request, + rpc::CancelRemoteTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()), request.force_kill(), diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 883208d63c08..84c194dea454 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1209,8 +1209,8 @@ class CoreWorker { rpc::SendReplyCallback send_reply_callback); /// Implements gRPC server handler. - void HandleRemoteCancelTask(rpc::RemoteCancelTaskRequest request, - rpc::RemoteCancelTaskReply *reply, + void HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request, + rpc::CancelRemoteTaskReply *reply, rpc::SendReplyCallback send_reply_callback); /// Implements gRPC server handler. diff --git a/src/ray/core_worker/core_worker_rpc_proxy.h b/src/ray/core_worker/core_worker_rpc_proxy.h index 8105232e67a9..48865d81b7b9 100644 --- a/src/ray/core_worker/core_worker_rpc_proxy.h +++ b/src/ray/core_worker/core_worker_rpc_proxy.h @@ -56,7 +56,7 @@ class CoreWorkerServiceHandlerProxy : public rpc::CoreWorkerServiceHandler { RAY_CORE_WORKER_RPC_PROXY(ReportGeneratorItemReturns) RAY_CORE_WORKER_RPC_PROXY(KillActor) RAY_CORE_WORKER_RPC_PROXY(CancelTask) - RAY_CORE_WORKER_RPC_PROXY(RemoteCancelTask) + RAY_CORE_WORKER_RPC_PROXY(CancelRemoteTask) RAY_CORE_WORKER_RPC_PROXY(RegisterMutableObjectReader) RAY_CORE_WORKER_RPC_PROXY(GetCoreWorkerStats) RAY_CORE_WORKER_RPC_PROXY(LocalGC) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index a3a550dbaa55..e5540aa502df 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -69,7 +69,7 @@ void CoreWorkerGrpcService::InitServerCallFactories( RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( CoreWorkerService, CancelTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, - RemoteCancelTask, + CancelRemoteTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index e70b9c8ee475..4559a45447c1 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -93,8 +93,8 @@ class CoreWorkerServiceHandler : public DelayedServiceHandler { CancelTaskReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleRemoteCancelTask(RemoteCancelTaskRequest request, - RemoteCancelTaskReply *reply, + virtual void HandleCancelRemoteTask(CancelRemoteTaskRequest request, + CancelRemoteTaskReply *reply, SendReplyCallback send_reply_callback) = 0; virtual void HandleRegisterMutableObjectReader( diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index bbd436547611..49191d9b8077 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -777,11 +777,17 @@ void NormalTaskSubmitter::CancelRemoteTask(const ObjectID &object_id, bool force_kill, bool recursive) { auto client = core_worker_client_pool_->GetOrConnect(worker_addr); - auto request = rpc::RemoteCancelTaskRequest(); + auto request = rpc::CancelRemoteTaskRequest(); request.set_force_kill(force_kill); request.set_recursive(recursive); request.set_remote_object_id(object_id.Binary()); - client->RemoteCancelTask(request, nullptr); + client->CancelRemoteTask( + std::move(request), + [](const Status &status, const rpc::CancelRemoteTaskReply &reply) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to cancel remote task: " << status.ToString(); + } + }); } bool NormalTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) { diff --git a/src/ray/core_worker_rpc_client/core_worker_client.h b/src/ray/core_worker_rpc_client/core_worker_client.h index 06f117ddecf6..b9fa6b2ea71f 100644 --- a/src/ray/core_worker_rpc_client/core_worker_client.h +++ b/src/ray/core_worker_rpc_client/core_worker_client.h @@ -84,11 +84,12 @@ class CoreWorkerClient : public std::enable_shared_from_this, /*method_timeout_ms*/ -1, override) - VOID_RPC_CLIENT_METHOD(CoreWorkerService, - RemoteCancelTask, - grpc_client_, - /*method_timeout_ms*/ -1, - override) + VOID_RETRYABLE_RPC_CLIENT_METHOD(retryable_grpc_client_, + CoreWorkerService, + CancelRemoteTask, + grpc_client_, + /*method_timeout_ms*/ -1, + override) VOID_RETRYABLE_RPC_CLIENT_METHOD(retryable_grpc_client_, CoreWorkerService, diff --git a/src/ray/core_worker_rpc_client/core_worker_client_interface.h b/src/ray/core_worker_rpc_client/core_worker_client_interface.h index e5d2a554e558..80c37f709d34 100644 --- a/src/ray/core_worker_rpc_client/core_worker_client_interface.h +++ b/src/ray/core_worker_rpc_client/core_worker_client_interface.h @@ -76,9 +76,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { virtual void CancelTask(const CancelTaskRequest &request, const ClientCallback &callback) = 0; - virtual void RemoteCancelTask( - const RemoteCancelTaskRequest &request, - const ClientCallback &callback) = 0; + virtual void CancelRemoteTask( + CancelRemoteTaskRequest &&request, + const ClientCallback &callback) = 0; virtual void RegisterMutableObjectReader( const RegisterMutableObjectReaderRequest &request, diff --git a/src/ray/core_worker_rpc_client/fake_core_worker_client.h b/src/ray/core_worker_rpc_client/fake_core_worker_client.h index 1e1a51002c60..02a380b9e397 100644 --- a/src/ray/core_worker_rpc_client/fake_core_worker_client.h +++ b/src/ray/core_worker_rpc_client/fake_core_worker_client.h @@ -80,8 +80,8 @@ class FakeCoreWorkerClient : public CoreWorkerClientInterface { void CancelTask(const CancelTaskRequest &request, const ClientCallback &callback) override {} - void RemoteCancelTask(const RemoteCancelTaskRequest &request, - const ClientCallback &callback) override {} + void CancelRemoteTask(CancelRemoteTaskRequest &&request, + const ClientCallback &callback) override {} void RegisterMutableObjectReader( const RegisterMutableObjectReaderRequest &request, diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 51e252f97354..2bec20b67ffe 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -293,7 +293,7 @@ message CancelTaskReply { bool attempt_succeeded = 2; } -message RemoteCancelTaskRequest { +message CancelRemoteTaskRequest { // Object ID of the remote task that should be killed. bytes remote_object_id = 1; // Whether to kill the worker. @@ -302,7 +302,7 @@ message RemoteCancelTaskRequest { bool recursive = 3; } -message RemoteCancelTaskReply {} +message CancelRemoteTaskReply {} message GetCoreWorkerStatsRequest { // The ID of the worker this message is intended for. @@ -534,8 +534,8 @@ service CoreWorkerService { rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply); // Request from a worker to the owner worker to issue a cancellation. - // Failure: TODO: needs failure behavior - rpc RemoteCancelTask(RemoteCancelTaskRequest) returns (RemoteCancelTaskReply); + // Failure: Retries, it's idempotent. + rpc CancelRemoteTask(CancelRemoteTaskRequest) returns (CancelRemoteTaskReply); // From raylet to get metrics from its workers. // Failure: Should not fail, always from local raylet.