Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f8150c0
Make CancelTask RPC Fault Tolerant
Sparks0219 Oct 22, 2025
0a630a7
Addressing comments
Sparks0219 Oct 22, 2025
8ae4e3a
clean up and cpp test failures
Sparks0219 Oct 22, 2025
a733422
Addressing comments
Sparks0219 Oct 23, 2025
8a2e428
Fix broken cpp tests
Sparks0219 Oct 23, 2025
901099d
Fix merge conflicts
Sparks0219 Nov 7, 2025
7d4ab2e
Clean up
Sparks0219 Nov 7, 2025
9070db5
lint
Sparks0219 Nov 7, 2025
dcec398
Addressing comments
Sparks0219 Nov 12, 2025
9df37aa
Fix cpp test failures
Sparks0219 Nov 12, 2025
d846b90
Addressing comments
Sparks0219 Nov 13, 2025
873a17c
Addressing comments
Sparks0219 Nov 13, 2025
430a4a6
Merge remote-tracking branch 'upstream/master' into joshlee/make-canc…
Sparks0219 Nov 13, 2025
a253c81
fix build error
Sparks0219 Nov 14, 2025
9d5cf6f
Addressing comments
Sparks0219 Nov 14, 2025
d0fddda
Merge conflicts
Sparks0219 Nov 18, 2025
c8e0ed6
Addressing comments
Sparks0219 Nov 18, 2025
49250fb
Bad merge conflict fix
Sparks0219 Nov 18, 2025
3dbcc22
Addressing comments
Sparks0219 Nov 18, 2025
73445ab
Fix cpp test
Sparks0219 Nov 18, 2025
f737eef
Addressing comments
Sparks0219 Nov 19, 2025
0429f79
Addressing comments
Sparks0219 Nov 20, 2025
2f9c24e
Addressing comments
Sparks0219 Nov 20, 2025
bed7884
Fix cpp test error
Sparks0219 Nov 21, 2025
2a66834
Addressing comments
Sparks0219 Nov 22, 2025
0fb240e
Merge remote-tracking branch 'upstream/master' into joshlee/make-canc…
Sparks0219 Nov 26, 2025
6bab852
Removing io context posts now that accessor node cache is thread safe
Sparks0219 Nov 26, 2025
c1f1e0f
Merge branch 'master' into joshlee/make-cancel-task-fault-tolerant
edoakes Nov 26, 2025
758ecd6
Merge branch 'master' into joshlee/make-cancel-task-fault-tolerant
jjyao Dec 2, 2025
22a53f5
Deflake serve test
Sparks0219 Dec 3, 2025
e053c2f
AI comment
Sparks0219 Dec 3, 2025
6b2674b
Addressing AI comments
Sparks0219 Dec 3, 2025
07da167
More AI comments
Sparks0219 Dec 3, 2025
41fa586
Addressing comments
Sparks0219 Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/tests/test_core_worker_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def inject_cancel_remote_task_rpc_failure(monkeypatch, request):
deterministic_failure = request.param
monkeypatch.setenv(
"RAY_testing_rpc_failure",
"CoreWorkerService.grpc_client.CancelRemoteTask=1:"
"CoreWorkerService.grpc_client.RequestOwnerToCancelTask=1:"
+ ("100:0" if deterministic_failure == "request" else "0:100"),
)

Expand Down
56 changes: 55 additions & 1 deletion python/ray/tests/test_raylet_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pytest

import ray
from ray._private.test_utils import wait_for_condition
from ray._common.test_utils import SignalActor, wait_for_condition
from ray.core.generated import autoscaler_pb2
from ray.exceptions import GetTimeoutError, TaskCancelledError
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
Expand Down Expand Up @@ -180,5 +181,58 @@ def verify_process_killed():
wait_for_condition(verify_process_killed, timeout=30)


@pytest.fixture
def inject_cancel_local_task_rpc_failure(monkeypatch, request):
deterministic_failure = request.param
monkeypatch.setenv(
"RAY_testing_rpc_failure",
"NodeManagerService.grpc_client.CancelLocalTask=1:"
+ ("100:0" if deterministic_failure == "request" else "0:100"),
)


