Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f8150c0
Make CancelTask RPC Fault Tolerant
Sparks0219 Oct 22, 2025
0a630a7
Addressing comments
Sparks0219 Oct 22, 2025
8ae4e3a
clean up and cpp test failures
Sparks0219 Oct 22, 2025
a733422
Addressing comments
Sparks0219 Oct 23, 2025
8a2e428
Fix broken cpp tests
Sparks0219 Oct 23, 2025
901099d
Fix merge conflicts
Sparks0219 Nov 7, 2025
7d4ab2e
Clean up
Sparks0219 Nov 7, 2025
9070db5
lint
Sparks0219 Nov 7, 2025
dcec398
Addressing comments
Sparks0219 Nov 12, 2025
9df37aa
Fix cpp test failures
Sparks0219 Nov 12, 2025
d846b90
Addressing comments
Sparks0219 Nov 13, 2025
873a17c
Addressing comments
Sparks0219 Nov 13, 2025
430a4a6
Merge remote-tracking branch 'upstream/master' into joshlee/make-canc…
Sparks0219 Nov 13, 2025
a253c81
fix build error
Sparks0219 Nov 14, 2025
9d5cf6f
Addressing comments
Sparks0219 Nov 14, 2025
d0fddda
Merge conflicts
Sparks0219 Nov 18, 2025
c8e0ed6
Addressing comments
Sparks0219 Nov 18, 2025
49250fb
Bad merge conflict fix
Sparks0219 Nov 18, 2025
3dbcc22
Addressing comments
Sparks0219 Nov 18, 2025
73445ab
Fix cpp test
Sparks0219 Nov 18, 2025
f737eef
Addressing comments
Sparks0219 Nov 19, 2025
0429f79
Addressing comments
Sparks0219 Nov 20, 2025
2f9c24e
Addressing comments
Sparks0219 Nov 20, 2025
bed7884
Fix cpp test error
Sparks0219 Nov 21, 2025
2a66834
Addressing comments
Sparks0219 Nov 22, 2025
0fb240e
Merge remote-tracking branch 'upstream/master' into joshlee/make-canc…
Sparks0219 Nov 26, 2025
6bab852
Removing io context posts now that accessor node cache is thread safe
Sparks0219 Nov 26, 2025
c1f1e0f
Merge branch 'master' into joshlee/make-cancel-task-fault-tolerant
edoakes Nov 26, 2025
758ecd6
Merge branch 'master' into joshlee/make-cancel-task-fault-tolerant
jjyao Dec 2, 2025
22a53f5
Deflake serve test
Sparks0219 Dec 3, 2025
e053c2f
AI comment
Sparks0219 Dec 3, 2025
6b2674b
Addressing AI comments
Sparks0219 Dec 3, 2025
07da167
More AI comments
Sparks0219 Dec 3, 2025
41fa586
Addressing comments
Sparks0219 Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/tests/test_core_worker_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def inject_cancel_remote_task_rpc_failure(monkeypatch, request):
failure = RPC_FAILURE_MAP[deterministic_failure]
monkeypatch.setenv(
"RAY_testing_rpc_failure",
f"CoreWorkerService.grpc_client.CancelRemoteTask=1:{failure}",
f"CoreWorkerService.grpc_client.RequestOwnerToCancelTask=1:{failure}",
)


Expand Down
55 changes: 54 additions & 1 deletion python/ray/tests/test_raylet_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import pytest

import ray
from ray._common.test_utils import SignalActor, wait_for_condition
from ray._private.test_utils import (
RPC_FAILURE_MAP,
RPC_FAILURE_TYPES,
wait_for_condition,
)
from ray.core.generated import autoscaler_pb2
from ray.exceptions import GetTimeoutError, TaskCancelledError
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
Expand Down Expand Up @@ -247,5 +248,57 @@ def verify_process_killed():
wait_for_condition(verify_process_killed, timeout=30)


@pytest.fixture
def inject_cancel_local_task_rpc_failure(monkeypatch, request):
failure = RPC_FAILURE_MAP[request.param]
monkeypatch.setenv(
"RAY_testing_rpc_failure",
f"NodeManagerService.grpc_client.CancelLocalTask=1:{failure}",
)


