Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f8150c0
Make CancelTask RPC Fault Tolerant
Sparks0219 Oct 22, 2025
0a630a7
Addressing comments
Sparks0219 Oct 22, 2025
8ae4e3a
clean up and cpp test failures
Sparks0219 Oct 22, 2025
a733422
Addressing comments
Sparks0219 Oct 23, 2025
8a2e428
Fix broken cpp tests
Sparks0219 Oct 23, 2025
901099d
Fix merge conflicts
Sparks0219 Nov 7, 2025
7d4ab2e
Clean up
Sparks0219 Nov 7, 2025
9070db5
lint
Sparks0219 Nov 7, 2025
dcec398
Addressing comments
Sparks0219 Nov 12, 2025
9df37aa
Fix cpp test failures
Sparks0219 Nov 12, 2025
d846b90
Addressing comments
Sparks0219 Nov 13, 2025
873a17c
Addressing comments
Sparks0219 Nov 13, 2025
430a4a6
Merge remote-tracking branch 'upstream/master' into joshlee/make-canc…
Sparks0219 Nov 13, 2025
a253c81
fix build error
Sparks0219 Nov 14, 2025
9d5cf6f
Addressing comments
Sparks0219 Nov 14, 2025
d0fddda
Merge conflicts
Sparks0219 Nov 18, 2025
c8e0ed6
Addressing comments
Sparks0219 Nov 18, 2025
49250fb
Bad merge conflict fix
Sparks0219 Nov 18, 2025
3dbcc22
Addressing comments
Sparks0219 Nov 18, 2025
73445ab
Fix cpp test
Sparks0219 Nov 18, 2025
f737eef
Addressing comments
Sparks0219 Nov 19, 2025
0429f79
Addressing comments
Sparks0219 Nov 20, 2025
2f9c24e
Addressing comments
Sparks0219 Nov 20, 2025
bed7884
Fix cpp test error
Sparks0219 Nov 21, 2025
2a66834
Addressing comments
Sparks0219 Nov 22, 2025
0fb240e
Merge remote-tracking branch 'upstream/master' into joshlee/make-canc…
Sparks0219 Nov 26, 2025
6bab852
Removing io context posts now that accessor node cache is thread safe
Sparks0219 Nov 26, 2025
c1f1e0f
Merge branch 'master' into joshlee/make-cancel-task-fault-tolerant
edoakes Nov 26, 2025
758ecd6
Merge branch 'master' into joshlee/make-cancel-task-fault-tolerant
jjyao Dec 2, 2025
22a53f5
Deflake serve test
Sparks0219 Dec 3, 2025
e053c2f
AI comment
Sparks0219 Dec 3, 2025
6b2674b
Addressing AI comments
Sparks0219 Dec 3, 2025
07da167
More AI comments
Sparks0219 Dec 3, 2025
41fa586
Addressing comments
Sparks0219 Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion python/ray/tests/test_raylet_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import os
import sys

import pytest

import ray
from ray._private.test_utils import wait_for_condition
from ray._common.test_utils import SignalActor, wait_for_condition
from ray.core.generated import autoscaler_pb2
from ray.exceptions import GetTimeoutError, TaskCancelledError
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
PlacementGroupSchedulingStrategy,
)

import psutil