@pytest.mark.parametrize(
"inject_cancel_local_task_rpc_failure", ["request", "response"], indirect=True
)
@pytest.mark.parametrize("force_kill", [True, False])
def test_cancel_local_task_rpc_retry_and_idempotency(
inject_cancel_local_task_rpc_failure, force_kill, shutdown_only
):
"""Test that CancelLocalTask RPC retries work correctly.

Verify that the RPC is idempotent when network failures occur.
When force_kill=True, verify the worker process is actually killed using psutil.
"""
ray.init(num_cpus=2)
signaler = SignalActor.remote()

@ray.remote(num_cpus=1)
def get_pid():
return os.getpid()

@ray.remote(num_cpus=1)
def blocking_task():
return ray.get(signaler.wait.remote())

worker_pid = ray.get(get_pid.remote())

blocking_ref = blocking_task.remote()

with pytest.raises(GetTimeoutError):
ray.get(blocking_ref, timeout=1)

ray.cancel(blocking_ref, force=force_kill)

with pytest.raises(TaskCancelledError):
ray.get(blocking_ref, timeout=10)

if force_kill:

def verify_process_killed():
return not psutil.pid_exists(worker_pid)

wait_for_condition(verify_process_killed, timeout=30)


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
6 changes: 3 additions & 3 deletions src/mock/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class MockCoreWorker : public CoreWorker {
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
HandleCancelRemoteTask,
(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
HandleRequestOwnerToCancelTask,
(rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
Expand Down
5 changes: 5 additions & 0 deletions src/mock/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class MockRayletClientInterface : public RayletClientInterface {
(const rpc::ClientCallback<rpc::GlobalGCReply> &callback),
(override));
MOCK_METHOD(int64_t, GetPinsInFlight, (), (const, override));
MOCK_METHOD(void,
CancelLocalTask,
(const rpc::CancelLocalTaskRequest &request,
const rpc::ClientCallback<rpc::CancelLocalTaskReply> &callback),
(override));
};

} // namespace ray
6 changes: 3 additions & 3 deletions src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class MockCoreWorkerClientInterface : public CoreWorkerClientInterface {
const ClientCallback<CancelTaskReply> &callback),
(override));
MOCK_METHOD(void,
CancelRemoteTask,
(CancelRemoteTaskRequest && request,
const ClientCallback<CancelRemoteTaskReply> &callback),
RequestOwnerToCancelTask,
(RequestOwnerToCancelTaskRequest && request,
const ClientCallback<RequestOwnerToCancelTaskReply> &callback),
(override));
MOCK_METHOD(void,
GetCoreWorkerStats,
Expand Down
12 changes: 7 additions & 5 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2439,12 +2439,13 @@ Status CoreWorker::CancelTask(const ObjectID &object_id,
}

if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) {
// We don't have CancelRemoteTask for actor_task_submitter_
// We don't have RequestOwnerToCancelTask for actor_task_submitter_
// because it requires the same implementation.
RAY_LOG(DEBUG).WithField(object_id)
<< "Request to cancel a task of object to an owner "
<< obj_addr.SerializeAsString();
normal_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill, recursive);
normal_task_submitter_->RequestOwnerToCancelTask(
object_id, obj_addr, force_kill, recursive);
return Status::OK();
}

Expand Down Expand Up @@ -3878,9 +3879,10 @@ void CoreWorker::ProcessSubscribeForRefRemoved(
reference_counter_->SubscribeRefRemoved(object_id, contained_in_id, owner_address);
}