@pytest.mark.parametrize(
"inject_cancel_local_task_rpc_failure", RPC_FAILURE_TYPES, indirect=True
)
@pytest.mark.parametrize("force_kill", [True, False])
def test_cancel_local_task_rpc_retry_and_idempotency(
inject_cancel_local_task_rpc_failure, force_kill, shutdown_only
):
"""Test that CancelLocalTask RPC retries work correctly.

Verify that the RPC is idempotent when network failures occur.
When force_kill=True, verify the worker process is actually killed using psutil.
"""
ray.init(num_cpus=1)
signaler = SignalActor.remote()

@ray.remote(num_cpus=1)
def get_pid():
return os.getpid()

@ray.remote(num_cpus=1)
def blocking_task():
return ray.get(signaler.wait.remote())

worker_pid = ray.get(get_pid.remote())

blocking_ref = blocking_task.remote()

with pytest.raises(GetTimeoutError):
ray.get(blocking_ref, timeout=1)

ray.cancel(blocking_ref, force=force_kill)

with pytest.raises(TaskCancelledError):
ray.get(blocking_ref, timeout=10)

if force_kill:

def verify_process_killed():
return not psutil.pid_exists(worker_pid)

wait_for_condition(verify_process_killed, timeout=30)


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
6 changes: 3 additions & 3 deletions src/mock/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class MockCoreWorker : public CoreWorker {
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
HandleCancelRemoteTask,
(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
HandleRequestOwnerToCancelTask,
(rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
Expand Down
5 changes: 5 additions & 0 deletions src/mock/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class MockRayletClientInterface : public RayletClientInterface {
(const rpc::ClientCallback<rpc::GlobalGCReply> &callback),
(override));
MOCK_METHOD(int64_t, GetPinsInFlight, (), (const, override));
MOCK_METHOD(void,
CancelLocalTask,
(const rpc::CancelLocalTaskRequest &request,
const rpc::ClientCallback<rpc::CancelLocalTaskReply> &callback),
(override));
};

} // namespace ray
6 changes: 3 additions & 3 deletions src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class MockCoreWorkerClientInterface : public CoreWorkerClientInterface {
const ClientCallback<CancelTaskReply> &callback),
(override));
MOCK_METHOD(void,
CancelRemoteTask,
(CancelRemoteTaskRequest && request,
const ClientCallback<CancelRemoteTaskReply> &callback),
RequestOwnerToCancelTask,
(RequestOwnerToCancelTaskRequest && request,
const ClientCallback<RequestOwnerToCancelTaskReply> &callback),
(override));
MOCK_METHOD(void,
GetCoreWorkerStats,
Expand Down
12 changes: 7 additions & 5 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2460,12 +2460,13 @@ Status CoreWorker::CancelTask(const ObjectID &object_id,
}

if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) {
// We don't have CancelRemoteTask for actor_task_submitter_
// We don't have RequestOwnerToCancelTask for actor_task_submitter_
// because it requires the same implementation.
RAY_LOG(DEBUG).WithField(object_id)
<< "Request to cancel a task of object to an owner "
<< obj_addr.SerializeAsString();
normal_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill, recursive);
normal_task_submitter_->RequestOwnerToCancelTask(
object_id, obj_addr, force_kill, recursive);
return Status::OK();
}

Expand Down Expand Up @@ -3910,9 +3911,10 @@ void CoreWorker::ProcessSubscribeForRefRemoved(
reference_counter_->SubscribeRefRemoved(object_id, contained_in_id, owner_address);
}

