diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index 9cb8f450c1e2..9a3e58ea1de6 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -11,7 +11,11 @@ import ray from ray import serve from ray._common.test_utils import SignalActor -from ray.serve._private.test_utils import get_application_url, get_application_urls +from ray.serve._private.test_utils import ( + get_application_url, + get_application_urls, + send_signal_on_cancellation, +) from ray.serve.handle import DeploymentHandle @@ -326,12 +330,9 @@ def test_http_disconnect(serve_instance): class SimpleGenerator: def __call__(self, request: Request) -> StreamingResponse: async def wait_for_disconnect(): - try: - yield "hi" - await asyncio.sleep(100) - except asyncio.CancelledError: - print("Cancelled!") - signal_actor.send.remote() + yield "hi" + async with send_signal_on_cancellation(signal_actor): + pass return StreamingResponse(wait_for_disconnect()) diff --git a/python/ray/tests/test_core_worker_fault_tolerance.py b/python/ray/tests/test_core_worker_fault_tolerance.py index 3ff65345870d..b320777f0b14 100644 --- a/python/ray/tests/test_core_worker_fault_tolerance.py +++ b/python/ray/tests/test_core_worker_fault_tolerance.py @@ -172,7 +172,7 @@ def inject_cancel_remote_task_rpc_failure(monkeypatch, request): failure = RPC_FAILURE_MAP[deterministic_failure] monkeypatch.setenv( "RAY_testing_rpc_failure", - f"CoreWorkerService.grpc_client.CancelRemoteTask=1:{failure}", + f"CoreWorkerService.grpc_client.RequestOwnerToCancelTask=1:{failure}", ) diff --git a/python/ray/tests/test_raylet_fault_tolerance.py b/python/ray/tests/test_raylet_fault_tolerance.py index ee38d64e3af8..8c4597da8d78 100644 --- a/python/ray/tests/test_raylet_fault_tolerance.py +++ b/python/ray/tests/test_raylet_fault_tolerance.py @@ -4,12 +4,13 @@ import pytest import ray +from ray._common.test_utils import SignalActor, wait_for_condition from ray._private.test_utils import ( RPC_FAILURE_MAP, RPC_FAILURE_TYPES, - wait_for_condition, ) from ray.core.generated import autoscaler_pb2 +from ray.exceptions import GetTimeoutError, TaskCancelledError from ray.util.placement_group import placement_group, remove_placement_group from ray.util.scheduling_strategies import ( NodeAffinitySchedulingStrategy, @@ -247,5 +248,57 @@ def verify_process_killed(): wait_for_condition(verify_process_killed, timeout=30) +@pytest.fixture +def inject_cancel_local_task_rpc_failure(monkeypatch, request): + failure = RPC_FAILURE_MAP[request.param] + monkeypatch.setenv( + "RAY_testing_rpc_failure", + f"NodeManagerService.grpc_client.CancelLocalTask=1:{failure}", + ) + + +@pytest.mark.parametrize( + "inject_cancel_local_task_rpc_failure", RPC_FAILURE_TYPES, indirect=True +) +@pytest.mark.parametrize("force_kill", [True, False]) +def test_cancel_local_task_rpc_retry_and_idempotency( + inject_cancel_local_task_rpc_failure, force_kill, shutdown_only +): + """Test that CancelLocalTask RPC retries work correctly. + + Verify that the RPC is idempotent when network failures occur. + When force_kill=True, verify the worker process is actually killed using psutil. + """ + ray.init(num_cpus=1) + signaler = SignalActor.remote() + + @ray.remote(num_cpus=1) + def get_pid(): + return os.getpid() + + @ray.remote(num_cpus=1) + def blocking_task(): + return ray.get(signaler.wait.remote()) + + worker_pid = ray.get(get_pid.remote()) + + blocking_ref = blocking_task.remote() + + with pytest.raises(GetTimeoutError): + ray.get(blocking_ref, timeout=1) + + ray.cancel(blocking_ref, force=force_kill) + + with pytest.raises(TaskCancelledError): + ray.get(blocking_ref, timeout=10) + + if force_kill: + + def verify_process_killed(): + return not psutil.pid_exists(worker_pid) + + wait_for_condition(verify_process_killed, timeout=30) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/mock/ray/core_worker/core_worker.h b/src/mock/ray/core_worker/core_worker.h index 563d7f3d3f6c..1ba83cb01a3d 100644 --- a/src/mock/ray/core_worker/core_worker.h +++ b/src/mock/ray/core_worker/core_worker.h @@ -89,9 +89,9 @@ class MockCoreWorker : public CoreWorker { rpc::SendReplyCallback send_reply_callback), (override)); MOCK_METHOD(void, - HandleCancelRemoteTask, - (rpc::CancelRemoteTaskRequest request, - rpc::CancelRemoteTaskReply *reply, + HandleRequestOwnerToCancelTask, + (rpc::RequestOwnerToCancelTaskRequest request, + rpc::RequestOwnerToCancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); MOCK_METHOD(void, diff --git a/src/mock/ray/raylet_client/raylet_client.h b/src/mock/ray/raylet_client/raylet_client.h index 9419043c6aeb..8c4a76734b27 100644 --- a/src/mock/ray/raylet_client/raylet_client.h +++ b/src/mock/ray/raylet_client/raylet_client.h @@ -156,6 +156,11 @@ class MockRayletClientInterface : public RayletClientInterface { (const rpc::ClientCallback &callback), (override)); MOCK_METHOD(int64_t, GetPinsInFlight, (), (const, override)); + MOCK_METHOD(void, + CancelLocalTask, + (const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback), + (override)); }; } // namespace ray diff --git a/src/mock/ray/rpc/worker/core_worker_client.h b/src/mock/ray/rpc/worker/core_worker_client.h index cd293cebbd93..e537a35f7255 100644 --- a/src/mock/ray/rpc/worker/core_worker_client.h +++ b/src/mock/ray/rpc/worker/core_worker_client.h @@ -87,9 +87,9 @@ class MockCoreWorkerClientInterface : public CoreWorkerClientInterface { const ClientCallback &callback), (override)); MOCK_METHOD(void, - CancelRemoteTask, - (CancelRemoteTaskRequest && request, - const ClientCallback &callback), + RequestOwnerToCancelTask, + (RequestOwnerToCancelTaskRequest && request, + const ClientCallback &callback), (override)); MOCK_METHOD(void, GetCoreWorkerStats, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 4c8484826211..b9da6e7c6144 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2460,12 +2460,13 @@ Status CoreWorker::CancelTask(const ObjectID &object_id, } if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { - // We don't have CancelRemoteTask for actor_task_submitter_ + // We don't have RequestOwnerToCancelTask for actor_task_submitter_ // because it requires the same implementation. RAY_LOG(DEBUG).WithField(object_id) << "Request to cancel a task of object to an owner " << obj_addr.SerializeAsString(); - normal_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill, recursive); + normal_task_submitter_->RequestOwnerToCancelTask( + object_id, obj_addr, force_kill, recursive); return Status::OK(); } @@ -3910,9 +3911,10 @@ void CoreWorker::ProcessSubscribeForRefRemoved( reference_counter_->SubscribeRefRemoved(object_id, contained_in_id, owner_address); } -void CoreWorker::HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request, - rpc::CancelRemoteTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) { +void CoreWorker::HandleRequestOwnerToCancelTask( + rpc::RequestOwnerToCancelTaskRequest request, + rpc::RequestOwnerToCancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()), request.force_kill(), request.recursive()); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index b3c46fb51a8e..15770e3eb79f 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1208,9 +1208,9 @@ class CoreWorker { rpc::SendReplyCallback send_reply_callback); /// Implements gRPC server handler. - void HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request, - rpc::CancelRemoteTaskReply *reply, - rpc::SendReplyCallback send_reply_callback); + void HandleRequestOwnerToCancelTask(rpc::RequestOwnerToCancelTaskRequest request, + rpc::RequestOwnerToCancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback); /// Implements gRPC server handler. void HandlePlasmaObjectReady(rpc::PlasmaObjectReadyRequest request, diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index aae14dd55795..9e12ab9c0d37 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -500,6 +500,8 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( auto actor_task_submitter = std::make_unique( *core_worker_client_pool, + *raylet_client_pool, + gcs_client, *memory_store, *task_manager, *actor_creator, @@ -538,6 +540,7 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( local_raylet_rpc_client, core_worker_client_pool, raylet_client_pool, + gcs_client, std::move(lease_policy), memory_store, *task_manager, @@ -554,7 +557,7 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( // OBJECT_STORE. return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_service_), + io_service_, *scheduler_placement_time_ms_histogram_); auto report_locality_data_callback = [this]( diff --git a/src/ray/core_worker/core_worker_rpc_proxy.h b/src/ray/core_worker/core_worker_rpc_proxy.h index 48865d81b7b9..808da2cde210 100644 --- a/src/ray/core_worker/core_worker_rpc_proxy.h +++ b/src/ray/core_worker/core_worker_rpc_proxy.h @@ -56,7 +56,7 @@ class CoreWorkerServiceHandlerProxy : public rpc::CoreWorkerServiceHandler { RAY_CORE_WORKER_RPC_PROXY(ReportGeneratorItemReturns) RAY_CORE_WORKER_RPC_PROXY(KillActor) RAY_CORE_WORKER_RPC_PROXY(CancelTask) - RAY_CORE_WORKER_RPC_PROXY(CancelRemoteTask) + RAY_CORE_WORKER_RPC_PROXY(RequestOwnerToCancelTask) RAY_CORE_WORKER_RPC_PROXY(RegisterMutableObjectReader) RAY_CORE_WORKER_RPC_PROXY(GetCoreWorkerStats) RAY_CORE_WORKER_RPC_PROXY(LocalGC) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index adb5b62786d4..b0e4d9ce37d7 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -77,7 +77,7 @@ void CoreWorkerGrpcService::InitServerCallFactories( max_active_rpcs_per_handler_, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, - CancelRemoteTask, + RequestOwnerToCancelTask, max_active_rpcs_per_handler_, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index d605f5176533..b90349f6f009 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -96,9 +96,9 @@ class CoreWorkerServiceHandler : public DelayedServiceHandler { CancelTaskReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleCancelRemoteTask(CancelRemoteTaskRequest request, - CancelRemoteTaskReply *reply, - SendReplyCallback send_reply_callback) = 0; + virtual void HandleRequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest request, + RequestOwnerToCancelTaskReply *reply, + SendReplyCallback send_reply_callback) = 0; virtual void HandleRegisterMutableObjectReader( RegisterMutableObjectReaderRequest request, diff --git a/src/ray/core_worker/task_submission/BUILD.bazel b/src/ray/core_worker/task_submission/BUILD.bazel index e8b9cd096993..6856592a792b 100644 --- a/src/ray/core_worker/task_submission/BUILD.bazel +++ b/src/ray/core_worker/task_submission/BUILD.bazel @@ -53,6 +53,17 @@ ray_cc_library( ], ) +ray_cc_library( + name = "task_submission_util", + hdrs = ["task_submission_util.h"], + visibility = [":__subpackages__"], + deps = [ + "//src/ray/common:asio", + "//src/ray/common:id", + "//src/ray/gcs_rpc_client:gcs_client", + ], +) + ray_cc_library( name = "actor_task_submitter", srcs = ["actor_task_submitter.cc"], @@ -66,12 +77,16 @@ ray_cc_library( ":dependency_resolver", ":out_of_order_actor_submit_queue", ":sequential_actor_submit_queue", + ":task_submission_util", "//src/ray/common:asio", "//src/ray/common:id", "//src/ray/common:protobuf_utils", "//src/ray/core_worker:actor_creator", "//src/ray/core_worker:reference_counter_interface", "//src/ray/core_worker_rpc_client:core_worker_client_pool", + "//src/ray/gcs_rpc_client:gcs_client", + "//src/ray/raylet_rpc_client:raylet_client_interface", + "//src/ray/raylet_rpc_client:raylet_client_pool", "//src/ray/rpc:rpc_callback_types", "//src/ray/util:time", "@com_google_absl//absl/base:core_headers", @@ -90,6 +105,7 @@ ray_cc_library( ], deps = [ ":dependency_resolver", + ":task_submission_util", "//src/ray/common:id", "//src/ray/common:lease", "//src/ray/common:protobuf_utils", @@ -97,7 +113,9 @@ ray_cc_library( "//src/ray/core_worker:memory_store", "//src/ray/core_worker:task_manager_interface", "//src/ray/core_worker_rpc_client:core_worker_client_pool", + "//src/ray/gcs_rpc_client:gcs_client", "//src/ray/raylet_rpc_client:raylet_client_interface", + "//src/ray/raylet_rpc_client:raylet_client_pool", "//src/ray/util:time", "@com_google_absl//absl/base:core_headers", ], diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index e0cc2fb20d75..e5a4a46b7f53 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -21,6 +21,7 @@ #include #include "ray/common/protobuf_utils.h" +#include "ray/core_worker/task_submission/task_submission_util.h" #include "ray/util/time.h" namespace ray { @@ -912,17 +913,16 @@ std::string ActorTaskSubmitter::DebugString(const ActorID &actor_id) const { return stream.str(); } -void ActorTaskSubmitter::RetryCancelTask(TaskSpecification task_spec, - bool recursive, - int64_t milliseconds) { +void ActorTaskSubmitter::RetryCancelTask(TaskSpecification task_spec, bool recursive) { + auto delay_ms = RayConfig::instance().cancellation_retry_ms(); RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "Task cancelation will be retried in " << milliseconds << " ms"; + << "Task cancelation will be retried in " << delay_ms << " ms"; execute_after( io_service_, [this, task_spec = std::move(task_spec), recursive] { CancelTask(task_spec, recursive); }, - std::chrono::milliseconds(milliseconds)); + std::chrono::milliseconds(delay_ms)); } void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) { @@ -997,44 +997,56 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) // an executor tells us to stop retrying. // If there's no client, it means actor is not created yet. - // Retry in 1 second. + // Retry after the configured delay. + NodeID node_id; + std::string executor_worker_id; { absl::MutexLock lock(&mu_); RAY_LOG(DEBUG).WithField(task_id) << "Task was sent to an actor. Send a cancel RPC."; auto queue = client_queues_.find(actor_id); RAY_CHECK(queue != client_queues_.end()); if (!queue->second.client_address_.has_value()) { - RetryCancelTask(task_spec, recursive, 1000); + RetryCancelTask(task_spec, recursive); return; } - - rpc::CancelTaskRequest request; - request.set_intended_task_id(task_spec.TaskIdBinary()); - request.set_force_kill(force_kill); - request.set_recursive(recursive); - request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - auto client = core_worker_client_pool_.GetOrConnect(*queue->second.client_address_); - client->CancelTask(request, - [this, task_spec = std::move(task_spec), recursive, task_id]( - const Status &status, const rpc::CancelTaskReply &reply) { - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "CancelTask RPC response received with status " - << status.ToString(); - - // Keep retrying every 2 seconds until a task is officially - // finished. - if (!task_manager_.GetTaskSpec(task_id)) { - // Task is already finished. - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "Task is finished. Stop a cancel request."; - return; - } - - if (!reply.attempt_succeeded()) { - RetryCancelTask(task_spec, recursive, 2000); - } - }); + node_id = NodeID::FromBinary(queue->second.client_address_.value().node_id()); + executor_worker_id = queue->second.client_address_.value().worker_id(); } + + auto do_cancel_local_task = + [this, task_spec = std::move(task_spec), force_kill, recursive, executor_worker_id]( + const rpc::Address &raylet_address) mutable { + rpc::CancelLocalTaskRequest request; + request.set_intended_task_id(task_spec.TaskIdBinary()); + request.set_force_kill(force_kill); + request.set_recursive(recursive); + request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); + request.set_executor_worker_id(executor_worker_id); + + auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address); + raylet_client->CancelLocalTask( + request, + [this, task_spec = std::move(task_spec), recursive]( + const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { + if (!status.ok()) { + RAY_LOG(INFO) << "CancelLocalTask RPC failed for task " + << task_spec.TaskId() << ": " << status.ToString() + << " due to node death"; + return; + } else { + RAY_LOG(INFO) << "CancelLocalTask RPC response received for " + << task_spec.TaskId() + << " with attempt_succeeded: " << reply.attempt_succeeded() + << " requested_task_running: " + << reply.requested_task_running(); + } + // Keep retrying until a task is officially finished. + if (!reply.attempt_succeeded()) { + RetryCancelTask(std::move(task_spec), recursive); + } + }); + }; + SendCancelLocalTask(gcs_client_, node_id, std::move(do_cancel_local_task), []() {}); } bool ActorTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) { diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.h b/src/ray/core_worker/task_submission/actor_task_submitter.h index ada84d9a0d25..16769e27f7d3 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.h +++ b/src/ray/core_worker/task_submission/actor_task_submitter.h @@ -68,6 +68,8 @@ class ActorTaskSubmitterInterface { class ActorTaskSubmitter : public ActorTaskSubmitterInterface { public: ActorTaskSubmitter(rpc::CoreWorkerClientPool &core_worker_client_pool, + rpc::RayletClientPool &raylet_client_pool, + std::shared_ptr gcs_client, CoreWorkerMemoryStore &store, TaskManagerInterface &task_manager, ActorCreatorInterface &actor_creator, @@ -77,6 +79,8 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { instrumented_io_context &io_service, std::shared_ptr reference_counter) : core_worker_client_pool_(core_worker_client_pool), + raylet_client_pool_(raylet_client_pool), + gcs_client_(std::move(gcs_client)), actor_creator_(actor_creator), resolver_(store, task_manager, actor_creator, tensor_transport_getter), task_manager_(task_manager), @@ -232,8 +236,8 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { /// \param recursive If true, it will cancel all child tasks. void CancelTask(TaskSpecification task_spec, bool recursive); - /// Retry the CancelTask in milliseconds. - void RetryCancelTask(TaskSpecification task_spec, bool recursive, int64_t milliseconds); + /// Retry the CancelTask after a configured delay. + void RetryCancelTask(TaskSpecification task_spec, bool recursive); /// Queue the streaming generator up for resubmission. /// \return true if the task is still executing and the submitter agrees to resubmit @@ -301,7 +305,7 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { int64_t num_restarts_due_to_lineage_reconstructions_ = 0; /// Whether this actor exits by spot preemption. bool preempted_ = false; - /// The RPC client address. + /// The RPC client address of the actor. std::optional client_address_; /// The intended worker ID of the actor. std::string worker_id_; @@ -412,6 +416,11 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { /// Pool for producing new core worker clients. rpc::CoreWorkerClientPool &core_worker_client_pool_; + /// Pool for producing new raylet clients. + rpc::RayletClientPool &raylet_client_pool_; + + std::shared_ptr gcs_client_; + ActorCreatorInterface &actor_creator_; /// Mutex to protect the various maps below. diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index aa84b217db41..ebe984ea3b41 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -22,8 +22,10 @@ #include #include "absl/strings/str_format.h" +#include "ray/common/asio/asio_util.h" #include "ray/common/lease/lease_spec.h" #include "ray/common/protobuf_utils.h" +#include "ray/core_worker/task_submission/task_submission_util.h" #include "ray/util/time.h" namespace ray { @@ -665,7 +667,9 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, SchedulingKey scheduling_key(task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), task_spec.GetRuntimeEnvHash()); - std::shared_ptr client = nullptr; + + NodeID node_id; + std::string executor_worker_id; { absl::MutexLock lock(&mu_); generators_to_resubmit_.erase(task_id); @@ -700,9 +704,8 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, // This will get removed either when the RPC call to cancel is returned, when all // dependencies are resolved, or when dependency resolution is successfully cancelled. RAY_CHECK(cancelled_tasks_.emplace(task_id).second); - auto rpc_client = executing_tasks_.find(task_id); - - if (rpc_client == executing_tasks_.end()) { + auto rpc_client_address = executing_tasks_.find(task_id); + if (rpc_client_address == executing_tasks_.end()) { if (failed_tasks_pending_failure_cause_.contains(task_id)) { // We are waiting for the task failure cause. Do not fail it here; instead, // wait for the cause to come in and then handle it appropriately. @@ -722,68 +725,81 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, } return; } - // Looks for an RPC handle for the worker executing the task. - client = core_worker_client_pool_->GetOrConnect(rpc_client->second); + node_id = NodeID::FromBinary(rpc_client_address->second.node_id()); + executor_worker_id = rpc_client_address->second.worker_id(); } - RAY_CHECK(client != nullptr); - auto request = rpc::CancelTaskRequest(); - request.set_intended_task_id(task_spec.TaskIdBinary()); - request.set_force_kill(force_kill); - request.set_recursive(recursive); - request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - client->CancelTask( - request, - [this, - task_spec = std::move(task_spec), - scheduling_key = std::move(scheduling_key), - force_kill, - recursive](const Status &status, const rpc::CancelTaskReply &reply) mutable { - absl::MutexLock lock(&mu_); - RAY_LOG(DEBUG) << "CancelTask RPC response received for " << task_spec.TaskId() - << " with status " << status.ToString(); - cancelled_tasks_.erase(task_spec.TaskId()); - - // Retry is not attempted if !status.ok() because force-kill may kill the worker - // before the reply is sent. - if (!status.ok()) { - RAY_LOG(DEBUG) << "Failed to cancel a task due to " << status.ToString(); - return; - } - - if (!reply.attempt_succeeded()) { - if (reply.requested_task_running()) { - // Retry cancel request if failed. - if (cancel_retry_timer_.expiry().time_since_epoch() <= - std::chrono::high_resolution_clock::now().time_since_epoch()) { - cancel_retry_timer_.expires_after(boost::asio::chrono::milliseconds( - RayConfig::instance().cancellation_retry_ms())); - } - cancel_retry_timer_.async_wait(boost::bind(&NormalTaskSubmitter::CancelTask, - this, - std::move(task_spec), - force_kill, - recursive)); + auto do_cancel_local_task = [this, + task_spec = std::move(task_spec), + scheduling_key = std::move(scheduling_key), + executor_worker_id, + force_kill, + recursive](const rpc::Address &raylet_address) mutable { + rpc::CancelLocalTaskRequest request; + request.set_intended_task_id(task_spec.TaskIdBinary()); + request.set_force_kill(force_kill); + request.set_recursive(recursive); + request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); + request.set_executor_worker_id(executor_worker_id); + + auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(raylet_address); + raylet_client->CancelLocalTask( + request, + [this, + task_spec = std::move(task_spec), + scheduling_key = std::move(scheduling_key), + force_kill, + recursive](const Status &status, + const rpc::CancelLocalTaskReply &reply) mutable { + absl::MutexLock callback_lock(&mu_); + cancelled_tasks_.erase(task_spec.TaskId()); + if (!status.ok()) { + RAY_LOG(INFO) << "CancelLocalTask RPC failed for task " << task_spec.TaskId() + << ": " << status.ToString() << " due to node death"; + return; } else { - RAY_LOG(DEBUG) << "Attempt to cancel task " << task_spec.TaskId() - << " in a worker that doesn't have this task."; + RAY_LOG(INFO) << "CancelLocalTask RPC response received for " + << task_spec.TaskId() + << " with attempt_succeeded: " << reply.attempt_succeeded() + << " requested_task_running: " + << reply.requested_task_running(); } - } - }); + if (!reply.attempt_succeeded()) { + if (reply.requested_task_running()) { + execute_after( + io_service_, + [this, task_spec = std::move(task_spec), force_kill, recursive] { + CancelTask(task_spec, force_kill, recursive); + }, + std::chrono::milliseconds( + RayConfig::instance().cancellation_retry_ms())); + } else { + RAY_LOG(DEBUG) << "Attempt to cancel task " << task_spec.TaskId() + << " in a worker that doesn't have this task."; + } + } + }); + }; + auto failure_callback = [this, task_id]() { + absl::MutexLock inner_lock(&mu_); + cancelled_tasks_.erase(task_id); + }; + SendCancelLocalTask( + gcs_client_, node_id, std::move(do_cancel_local_task), std::move(failure_callback)); } -void NormalTaskSubmitter::CancelRemoteTask(const ObjectID &object_id, - const rpc::Address &worker_addr, - bool force_kill, - bool recursive) { +void NormalTaskSubmitter::RequestOwnerToCancelTask(const ObjectID &object_id, + const rpc::Address &worker_addr, + bool force_kill, + bool recursive) { auto client = core_worker_client_pool_->GetOrConnect(worker_addr); - auto request = rpc::CancelRemoteTaskRequest(); + auto request = rpc::RequestOwnerToCancelTaskRequest(); request.set_force_kill(force_kill); request.set_recursive(recursive); request.set_remote_object_id(object_id.Binary()); - client->CancelRemoteTask( + client->RequestOwnerToCancelTask( std::move(request), - [](const Status &status, const rpc::CancelRemoteTaskReply &reply) { + [](const Status &status, const rpc::RequestOwnerToCancelTaskReply &reply) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to cancel remote task: " << status.ToString(); } diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.h b/src/ray/core_worker/task_submission/normal_task_submitter.h index ce044826cfdf..d8090de144bc 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.h +++ b/src/ray/core_worker/task_submission/normal_task_submitter.h @@ -34,6 +34,11 @@ #include "ray/raylet_rpc_client/raylet_client_pool.h" namespace ray { + +namespace gcs { +class GcsClient; +} // namespace gcs + namespace core { // The task queues are keyed on resource shape & function descriptor @@ -85,6 +90,7 @@ class NormalTaskSubmitter { std::shared_ptr local_raylet_client, std::shared_ptr core_worker_client_pool, std::shared_ptr raylet_client_pool, + std::shared_ptr gcs_client, std::unique_ptr lease_policy, std::shared_ptr store, TaskManagerInterface &task_manager, @@ -95,11 +101,12 @@ class NormalTaskSubmitter { const JobID &job_id, std::shared_ptr lease_request_rate_limiter, const TensorTransportGetter &tensor_transport_getter, - boost::asio::steady_timer cancel_timer, + instrumented_io_context &io_service, ray::observability::MetricInterface &scheduler_placement_time_ms_histogram) : rpc_address_(std::move(rpc_address)), local_raylet_client_(std::move(local_raylet_client)), raylet_client_pool_(std::move(raylet_client_pool)), + gcs_client_(std::move(gcs_client)), lease_policy_(std::move(lease_policy)), resolver_(*store, task_manager, *actor_creator, tensor_transport_getter), task_manager_(task_manager), @@ -110,7 +117,7 @@ class NormalTaskSubmitter { core_worker_client_pool_(std::move(core_worker_client_pool)), job_id_(job_id), lease_request_rate_limiter_(std::move(lease_request_rate_limiter)), - cancel_retry_timer_(std::move(cancel_timer)), + io_service_(io_service), scheduler_placement_time_ms_histogram_(scheduler_placement_time_ms_histogram) {} /// Schedule a task for direct submission to a worker. @@ -126,10 +133,10 @@ class NormalTaskSubmitter { /// It is used when a object ID is not owned by the current process. /// We cannot cancel the task in this case because we don't have enough /// information to cancel a task. - void CancelRemoteTask(const ObjectID &object_id, - const rpc::Address &worker_addr, - bool force_kill, - bool recursive); + void RequestOwnerToCancelTask(const ObjectID &object_id, + const rpc::Address &worker_addr, + bool force_kill, + bool recursive); /// Queue the streaming generator up for resubmission. /// \return true if the task is still executing and the submitter agrees to resubmit @@ -244,6 +251,9 @@ class NormalTaskSubmitter { /// Raylet client pool for producing new clients to request leases from remote nodes. std::shared_ptr raylet_client_pool_; + /// GCS client for checking node liveness. + std::shared_ptr gcs_client_; + /// Provider of worker leasing decisions for the first lease request (not on /// spillback). std::unique_ptr lease_policy_; @@ -363,7 +373,7 @@ class NormalTaskSubmitter { std::shared_ptr lease_request_rate_limiter_; // Retries cancelation requests if they were not successful. - boost::asio::steady_timer cancel_retry_timer_ ABSL_GUARDED_BY(mu_); + instrumented_io_context &io_service_; ray::observability::MetricInterface &scheduler_placement_time_ms_histogram_; }; diff --git a/src/ray/core_worker/task_submission/task_submission_util.h b/src/ray/core_worker/task_submission/task_submission_util.h new file mode 100644 index 000000000000..2a37e3e732c0 --- /dev/null +++ b/src/ray/core_worker/task_submission/task_submission_util.h @@ -0,0 +1,88 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/id.h" +#include "ray/gcs_rpc_client/gcs_client.h" +#include "ray/raylet_rpc_client/raylet_client_pool.h" + +namespace ray { +namespace core { + +/// Send a CancelLocalTask operation after checking GCS node cache for node liveness. +/// The GCS query is done because we don't store the address of the raylet in the task +/// submission path. Since it's only needed in cancellation, we check the pubsub cache +/// (and the GCS if it's not in the cache which is rare) instead of polluting the hot +/// path. +/// +/// \param gcs_client GCS client to query node information. +/// \param node_id The local node ID of where the task is executing on +/// \param cancel_callback Callback containing CancelLocalTask RPC to invoke if the node +/// is alive. +/// \param failure_callback Callback invoked when CancelLocalTask RPC cannot be sent. +/// Used for cleanup. +inline void SendCancelLocalTask(std::shared_ptr gcs_client, + const NodeID &node_id, + std::function cancel_callback, + std::function failure_callback) { + // Check GCS node cache. If node info is not in the cache, query the GCS instead. + auto node_info = + gcs_client->Nodes().GetNodeAddressAndLiveness(node_id, + /*filter_dead_nodes=*/false); + if (!node_info) { + gcs_client->Nodes().AsyncGetAllNodeAddressAndLiveness( + [cancel_callback = std::move(cancel_callback), + failure_callback = std::move(failure_callback), + node_id](const Status &status, + std::vector &&nodes) mutable { + if (!status.ok()) { + RAY_LOG(INFO) << "Failed to get node info from GCS"; + failure_callback(); + return; + } + if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + failure_callback(); + return; + } + auto raylet_address = rpc::RayletClientPool::GenerateRayletAddress( + node_id, nodes[0].node_manager_address(), nodes[0].node_manager_port()); + cancel_callback(raylet_address); + }, + -1, + {node_id}); + return; + } + if (node_info->state() == rpc::GcsNodeInfo::DEAD) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + failure_callback(); + return; + } + auto raylet_address = rpc::RayletClientPool::GenerateRayletAddress( + node_id, node_info->node_manager_address(), node_info->node_manager_port()); + cancel_callback(raylet_address); +} + +} // namespace core +} // namespace ray diff --git a/src/ray/core_worker/task_submission/tests/BUILD.bazel b/src/ray/core_worker/task_submission/tests/BUILD.bazel index 1a0f0abaf274..e0e5850847f8 100644 --- a/src/ray/core_worker/task_submission/tests/BUILD.bazel +++ b/src/ray/core_worker/task_submission/tests/BUILD.bazel @@ -38,6 +38,7 @@ ray_cc_test( "//src/ray/core_worker/task_submission:actor_task_submitter", "//src/ray/pubsub:fake_publisher", "//src/ray/pubsub:fake_subscriber", + "//src/ray/raylet_rpc_client:raylet_client_pool", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], @@ -59,6 +60,7 @@ ray_cc_test( "//src/ray/core_worker_rpc_client:fake_core_worker_client", "//src/ray/pubsub:fake_publisher", "//src/ray/pubsub:fake_subscriber", + "//src/ray/raylet_rpc_client:raylet_client_pool", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], diff --git a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc index fad1e7545474..859065a6f36e 100644 --- a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc @@ -21,6 +21,7 @@ #include "gtest/gtest.h" #include "mock/ray/core_worker/task_manager_interface.h" +#include "mock/ray/gcs_client/gcs_client.h" #include "ray/common/test_utils.h" #include "ray/core_worker/fake_actor_creator.h" #include "ray/core_worker/reference_counter.h" @@ -29,6 +30,7 @@ #include "ray/observability/fake_metric.h" #include "ray/pubsub/fake_publisher.h" #include "ray/pubsub/fake_subscriber.h" +#include "ray/raylet_rpc_client/raylet_client_pool.h" namespace ray::core { @@ -92,9 +94,14 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam { ActorTaskSubmitterTest() : client_pool_(std::make_shared( [&](const rpc::Address &addr) { return worker_client_; })), + raylet_client_pool_(std::make_shared( + [](const rpc::Address &) -> std::shared_ptr { + return nullptr; + })), worker_client_(std::make_shared()), store_(std::make_shared(io_context)), task_manager_(std::make_shared()), + mock_gcs_client_(std::make_shared()), io_work(io_context.get_executor()), publisher_(std::make_unique()), subscriber_(std::make_unique()), @@ -110,6 +117,8 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam { /*lineage_pinning_enabled=*/false)), submitter_( *client_pool_, + *raylet_client_pool_, + mock_gcs_client_, *store_, *task_manager_, actor_creator_, @@ -125,9 +134,11 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam { int64_t last_queue_warning_ = 0; FakeActorCreator actor_creator_; std::shared_ptr client_pool_; + std::shared_ptr raylet_client_pool_; std::shared_ptr worker_client_; std::shared_ptr store_; std::shared_ptr task_manager_; + std::shared_ptr mock_gcs_client_; instrumented_io_context io_context; boost::asio::executor_work_guard io_work; std::unique_ptr publisher_; diff --git a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc index fdb086ea41a8..be7ee28586a6 100644 --- a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc +++ b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc @@ -26,6 +26,7 @@ #include "ray/observability/fake_metric.h" #include "ray/pubsub/fake_publisher.h" #include "ray/pubsub/fake_subscriber.h" +#include "ray/raylet_rpc_client/raylet_client_pool.h" namespace ray { namespace core { @@ -42,6 +43,10 @@ class DirectTaskTransportTest : public ::testing::Test { task_manager = std::make_shared(); client_pool = std::make_shared( [&](const rpc::Address &) { return nullptr; }); + raylet_client_pool = std::make_shared( + [](const rpc::Address &) -> std::shared_ptr { + return nullptr; + }); memory_store = DefaultCoreWorkerMemoryStoreWithThread::Create(); publisher = std::make_unique(); subscriber = std::make_unique(); @@ -55,6 +60,8 @@ class DirectTaskTransportTest : public ::testing::Test { /*lineage_pinning_enabled=*/false); actor_task_submitter = std::make_unique( *client_pool, + *raylet_client_pool, + gcs_client, *memory_store, *task_manager, *actor_creator, @@ -94,6 +101,7 @@ class DirectTaskTransportTest : public ::testing::Test { boost::asio::executor_work_guard io_work; std::unique_ptr actor_task_submitter; std::shared_ptr client_pool; + std::shared_ptr raylet_client_pool; std::unique_ptr memory_store; std::shared_ptr task_manager; std::unique_ptr actor_creator; diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index b5325598f788..0b59561ef65e 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -25,6 +25,7 @@ #include "gtest/gtest.h" #include "mock/ray/core_worker/memory_store.h" #include "mock/ray/core_worker/task_manager_interface.h" +#include "mock/ray/gcs_client/gcs_client.h" #include "ray/common/task/task_spec.h" #include "ray/common/task/task_util.h" #include "ray/common/test_utils.h" @@ -124,12 +125,6 @@ class MockWorkerClient : public rpc::FakeCoreWorkerClient { return true; } - void CancelTask(const rpc::CancelTaskRequest &request, - const rpc::ClientCallback &callback) override { - kill_requests.push_front(request); - cancel_callbacks.push_back(callback); - } - void ReplyCancelTask(Status status = Status::OK(), bool attempt_succeeded = true, bool requested_task_running = false) { @@ -387,6 +382,24 @@ class MockRayletClient : public rpc::FakeRayletClient { return GenericPopCallbackInLock(cancel_callbacks); } + void CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) override { + cancel_local_task_requests.push_back(request); + cancel_local_task_callbacks.push_back(callback); + } + + void ReplyCancelLocalTask(Status status, + bool attempt_succeeded, + bool requested_task_running) { + auto &callback = cancel_local_task_callbacks.front(); + rpc::CancelLocalTaskReply reply; + reply.set_attempt_succeeded(attempt_succeeded); + reply.set_requested_task_running(requested_task_running); + callback(status, std::move(reply)); + cancel_local_task_callbacks.pop_front(); + } + ~MockRayletClient() = default; // Protects all internal fields. @@ -401,10 +414,12 @@ class MockRayletClient : public rpc::FakeRayletClient { int num_get_task_failure_causes = 0; int reported_backlog_size = 0; std::map reported_backlogs; - std::list> callbacks = {}; - std::list> cancel_callbacks = {}; + std::list> callbacks; + std::list> cancel_callbacks; std::list> - get_task_failure_cause_callbacks = {}; + get_task_failure_cause_callbacks; + std::list cancel_local_task_requests; + std::list> cancel_local_task_callbacks; }; class MockLeasePolicy : public LeasePolicyInterface { @@ -450,7 +465,9 @@ class NormalTaskSubmitterTest : public testing::Test { task_manager(std::make_unique()), actor_creator(std::make_shared()), lease_policy(std::make_unique()), - lease_policy_ptr(lease_policy.get()) { + lease_policy_ptr(lease_policy.get()), + mock_gcs_client_(std::make_shared()), + io_work_(boost::asio::make_work_guard(io_context)) { address.set_node_id(local_node_id.Binary()); lease_policy_ptr->SetNodeID(local_node_id); } @@ -485,6 +502,7 @@ class NormalTaskSubmitterTest : public testing::Test { raylet_client, client_pool, raylet_client_pool, + mock_gcs_client_, std::move(lease_policy), store, *task_manager, @@ -495,7 +513,7 @@ class NormalTaskSubmitterTest : public testing::Test { JobID::Nil(), rate_limiter, [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_context), + io_context, fake_scheduler_placement_time_ms_histogram_); } @@ -512,7 +530,9 @@ class NormalTaskSubmitterTest : public testing::Test { // the submitter. std::unique_ptr lease_policy; MockLeasePolicy *lease_policy_ptr = nullptr; + std::shared_ptr mock_gcs_client_; instrumented_io_context io_context; + boost::asio::executor_work_guard io_work_; ray::observability::FakeHistogram fake_scheduler_placement_time_ms_histogram_; }; @@ -646,6 +666,19 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { // the task cancellation races between ReplyPushTask and ReplyGetWorkerFailureCause. // For an example of a python integration test, see // https://github.com/ray-project/ray/blob/2b6807f4d9c4572e6309f57bc404aa641bc4b185/python/ray/tests/test_cancel.py#L35 + + // Set up GCS node mock to return node as alive + using testing::_; + + rpc::GcsNodeAddressAndLiveness node_info; + node_info.set_node_id(local_node_id.Binary()); + node_info.set_node_manager_address("127.0.0.1"); + node_info.set_node_manager_port(9999); + node_info.set_state(rpc::GcsNodeInfo::ALIVE); + + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, GetNodeAddressAndLiveness(_, false)) + .WillRepeatedly(testing::Return(std::make_optional(node_info))); + auto submitter = CreateNormalTaskSubmitter(std::make_shared(1)); @@ -657,6 +690,9 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("oops"))); // Cancel the task while GetWorkerFailureCause has not been completed. submitter.CancelTask(task, true, false); + // ReplyPushTask removes the task from the executing_tasks_ map hence + // we shouldn't have triggered CancelLocalTask RPC. + ASSERT_EQ(raylet_client->num_cancel_local_task_requested, 0); // Completing the GetWorkerFailureCause call. Check that the reply runs without error // and FailPendingTask is not called. ASSERT_TRUE(raylet_client->ReplyGetWorkerFailureCause()); @@ -1445,12 +1481,14 @@ void TestSchedulingKey(const std::shared_ptr store, auto actor_creator = std::make_shared(); auto lease_policy = std::make_unique(); lease_policy->SetNodeID(local_node_id); + auto mock_gcs_client = std::make_shared(); instrumented_io_context io_context; NormalTaskSubmitter submitter( address, raylet_client, client_pool, raylet_client_pool, + mock_gcs_client, std::move(lease_policy), store, *task_manager, @@ -1461,7 +1499,7 @@ void TestSchedulingKey(const std::shared_ptr store, JobID::Nil(), std::make_shared(1), [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_context), + io_context, fake_scheduler_placement_time_ms_histogram_); submitter.SubmitTask(same1); @@ -1717,6 +1755,17 @@ TEST_F(NormalTaskSubmitterTest, TestWorkerLeaseTimeout) { } TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { + rpc::GcsNodeAddressAndLiveness node_info; + node_info.set_node_id(local_node_id.Binary()); + node_info.set_node_manager_address("127.0.0.1"); + node_info.set_node_manager_port(9999); + node_info.set_state(rpc::GcsNodeInfo::ALIVE); + + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, + GetNodeAddressAndLiveness(local_node_id, false)) + .WillOnce(testing::Return(std::make_optional(node_info))) + .WillOnce(testing::Return(std::make_optional(node_info))); + auto submitter = CreateNormalTaskSubmitter(std::make_shared(1)); TaskSpecification task = BuildEmptyTaskSpec(); @@ -1726,7 +1775,8 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try force kill, exiting the worker submitter.CancelTask(task, true, false); - ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), task.TaskIdBinary()); + ASSERT_EQ(raylet_client->cancel_local_task_requests.front().intended_task_id(), + task.TaskIdBinary()); ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true)); ASSERT_TRUE(raylet_client->ReplyGetWorkerFailureCause()); ASSERT_EQ(worker_client->callbacks.size(), 0); @@ -1744,7 +1794,8 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try non-force kill, worker returns normally submitter.CancelTask(task, false, false); ASSERT_TRUE(worker_client->ReplyPushTask()); - ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), task.TaskIdBinary()); + ASSERT_EQ(raylet_client->cancel_local_task_requests.back().intended_task_id(), + task.TaskIdBinary()); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(raylet_client->num_workers_returned_exiting, 0); @@ -1764,6 +1815,9 @@ TEST_F(NormalTaskSubmitterTest, TestKillPendingTask) { submitter.SubmitTask(task); submitter.CancelTask(task, true, false); + // We haven't been granted a worker lease yet, so the task is not executing. + // So we shouldn't have triggered CancelLocalTask RPC. + ASSERT_EQ(raylet_client->num_cancel_local_task_requested, 0); ASSERT_EQ(worker_client->kill_requests.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); @@ -1792,6 +1846,9 @@ TEST_F(NormalTaskSubmitterTest, TestKillResolvingTask) { submitter.SubmitTask(task); ASSERT_EQ(task_manager->num_inlined_dependencies, 0); submitter.CancelTask(task, true, false); + // We haven't been granted a worker lease yet, so the task is not executing. + // So we shouldn't have triggered CancelLocalTask RPC. + ASSERT_EQ(raylet_client->num_cancel_local_task_requested, 0); auto data = GenerateRandomObject(); store->Put(*data, obj1, /*has_reference=*/true); WaitForObjectIdInMemoryStore(*store, obj1); @@ -1824,6 +1881,18 @@ TEST_F(NormalTaskSubmitterTest, TestQueueGeneratorForResubmit) { TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) { // Cancel -> failed queue generator for resubmit -> cancel reply -> successful queue for // resubmit -> push task reply -> honor the cancel not the queued resubmit. + + rpc::GcsNodeAddressAndLiveness node_info; + node_info.set_node_id(local_node_id.Binary()); + node_info.set_node_manager_address("127.0.0.1"); + node_info.set_node_manager_port(9999); + node_info.set_state(rpc::GcsNodeInfo::ALIVE); + + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, + GetNodeAddressAndLiveness(local_node_id, false)) + .WillOnce(testing::Return(std::make_optional(node_info))) + .WillOnce(testing::Return(std::make_optional(node_info))); + auto submitter = CreateNormalTaskSubmitter(std::make_shared(1)); TaskSpecification task = BuildEmptyTaskSpec(); @@ -1831,7 +1900,9 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, local_node_id)); submitter.CancelTask(task, /*force_kill=*/false, /*recursive=*/true); ASSERT_FALSE(submitter.QueueGeneratorForResubmit(task)); - worker_client->ReplyCancelTask(); + raylet_client->ReplyCancelLocalTask(Status::OK(), + /*attempt_succeeded=*/true, + /*requested_task_running=*/false); ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task)); ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), /*exit=*/false, @@ -1849,9 +1920,9 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task2)); submitter.CancelTask(task2, /*force_kill=*/false, /*recursive=*/true); ASSERT_TRUE(worker_client->ReplyPushTask()); - worker_client->ReplyCancelTask(Status::OK(), - /*attempt_succeeded=*/true, - /*requested_task_running=*/false); + raylet_client->ReplyCancelLocalTask(Status::OK(), + /*attempt_succeeded=*/true, + /*requested_task_running=*/false); ASSERT_EQ(task_manager->num_tasks_complete, 1); ASSERT_EQ(task_manager->num_tasks_failed, 1); ASSERT_EQ(task_manager->num_generator_failed_and_resubmitted, 0); diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index 3e280b0d2eee..31b5c2cc6f17 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -216,6 +216,7 @@ class CoreWorkerTest : public ::testing::Test { fake_local_raylet_rpc_client, core_worker_client_pool, raylet_client_pool, + mock_gcs_client_, std::move(lease_policy), memory_store_, *task_manager_, @@ -226,11 +227,13 @@ class CoreWorkerTest : public ::testing::Test { JobID::Nil(), lease_request_rate_limiter, [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_service_), + io_service_, fake_scheduler_placement_time_ms_histogram_); auto actor_task_submitter = std::make_unique( *core_worker_client_pool, + *raylet_client_pool, + mock_gcs_client_, *memory_store_, *task_manager_, *actor_creator_, diff --git a/src/ray/core_worker_rpc_client/core_worker_client.h b/src/ray/core_worker_rpc_client/core_worker_client.h index b9fa6b2ea71f..be6c84698283 100644 --- a/src/ray/core_worker_rpc_client/core_worker_client.h +++ b/src/ray/core_worker_rpc_client/core_worker_client.h @@ -86,7 +86,7 @@ class CoreWorkerClient : public std::enable_shared_from_this, VOID_RETRYABLE_RPC_CLIENT_METHOD(retryable_grpc_client_, CoreWorkerService, - CancelRemoteTask, + RequestOwnerToCancelTask, grpc_client_, /*method_timeout_ms*/ -1, override) diff --git a/src/ray/core_worker_rpc_client/core_worker_client_interface.h b/src/ray/core_worker_rpc_client/core_worker_client_interface.h index 80c37f709d34..e4c7f2f97cb1 100644 --- a/src/ray/core_worker_rpc_client/core_worker_client_interface.h +++ b/src/ray/core_worker_rpc_client/core_worker_client_interface.h @@ -76,9 +76,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { virtual void CancelTask(const CancelTaskRequest &request, const ClientCallback &callback) = 0; - virtual void CancelRemoteTask( - CancelRemoteTaskRequest &&request, - const ClientCallback &callback) = 0; + virtual void RequestOwnerToCancelTask( + RequestOwnerToCancelTaskRequest &&request, + const ClientCallback &callback) = 0; virtual void RegisterMutableObjectReader( const RegisterMutableObjectReaderRequest &request, diff --git a/src/ray/core_worker_rpc_client/fake_core_worker_client.h b/src/ray/core_worker_rpc_client/fake_core_worker_client.h index 368cda8e8628..41ac5546f44d 100644 --- a/src/ray/core_worker_rpc_client/fake_core_worker_client.h +++ b/src/ray/core_worker_rpc_client/fake_core_worker_client.h @@ -82,8 +82,9 @@ class FakeCoreWorkerClient : public CoreWorkerClientInterface { void CancelTask(const CancelTaskRequest &request, const ClientCallback &callback) override {} - void CancelRemoteTask(CancelRemoteTaskRequest &&request, - const ClientCallback &callback) override {} + void RequestOwnerToCancelTask( + RequestOwnerToCancelTaskRequest &&request, + const ClientCallback &callback) override {} void RegisterMutableObjectReader( const RegisterMutableObjectReaderRequest &request, diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index c4dc447b9736..402f15b3a3a8 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -293,7 +293,7 @@ message CancelTaskReply { bool attempt_succeeded = 2; } -message CancelRemoteTaskRequest { +message RequestOwnerToCancelTaskRequest { // Object ID of the remote task that should be killed. bytes remote_object_id = 1; // Whether to kill the worker. @@ -302,7 +302,7 @@ message CancelRemoteTaskRequest { bool recursive = 3; } -message CancelRemoteTaskReply {} +message RequestOwnerToCancelTaskReply {} message GetCoreWorkerStatsRequest { // The ID of the worker this message is intended for. @@ -526,13 +526,15 @@ service CoreWorkerService { // KillLocalActor from the raylet which does implement retries rpc KillActor(KillActorRequest) returns (KillActorReply); - // Request from owner worker to executor worker to cancel a task. - // Failure: Will retry, TODO: Needs tests for failure behavior. + // Request from local raylet to executor worker to cancel a task. + // Failure: Idempotent, does not retry. However requests should only be sent via + // CancelLocalTask from the raylet which does implement retries rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply); // Request from a worker to the owner worker to issue a cancellation. // Failure: Retries, it's idempotent. - rpc CancelRemoteTask(CancelRemoteTaskRequest) returns (CancelRemoteTaskReply); + rpc RequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest) + returns (RequestOwnerToCancelTaskReply); // From raylet to get metrics from its workers. // Failure: Should not fail, always from local raylet. diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 0ee61435fc31..d9e3e1ee6dae 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -419,6 +419,26 @@ message KillLocalActorRequest { message KillLocalActorReply {} +message CancelLocalTaskRequest { + // ID of task that should be killed. + bytes intended_task_id = 1; + // Whether to kill the worker. + bool force_kill = 2; + // Whether to recursively cancel tasks. + bool recursive = 3; + // The worker ID of the caller. + bytes caller_worker_id = 4; + // The worker ID of the executor. + bytes executor_worker_id = 5; +} + +message CancelLocalTaskReply { + // Whether the requested task is the currently running task. + bool requested_task_running = 1; + // Whether the task is canceled. + bool attempt_succeeded = 2; +} + // Service for inter-node-manager communication. service NodeManagerService { // Handle the case when GCS restarted. @@ -542,4 +562,11 @@ service NodeManagerService { // without completing outstanding work. // Failure: Retries, it's idempotent. rpc KillLocalActor(KillLocalActorRequest) returns (KillLocalActorReply); + // Cancels a task on the local node. If the reply is OK, the task is canceled. + // If force_kill is true, the worker executing the task will be killed. We will + // initially try to gracefully kill the worker, and will escalate to a SIGKILL + // in the case of a hang. + // If force_kill is false, the worker remains alive for subsequent tasks. + // Failure: Retries, it's idempotent. + rpc CancelLocalTask(CancelLocalTaskRequest) returns (CancelLocalTaskReply); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 043a4711ef0c..b6082b598015 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3363,6 +3363,9 @@ void NodeManager::HandleKillLocalActor(rpc::KillLocalActorRequest request, auto timer = execute_after( io_service_, [this, send_reply_callback, worker_id, replied]() { + if (*replied) { + return; + } auto current_worker = worker_pool_.GetRegisteredWorker(worker_id); if (current_worker) { // If the worker is still alive, force kill it @@ -3388,12 +3391,16 @@ void NodeManager::HandleKillLocalActor(rpc::KillLocalActorRequest request, timer, send_reply_callback, replied](const ray::Status &status, const rpc::KillActorReply &) { - if (!status.ok() && !*replied) { + if (*replied) { + return; + } + if (!status.ok()) { std::ostringstream stream; stream << "KillActor RPC failed for actor " << actor_id << ": " << status.ToString(); const auto &msg = stream.str(); RAY_LOG(DEBUG) << msg; + *replied = true; timer->cancel(); send_reply_callback(Status::Invalid(msg), nullptr, nullptr); } @@ -3403,4 +3410,91 @@ void NodeManager::HandleKillLocalActor(rpc::KillLocalActorRequest request, }); } +void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, + rpc::CancelLocalTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + auto executor_worker_id = WorkerID::FromBinary(request.executor_worker_id()); + + auto worker = worker_pool_.GetRegisteredWorker(executor_worker_id); + // If the worker is not registered, then it must have already been killed + if (!worker || worker->IsDead()) { + reply->set_attempt_succeeded(true); + reply->set_requested_task_running(false); + send_reply_callback(Status::OK(), nullptr, nullptr); + return; + } + + WorkerID worker_id = worker->WorkerId(); + + rpc::CancelTaskRequest cancel_task_request; + cancel_task_request.set_intended_task_id(request.intended_task_id()); + cancel_task_request.set_force_kill(request.force_kill()); + cancel_task_request.set_recursive(request.recursive()); + cancel_task_request.set_caller_worker_id(request.caller_worker_id()); + + // The timer and RPC response can come back in any order since they can be queued on the + // io service before either is executed. + std::shared_ptr replied = std::make_shared(false); + std::shared_ptr timer; + + if (request.force_kill()) { + timer = execute_after( + io_service_, + [this, reply, send_reply_callback, worker_id, replied]() { + if (*replied) { + return; + } + auto current_worker = worker_pool_.GetRegisteredWorker(worker_id); + if (current_worker) { + // If the worker is still alive, force kill it + RAY_LOG(INFO) << "Worker with PID=" << current_worker->GetProcess().GetId() + << " did not exit after " + << RayConfig::instance().kill_worker_timeout_milliseconds() + << "ms, force killing with SIGKILL."; + DestroyWorker(current_worker, + rpc::WorkerExitType::INTENDED_SYSTEM_EXIT, + "Force-killed by ray.cancel(force=True)", + /*force=*/true); + } + *replied = true; + reply->set_attempt_succeeded(true); + reply->set_requested_task_running(false); + send_reply_callback(Status::OK(), nullptr, nullptr); + }, + std::chrono::milliseconds( + RayConfig::instance().kill_worker_timeout_milliseconds())); + } + + worker->rpc_client()->CancelTask( + cancel_task_request, + [task_id = request.intended_task_id(), + executor_worker_id, + timer, + reply, + send_reply_callback, + replied](const ray::Status &status, + const rpc::CancelTaskReply &cancel_task_reply) { + // Check if timer already fired (only relevant for force_kill case) + if (*replied) { + return; + } + if (!status.ok()) { + RAY_LOG(WARNING) << "CancelTask RPC failed for task " + << TaskID::FromBinary(task_id) << ": " << status.ToString() + << "with Worker ID: " << executor_worker_id; + if (timer) { + RAY_LOG(INFO) << "Escalating graceful shutdown to SIGKILL instead."; + return; + } + } + *replied = true; + reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); + reply->set_requested_task_running(cancel_task_reply.requested_task_running()); + send_reply_callback(status, nullptr, nullptr); + if (timer) { + timer->cancel(); + } + }); +} + } // namespace ray::raylet diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index b2d2138bee20..8d669162f8a0 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -318,6 +318,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler, rpc::KillLocalActorReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, + rpc::CancelLocalTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + private: FRIEND_TEST(NodeManagerStaticTest, TestHandleReportWorkerBacklog); diff --git a/src/ray/raylet_rpc_client/fake_raylet_client.h b/src/ray/raylet_rpc_client/fake_raylet_client.h index 931c5d52643b..37645bd7aa62 100644 --- a/src/ray/raylet_rpc_client/fake_raylet_client.h +++ b/src/ray/raylet_rpc_client/fake_raylet_client.h @@ -290,23 +290,28 @@ class FakeRayletClient : public RayletClientInterface { int64_t GetPinsInFlight() const override { return 0; } + void CancelLocalTask(const CancelLocalTaskRequest &request, + const ClientCallback &callback) override { + num_cancel_local_task_requested += 1; + } + int num_workers_requested = 0; int num_workers_returned = 0; int num_workers_disconnected = 0; int num_leases_canceled = 0; int num_release_unused_workers = 0; int num_get_task_failure_causes = 0; + int num_lease_requested = 0; + int num_return_requested = 0; + int num_commit_requested = 0; + int num_cancel_local_task_requested = 0; + int num_release_unused_bundles_requested = 0; NodeID node_id_ = NodeID::FromRandom(); std::vector killed_actors; std::list> drain_raylet_callbacks = {}; std::list> callbacks = {}; std::list> cancel_callbacks = {}; std::list> release_callbacks = {}; - int num_lease_requested = 0; - int num_return_requested = 0; - int num_commit_requested = 0; - - int num_release_unused_bundles_requested = 0; std::list> lease_callbacks = {}; std::list> commit_callbacks = {}; std::list> return_callbacks = {}; diff --git a/src/ray/raylet_rpc_client/raylet_client.cc b/src/ray/raylet_rpc_client/raylet_client.cc index d76b49716331..83a6507b754d 100644 --- a/src/ray/raylet_rpc_client/raylet_client.cc +++ b/src/ray/raylet_rpc_client/raylet_client.cc @@ -508,5 +508,17 @@ void RayletClient::KillLocalActor( /*method_timeout_ms*/ -1); } +void RayletClient::CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) { + INVOKE_RETRYABLE_RPC_CALL(retryable_grpc_client_, + NodeManagerService, + CancelLocalTask, + request, + callback, + grpc_client_, + /*method_timeout_ms*/ -1); +} + } // namespace rpc } // namespace ray diff --git a/src/ray/raylet_rpc_client/raylet_client.h b/src/ray/raylet_rpc_client/raylet_client.h index 07de0073ef02..98db543fe258 100644 --- a/src/ray/raylet_rpc_client/raylet_client.h +++ b/src/ray/raylet_rpc_client/raylet_client.h @@ -174,6 +174,10 @@ class RayletClient : public RayletClientInterface { void GetWorkerPIDs(const gcs::OptionalItemCallback> &callback, int64_t timeout_ms); + void CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) override; + protected: /// gRPC client to the NodeManagerService. std::shared_ptr> grpc_client_; diff --git a/src/ray/raylet_rpc_client/raylet_client_interface.h b/src/ray/raylet_rpc_client/raylet_client_interface.h index fdd0035e224c..19c07e3a09b5 100644 --- a/src/ray/raylet_rpc_client/raylet_client_interface.h +++ b/src/ray/raylet_rpc_client/raylet_client_interface.h @@ -217,6 +217,10 @@ class RayletClientInterface { virtual int64_t GetPinsInFlight() const = 0; + virtual void CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) = 0; + virtual ~RayletClientInterface() = default; }; diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index f68f54b56862..ff453d249b37 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -64,7 +64,8 @@ class ServerCallFactory; RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(RegisterMutableObject) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(PushMutableObject) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetWorkerPIDs) \ - RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(KillLocalActor) + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(KillLocalActor) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(CancelLocalTask) /// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`. class NodeManagerServiceHandler { @@ -196,6 +197,10 @@ class NodeManagerServiceHandler { virtual void HandleKillLocalActor(KillLocalActorRequest request, KillLocalActorReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleCancelLocalTask(CancelLocalTaskRequest request, + CancelLocalTaskReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `NodeManagerService`.