void CoreWorker::HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
void CoreWorker::HandleRequestOwnerToCancelTask(
rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()),
request.force_kill(),
request.recursive());
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1208,9 +1208,9 @@ class CoreWorker {
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
void HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
void HandleRequestOwnerToCancelTask(rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
void HandlePlasmaObjectReady(rpc::PlasmaObjectReadyRequest request,
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/core_worker_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(

auto actor_task_submitter = std::make_unique<ActorTaskSubmitter>(
*core_worker_client_pool,
*raylet_client_pool,
gcs_client,
*memory_store,
*task_manager,
*actor_creator,
Expand Down Expand Up @@ -535,6 +537,7 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(
local_raylet_rpc_client,
core_worker_client_pool,
raylet_client_pool,
gcs_client,
std::move(lease_policy),
memory_store,
*task_manager,
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker_rpc_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CoreWorkerServiceHandlerProxy : public rpc::CoreWorkerServiceHandler {
RAY_CORE_WORKER_RPC_PROXY(ReportGeneratorItemReturns)
RAY_CORE_WORKER_RPC_PROXY(KillActor)
RAY_CORE_WORKER_RPC_PROXY(CancelTask)
RAY_CORE_WORKER_RPC_PROXY(CancelRemoteTask)
RAY_CORE_WORKER_RPC_PROXY(RequestOwnerToCancelTask)
RAY_CORE_WORKER_RPC_PROXY(RegisterMutableObjectReader)
RAY_CORE_WORKER_RPC_PROXY(GetCoreWorkerStats)
RAY_CORE_WORKER_RPC_PROXY(LocalGC)
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/grpc_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void CoreWorkerGrpcService::InitServerCallFactories(
max_active_rpcs_per_handler_,
ClusterIdAuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
CancelRemoteTask,
RequestOwnerToCancelTask,
max_active_rpcs_per_handler_,
ClusterIdAuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/grpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class CoreWorkerServiceHandler : public DelayedServiceHandler {
CancelTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleCancelRemoteTask(CancelRemoteTaskRequest request,
CancelRemoteTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleRequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest request,
RequestOwnerToCancelTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleRegisterMutableObjectReader(
RegisterMutableObjectReaderRequest request,
Expand Down
5 changes: 5 additions & 0 deletions src/ray/core_worker/task_submission/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ ray_cc_library(
"//src/ray/common:protobuf_utils",
"//src/ray/core_worker:actor_creator",
"//src/ray/core_worker_rpc_client:core_worker_client_pool",
"//src/ray/gcs_rpc_client:gcs_client",
"//src/ray/raylet_rpc_client:raylet_client_interface",
"//src/ray/raylet_rpc_client:raylet_client_pool",
"//src/ray/rpc:rpc_callback_types",
"//src/ray/util:time",
"@com_google_absl//absl/base:core_headers",
Expand All @@ -96,7 +99,9 @@ ray_cc_library(
"//src/ray/core_worker:memory_store",
"//src/ray/core_worker:task_manager_interface",
"//src/ray/core_worker_rpc_client:core_worker_client_pool",
"//src/ray/gcs_rpc_client:gcs_client",
"//src/ray/raylet_rpc_client:raylet_client_interface",
"//src/ray/raylet_rpc_client:raylet_client_pool",
"//src/ray/util:time",
"@com_google_absl//absl/base:core_headers",
],
Expand Down
92 changes: 70 additions & 22 deletions src/ray/core_worker/task_submission/actor_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive)

// If there's no client, it means actor is not created yet.
// Retry in 1 second.
rpc::Address client_address;
{
absl::MutexLock lock(&mu_);
RAY_LOG(DEBUG).WithField(task_id) << "Task was sent to an actor. Send a cancel RPC.";
Expand All @@ -1007,34 +1008,81 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive)
RetryCancelTask(task_spec, recursive, 1000);
return;
}
client_address = queue->second.client_address_.value();
}

rpc::CancelTaskRequest request;
const auto node_id = NodeID::FromBinary(client_address.node_id());
const auto executor_worker_id = client_address.worker_id();

auto do_cancel_local_task = [this,
task_spec = std::move(task_spec),
task_id,
force_kill,
recursive,
executor_worker_id = std::move(executor_worker_id)](
const rpc::GcsNodeInfo &node_info) mutable {
rpc::Address raylet_address;
raylet_address.set_node_id(node_info.node_id());
raylet_address.set_ip_address(node_info.node_manager_address());
raylet_address.set_port(node_info.node_manager_port());

rpc::CancelLocalTaskRequest request;
request.set_intended_task_id(task_spec.TaskIdBinary());
request.set_force_kill(force_kill);
request.set_recursive(recursive);
request.set_caller_worker_id(task_spec.CallerWorkerIdBinary());
auto client = core_worker_client_pool_.GetOrConnect(*queue->second.client_address_);
client->CancelTask(request,
[this, task_spec = std::move(task_spec), recursive, task_id](
const Status &status, const rpc::CancelTaskReply &reply) {
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "CancelTask RPC response received with status "
<< status.ToString();

// Keep retrying every 2 seconds until a task is officially
// finished.
if (!task_manager_.GetTaskSpec(task_id)) {
// Task is already finished.
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task is finished. Stop a cancel request.";
return;
}

if (!reply.attempt_succeeded()) {
RetryCancelTask(task_spec, recursive, 2000);
}
});
request.set_executor_worker_id(executor_worker_id);

auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address);
raylet_client->CancelLocalTask(
request,
[this, task_spec = std::move(task_spec), recursive, task_id](
const Status &status, const rpc::CancelLocalTaskReply &reply) mutable {
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "CancelTask RPC response received with status " << status.ToString();

// Keep retrying every 2 seconds until a task is officially
// finished.
if (!task_manager_.GetTaskSpec(task_id)) {
// Task is already finished.
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task is finished. Stop a cancel request.";
return;
}

if (!reply.attempt_succeeded()) {
RetryCancelTask(std::move(task_spec), recursive, 2000);
}
});
};

// Check GCS node cache. If node info is not in the cache, query the GCS instead.
auto *node_info = gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false);
if (node_info == nullptr) {
gcs_client_->Nodes().AsyncGetAll(
[do_cancel_local_task = std::move(do_cancel_local_task), node_id](
const Status &status, std::vector<rpc::GcsNodeInfo> &&nodes) mutable {
if (!status.ok()) {
RAY_LOG(INFO) << "Failed to get node info from GCS";
return;
}
if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) {
RAY_LOG(INFO).WithField(node_id)
<< "Not sending CancelLocalTask because node is dead";
return;
}
do_cancel_local_task(nodes[0]);
},
-1,
{node_id});
return;
}
if (node_info->state() == rpc::GcsNodeInfo::DEAD) {
RAY_LOG(INFO).WithField(node_id)
<< "Not sending CancelLocalTask because node is dead";
return;
}
do_cancel_local_task(*node_info);
}

bool ActorTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) {
Expand Down
11 changes: 10 additions & 1 deletion src/ray/core_worker/task_submission/actor_task_submitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class ActorTaskSubmitterInterface {
class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
public:
ActorTaskSubmitter(rpc::CoreWorkerClientPool &core_worker_client_pool,
rpc::RayletClientPool &raylet_client_pool,
std::shared_ptr<gcs::GcsClient> gcs_client,
CoreWorkerMemoryStore &store,
TaskManagerInterface &task_manager,
ActorCreatorInterface &actor_creator,
Expand All @@ -76,6 +78,8 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
instrumented_io_context &io_service,
std::shared_ptr<ReferenceCounterInterface> reference_counter)
: core_worker_client_pool_(core_worker_client_pool),
raylet_client_pool_(raylet_client_pool),
gcs_client_(std::move(gcs_client)),
actor_creator_(actor_creator),
resolver_(store, task_manager, actor_creator, tensor_transport_getter),
task_manager_(task_manager),
Expand Down Expand Up @@ -300,7 +304,7 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
int64_t num_restarts_due_to_lineage_reconstructions_ = 0;
/// Whether this actor exits by spot preemption.
bool preempted_ = false;
/// The RPC client address.
/// The RPC client address of the actor.
std::optional<rpc::Address> client_address_;
/// The intended worker ID of the actor.
std::string worker_id_;
Expand Down Expand Up @@ -411,6 +415,11 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
/// Pool for producing new core worker clients.
rpc::CoreWorkerClientPool &core_worker_client_pool_;

/// Pool for producing new raylet clients.
rpc::RayletClientPool &raylet_client_pool_;

std::shared_ptr<gcs::GcsClient> gcs_client_;

ActorCreatorInterface &actor_creator_;

/// Mutex to protect the various maps below.
Expand Down
Loading