void CoreWorker::HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
void CoreWorker::HandleRequestOwnerToCancelTask(
rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()),
request.force_kill(),
request.recursive());
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1208,9 +1208,9 @@ class CoreWorker {
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
void HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
void HandleRequestOwnerToCancelTask(rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
void HandlePlasmaObjectReady(rpc::PlasmaObjectReadyRequest request,
Expand Down
5 changes: 4 additions & 1 deletion src/ray/core_worker/core_worker_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(

auto actor_task_submitter = std::make_unique<ActorTaskSubmitter>(
*core_worker_client_pool,
*raylet_client_pool,
gcs_client,
*memory_store,
*task_manager,
*actor_creator,
Expand Down Expand Up @@ -537,6 +539,7 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(
local_raylet_rpc_client,
core_worker_client_pool,
raylet_client_pool,
gcs_client,
std::move(lease_policy),
memory_store,
*task_manager,
Expand All @@ -553,7 +556,7 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(
// OBJECT_STORE.
return rpc::TensorTransport::OBJECT_STORE;
},
boost::asio::steady_timer(io_service_),
io_service_,
*scheduler_placement_time_ms_histogram_);

auto report_locality_data_callback = [this](
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker_rpc_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CoreWorkerServiceHandlerProxy : public rpc::CoreWorkerServiceHandler {
RAY_CORE_WORKER_RPC_PROXY(ReportGeneratorItemReturns)
RAY_CORE_WORKER_RPC_PROXY(KillActor)
RAY_CORE_WORKER_RPC_PROXY(CancelTask)
RAY_CORE_WORKER_RPC_PROXY(CancelRemoteTask)
RAY_CORE_WORKER_RPC_PROXY(RequestOwnerToCancelTask)
RAY_CORE_WORKER_RPC_PROXY(RegisterMutableObjectReader)
RAY_CORE_WORKER_RPC_PROXY(GetCoreWorkerStats)
RAY_CORE_WORKER_RPC_PROXY(LocalGC)
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/grpc_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void CoreWorkerGrpcService::InitServerCallFactories(
max_active_rpcs_per_handler_,
ClusterIdAuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
CancelRemoteTask,
RequestOwnerToCancelTask,
max_active_rpcs_per_handler_,
ClusterIdAuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/grpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class CoreWorkerServiceHandler : public DelayedServiceHandler {
CancelTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleCancelRemoteTask(CancelRemoteTaskRequest request,
CancelRemoteTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleRequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest request,
RequestOwnerToCancelTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleRegisterMutableObjectReader(
RegisterMutableObjectReaderRequest request,
Expand Down
18 changes: 18 additions & 0 deletions src/ray/core_worker/task_submission/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ ray_cc_library(
],
)

ray_cc_library(
name = "task_submission_util",
hdrs = ["task_submission_util.h"],
visibility = [":__subpackages__"],
deps = [
"//src/ray/common:asio",
"//src/ray/common:id",
"//src/ray/gcs_rpc_client:gcs_client",
],
)

ray_cc_library(
name = "actor_task_submitter",
srcs = ["actor_task_submitter.cc"],
Expand All @@ -66,12 +77,16 @@ ray_cc_library(
":dependency_resolver",
":out_of_order_actor_submit_queue",
":sequential_actor_submit_queue",
":task_submission_util",
"//src/ray/common:asio",
"//src/ray/common:id",
"//src/ray/common:protobuf_utils",
"//src/ray/core_worker:actor_creator",
"//src/ray/core_worker:reference_counter_interface",
"//src/ray/core_worker_rpc_client:core_worker_client_pool",
"//src/ray/gcs_rpc_client:gcs_client",
"//src/ray/raylet_rpc_client:raylet_client_interface",
"//src/ray/raylet_rpc_client:raylet_client_pool",
"//src/ray/rpc:rpc_callback_types",
"//src/ray/util:time",
"@com_google_absl//absl/base:core_headers",
Expand All @@ -90,14 +105,17 @@ ray_cc_library(
],
deps = [
":dependency_resolver",
":task_submission_util",
"//src/ray/common:id",
"//src/ray/common:lease",
"//src/ray/common:protobuf_utils",
"//src/ray/core_worker:lease_policy",
"//src/ray/core_worker:memory_store",
"//src/ray/core_worker:task_manager_interface",
"//src/ray/core_worker_rpc_client:core_worker_client_pool",
"//src/ray/gcs_rpc_client:gcs_client",
"//src/ray/raylet_rpc_client:raylet_client_interface",
"//src/ray/raylet_rpc_client:raylet_client_pool",
"//src/ray/util:time",
"@com_google_absl//absl/base:core_headers",
],
Expand Down
88 changes: 54 additions & 34 deletions src/ray/core_worker/task_submission/actor_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <vector>

#include "ray/common/protobuf_utils.h"
#include "ray/core_worker/task_submission/task_submission_util.h"
#include "ray/util/time.h"

namespace ray {
Expand Down Expand Up @@ -912,17 +913,16 @@ std::string ActorTaskSubmitter::DebugString(const ActorID &actor_id) const {
return stream.str();
}

void ActorTaskSubmitter::RetryCancelTask(TaskSpecification task_spec,
bool recursive,
int64_t milliseconds) {
void ActorTaskSubmitter::RetryCancelTask(TaskSpecification task_spec, bool recursive) {
auto delay_ms = RayConfig::instance().cancellation_retry_ms();
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task cancelation will be retried in " << milliseconds << " ms";
<< "Task cancelation will be retried in " << delay_ms << " ms";
execute_after(
io_service_,
[this, task_spec = std::move(task_spec), recursive] {
CancelTask(task_spec, recursive);
},
std::chrono::milliseconds(milliseconds));
std::chrono::milliseconds(delay_ms));
}

void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) {
Expand Down Expand Up @@ -997,44 +997,64 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive)
// an executor tells us to stop retrying.

// If there's no client, it means actor is not created yet.
// Retry in 1 second.
// Retry after the configured delay.
NodeID node_id;
std::string executor_worker_id;
{
absl::MutexLock lock(&mu_);
RAY_LOG(DEBUG).WithField(task_id) << "Task was sent to an actor. Send a cancel RPC.";
auto queue = client_queues_.find(actor_id);
RAY_CHECK(queue != client_queues_.end());
if (!queue->second.client_address_.has_value()) {
RetryCancelTask(task_spec, recursive, 1000);
RetryCancelTask(task_spec, recursive);
return;
}

rpc::CancelTaskRequest request;
request.set_intended_task_id(task_spec.TaskIdBinary());
request.set_force_kill(force_kill);
request.set_recursive(recursive);
request.set_caller_worker_id(task_spec.CallerWorkerIdBinary());
auto client = core_worker_client_pool_.GetOrConnect(*queue->second.client_address_);
client->CancelTask(request,
[this, task_spec = std::move(task_spec), recursive, task_id](
const Status &status, const rpc::CancelTaskReply &reply) {
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "CancelTask RPC response received with status "
<< status.ToString();

// Keep retrying every 2 seconds until a task is officially
// finished.
if (!task_manager_.GetTaskSpec(task_id)) {
// Task is already finished.
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task is finished. Stop a cancel request.";
return;
}

if (!reply.attempt_succeeded()) {
RetryCancelTask(task_spec, recursive, 2000);
}
});
node_id = NodeID::FromBinary(queue->second.client_address_.value().node_id());
executor_worker_id = queue->second.client_address_.value().worker_id();
}

auto do_cancel_local_task =
[this, task_spec = std::move(task_spec), force_kill, recursive, executor_worker_id](
const rpc::Address &raylet_address) mutable {
rpc::CancelLocalTaskRequest request;
request.set_intended_task_id(task_spec.TaskIdBinary());
request.set_force_kill(force_kill);
request.set_recursive(recursive);
request.set_caller_worker_id(task_spec.CallerWorkerIdBinary());
request.set_executor_worker_id(executor_worker_id);

auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address);
raylet_client->CancelLocalTask(
request,
[this, task_spec = std::move(task_spec), recursive](
const Status &status, const rpc::CancelLocalTaskReply &reply) mutable {
if (!status.ok()) {
RAY_LOG(INFO) << "CancelLocalTask RPC failed for task "
<< task_spec.TaskId() << ": " << status.ToString()
<< " due to node death";
return;
} else {
RAY_LOG(INFO) << "CancelLocalTask RPC response received for "
<< task_spec.TaskId()
<< " with attempt_succeeded: " << reply.attempt_succeeded()
<< " requested_task_running: "
<< reply.requested_task_running();
}
// Keep retrying until a task is officially finished.
if (!reply.attempt_succeeded()) {
RetryCancelTask(std::move(task_spec), recursive);
}
});
};

// Cancel can execute on the user's python thread, but the GCS node cache is updated on
// the io service thread and is not thread-safe. Hence we need to post the entire
// cache access to the io service thread.
io_service_.post(
[this, node_id, do_cancel_local_task = std::move(do_cancel_local_task)]() mutable {
SendCancelLocalTask(gcs_client_, node_id, std::move(do_cancel_local_task));
},
"ActorTaskSubmitter.CancelTask");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

almost all of this new code to make the rpc and retry is shared with the normal task submitter, can it be deduplicated with some shared utility func

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't see a util file for task_submission so created one and moved the io_service_.post part to it since it's identical as you mentioned between actor and normal task submitter

}

bool ActorTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) {
Expand Down
Loading