@pytest.mark.parametrize("deterministic_failure", ["request", "response"])
def test_request_worker_lease_idempotent(
Expand Down Expand Up @@ -138,5 +142,58 @@ def task():
assert result == "success"


@pytest.fixture
def inject_cancel_local_task_rpc_failure(monkeypatch, request):
deterministic_failure = request.param
monkeypatch.setenv(
"RAY_testing_rpc_failure",
"NodeManagerService.grpc_client.CancelLocalTask=1:"
+ ("100:0" if deterministic_failure == "request" else "0:100"),
)


@pytest.mark.parametrize(
"inject_cancel_local_task_rpc_failure", ["request", "response"], indirect=True
)
@pytest.mark.parametrize("force_kill", [True, False])
def test_cancel_local_task_rpc_retry_and_idempotency(
inject_cancel_local_task_rpc_failure, force_kill, shutdown_only
):
"""Test that CancelLocalTask RPC retries work correctly.

Verify that the RPC is idempotent when network failures occur.
When force_kill=True, verify the worker process is actually killed using psutil.
"""
ray.init(num_cpus=2)
signaler = SignalActor.remote()

@ray.remote(num_cpus=1)
def get_pid():
return os.getpid()

@ray.remote(num_cpus=1)
def blocking_task():
return ray.get(signaler.wait.remote())

worker_pid = ray.get(get_pid.remote())

blocking_ref = blocking_task.remote()

with pytest.raises(GetTimeoutError):
ray.get(blocking_ref, timeout=1)

ray.cancel(blocking_ref, force=force_kill)

with pytest.raises(TaskCancelledError):
ray.get(blocking_ref, timeout=10)

if force_kill:

def verify_process_killed():
return not psutil.pid_exists(worker_pid)

wait_for_condition(verify_process_killed, timeout=30)


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
5 changes: 5 additions & 0 deletions src/mock/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class MockRayletClientInterface : public RayletClientInterface {
(const rpc::ClientCallback<rpc::GlobalGCReply> &callback),
(override));
MOCK_METHOD(int64_t, GetPinsInFlight, (), (const, override));
MOCK_METHOD(void,
CancelLocalTask,
(const rpc::CancelLocalTaskRequest &request,
const rpc::ClientCallback<rpc::CancelLocalTaskReply> &callback),
(override));
};

} // namespace ray
2 changes: 2 additions & 0 deletions src/ray/core_worker/core_worker_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,8 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(

auto actor_task_submitter = std::make_unique<ActorTaskSubmitter>(
*core_worker_client_pool,
*raylet_client_pool,
gcs_client,
*memory_store,
*task_manager,
*actor_creator,
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/task_submission/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ ray_cc_library(
"//src/ray/common:protobuf_utils",
"//src/ray/core_worker:actor_creator",
"//src/ray/core_worker_rpc_client:core_worker_client_pool",
"//src/ray/gcs_rpc_client:gcs_client",
"//src/ray/raylet_rpc_client:raylet_client_interface",
"//src/ray/raylet_rpc_client:raylet_client_pool",
"//src/ray/rpc:rpc_callback_types",
"//src/ray/util:time",
"@com_google_absl//absl/base:core_headers",
Expand Down
53 changes: 31 additions & 22 deletions src/ray/core_worker/task_submission/actor_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ void ActorTaskSubmitter::CancelDependencyResolution(const TaskID &task_id) {

void ActorTaskSubmitter::DisconnectRpcClient(ClientQueue &queue) {
queue.client_address_ = std::nullopt;
queue.raylet_address_ = std::nullopt;
// If the actor on the worker is dead, the worker is also dead.
core_worker_client_pool_.Disconnect(WorkerID::FromBinary(queue.worker_id_));
queue.worker_id_.clear();
Expand Down Expand Up @@ -336,7 +337,12 @@ void ActorTaskSubmitter::ConnectActor(const ActorID &actor_id,
// So new RPCs go out with the right intended worker id to the right address.
queue->second.worker_id_ = address.worker_id();
queue->second.client_address_ = address;
NodeID node_id = NodeID::FromBinary(address.node_id());

auto node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness(node_id);
RAY_CHECK(node_info != nullptr);
queue->second.raylet_address_ = rpc::RayletClientPool::GenerateRayletAddress(
node_id, node_info->node_manager_address(), node_info->node_manager_port());
SendPendingTasks(actor_id);
}

Expand Down Expand Up @@ -1008,32 +1014,35 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive)
return;
}

rpc::CancelTaskRequest request;
rpc::CancelLocalTaskRequest request;
request.set_intended_task_id(task_spec.TaskIdBinary());
request.set_force_kill(force_kill);
request.set_recursive(recursive);
request.set_caller_worker_id(task_spec.CallerWorkerIdBinary());
auto client = core_worker_client_pool_.GetOrConnect(*queue->second.client_address_);
client->CancelTask(request,
[this, task_spec = std::move(task_spec), recursive, task_id](
const Status &status, const rpc::CancelTaskReply &reply) {
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "CancelTask RPC response received with status "
<< status.ToString();

// Keep retrying every 2 seconds until a task is officially
// finished.
if (!task_manager_.GetTaskSpec(task_id)) {
// Task is already finished.
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task is finished. Stop a cancel request.";
return;
}

if (!reply.attempt_succeeded()) {
RetryCancelTask(task_spec, recursive, 2000);
}
});
request.set_executor_worker_address(
queue->second.client_address_->SerializeAsString());
auto raylet_client =
raylet_client_pool_.GetOrConnectByAddress(*queue->second.raylet_address_);
raylet_client->CancelLocalTask(
request,
[this, task_spec = std::move(task_spec), recursive, task_id](
const Status &status, const rpc::CancelLocalTaskReply &reply) {
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "CancelTask RPC response received with status " << status.ToString();

// Keep retrying every 2 seconds until a task is officially
// finished.
if (!task_manager_.GetTaskSpec(task_id)) {
// Task is already finished.
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task is finished. Stop a cancel request.";
return;
}

if (!reply.attempt_succeeded()) {
RetryCancelTask(task_spec, recursive, 2000);
}
});
}
}

Expand Down
13 changes: 12 additions & 1 deletion src/ray/core_worker/task_submission/actor_task_submitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class ActorTaskSubmitterInterface {
class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
public:
ActorTaskSubmitter(rpc::CoreWorkerClientPool &core_worker_client_pool,
rpc::RayletClientPool &raylet_client_pool,
std::shared_ptr<gcs::GcsClient> gcs_client,
CoreWorkerMemoryStore &store,
TaskManagerInterface &task_manager,
ActorCreatorInterface &actor_creator,
Expand All @@ -76,6 +78,8 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
instrumented_io_context &io_service,
std::shared_ptr<ReferenceCounterInterface> reference_counter)
: core_worker_client_pool_(core_worker_client_pool),
raylet_client_pool_(raylet_client_pool),
gcs_client_(std::move(gcs_client)),
actor_creator_(actor_creator),
resolver_(store, task_manager, actor_creator, tensor_transport_getter),
task_manager_(task_manager),
Expand Down Expand Up @@ -300,8 +304,10 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
int64_t num_restarts_due_to_lineage_reconstructions_ = 0;
/// Whether this actor exits by spot preemption.
bool preempted_ = false;
/// The RPC client address.
/// The RPC client address of the actor.
std::optional<rpc::Address> client_address_;
/// The local raylet address of the actor.
std::optional<rpc::Address> raylet_address_;
/// The intended worker ID of the actor.
std::string worker_id_;
/// The actor is out of scope but the death info is not published
Expand Down Expand Up @@ -411,6 +417,11 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
/// Pool for producing new core worker clients.
rpc::CoreWorkerClientPool &core_worker_client_pool_;

/// Pool for producing new raylet clients.
rpc::RayletClientPool &raylet_client_pool_;

std::shared_ptr<gcs::GcsClient> gcs_client_;

ActorCreatorInterface &actor_creator_;

/// Mutex to protect the various maps below.
Expand Down
23 changes: 13 additions & 10 deletions src/ray/core_worker/task_submission/normal_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ void NormalTaskSubmitter::OnWorkerIdle(
task_spec.GetMutableMessage().set_lease_grant_timestamp_ms(current_sys_time_ms());
task_spec.EmitTaskMetrics();

executing_tasks_.emplace(task_spec.TaskId(), addr);
executing_tasks_.emplace(task_spec.TaskId(),
std::make_pair(addr, lease_entry.addr));
PushNormalTask(
addr, client, scheduling_key, std::move(task_spec), assigned_resources);
}
Expand Down Expand Up @@ -665,7 +666,8 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec,
SchedulingKey scheduling_key(task_spec.GetSchedulingClass(),
task_spec.GetDependencyIds(),
task_spec.GetRuntimeEnvHash());
std::shared_ptr<rpc::CoreWorkerClientInterface> client = nullptr;
std::shared_ptr<RayletClientInterface> raylet_client = nullptr;
rpc::Address executor_worker_address;
{
absl::MutexLock lock(&mu_);
generators_to_resubmit_.erase(task_id);
Expand Down Expand Up @@ -700,9 +702,8 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec,
// This will get removed either when the RPC call to cancel is returned, when all
// dependencies are resolved, or when dependency resolution is successfully cancelled.
RAY_CHECK(cancelled_tasks_.emplace(task_id).second);
auto rpc_client = executing_tasks_.find(task_id);

if (rpc_client == executing_tasks_.end()) {
auto rpc_client_address = executing_tasks_.find(task_id);
if (rpc_client_address == executing_tasks_.end()) {
if (failed_tasks_pending_failure_cause_.contains(task_id)) {
// We are waiting for the task failure cause. Do not fail it here; instead,
// wait for the cause to come in and then handle it appropriately.
Expand All @@ -723,22 +724,24 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec,
return;
}
// Looks for an RPC handle for the worker executing the task.
client = core_worker_client_pool_->GetOrConnect(rpc_client->second);
raylet_client =
raylet_client_pool_->GetOrConnectByAddress(rpc_client_address->second.second);
executor_worker_address = rpc_client_address->second.first;
}

RAY_CHECK(client != nullptr);
auto request = rpc::CancelTaskRequest();
auto request = rpc::CancelLocalTaskRequest();
request.set_intended_task_id(task_spec.TaskIdBinary());
request.set_force_kill(force_kill);
request.set_recursive(recursive);
request.set_caller_worker_id(task_spec.CallerWorkerIdBinary());
client->CancelTask(
request.set_executor_worker_address(executor_worker_address.SerializeAsString());
raylet_client->CancelLocalTask(
request,
[this,
task_spec = std::move(task_spec),
scheduling_key = std::move(scheduling_key),
force_kill,
recursive](const Status &status, const rpc::CancelTaskReply &reply) mutable {
recursive](const Status &status, const rpc::CancelLocalTaskReply &reply) mutable {
absl::MutexLock lock(&mu_);
RAY_LOG(DEBUG) << "CancelTask RPC response received for " << task_spec.TaskId()
<< " with status " << status.ToString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,10 @@ class NormalTaskSubmitter {
absl::flat_hash_set<TaskID> cancelled_tasks_ ABSL_GUARDED_BY(mu_);

// Keeps track of where currently executing tasks are being run.
absl::flat_hash_map<TaskID, rpc::Address> executing_tasks_ ABSL_GUARDED_BY(mu_);
// The first address is the executor, the second address is the local raylet of the
// executor.
absl::flat_hash_map<TaskID, std::pair<rpc::Address, rpc::Address>> executing_tasks_
ABSL_GUARDED_BY(mu_);

// Generators that are currently running and need to be resubmitted.
absl::flat_hash_set<TaskID> generators_to_resubmit_ ABSL_GUARDED_BY(mu_);
Expand Down
5 changes: 3 additions & 2 deletions src/ray/protobuf/core_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,9 @@ service CoreWorkerService {
// Failure: TODO: Never retries
rpc KillActor(KillActorRequest) returns (KillActorReply);

// Request from owner worker to executor worker to cancel a task.
// Failure: Will retry, TODO: Needs tests for failure behavior.
// Request from local raylet to executor worker to cancel a task.
// Failure: Idempotent, does not retry for network failures. However requests
// should only be sent via CancelLocalTask from the raylet which does implement retries
rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply);

// Request from a worker to the owner worker to issue a cancellation.
Expand Down
23 changes: 23 additions & 0 deletions src/ray/protobuf/node_manager.proto
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,26 @@ message GetWorkerPIDsReply {
repeated int32 pids = 1;
}

message CancelLocalTaskRequest {
// ID of task that should be killed.
bytes intended_task_id = 1;
// Whether to kill the worker.
bool force_kill = 2;
// Whether to recursively cancel tasks.
bool recursive = 3;
// The worker ID of the caller.
bytes caller_worker_id = 4;
// The worker address of the executor.
bytes executor_worker_address = 5;
}

message CancelLocalTaskReply {
// Whether the requested task is the currently running task.
bool requested_task_running = 1;
// Whether the task is canceled.
bool attempt_succeeded = 2;
}

// Service for inter-node-manager communication.
service NodeManagerService {
// Handle the case when GCS restarted.
Expand Down Expand Up @@ -525,4 +545,7 @@ service NodeManagerService {
// Failure: Will retry with the default timeout 1000ms. If fails, reply return an empty
// list.
rpc GetWorkerPIDs(GetWorkerPIDsRequest) returns (GetWorkerPIDsReply);
// Forwards the CancelTask request from the caller core worker to the executor
// Failure: Retries, it's idempotent.
rpc CancelLocalTask(CancelLocalTaskRequest) returns (CancelLocalTaskReply);
}
Loading