Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions hyperactor/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ use crate::channel::Rx;
use crate::channel::Tx;
use crate::clock::Clock;
use crate::clock::RealClock;
use crate::context;
use crate::mailbox::BoxableMailboxSender;
use crate::mailbox::DialMailboxRouter;
use crate::mailbox::IntoBoxedMailboxSender as _;
Expand Down Expand Up @@ -404,6 +405,7 @@ pub trait SingleTerminate: Send + Sync {
/// Returns a tuple of (polite shutdown actors vec, forceful stop actors vec)
async fn terminate_proc(
&self,
cx: &impl context::Actor,
proc: &ProcId,
timeout: std::time::Duration,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error>;
Expand Down Expand Up @@ -444,6 +446,7 @@ pub trait BulkTerminate: Send + Sync {
/// etc.).
async fn terminate_all(
&self,
cx: &impl context::Actor,
timeout: std::time::Duration,
max_in_flight: usize,
) -> TerminateSummary;
Expand All @@ -467,21 +470,23 @@ impl<M: ProcManager + BulkTerminate> Host<M> {
/// terminations.
pub async fn terminate_children(
&self,
cx: &impl context::Actor,
timeout: Duration,
max_in_flight: usize,
) -> TerminateSummary {
self.manager.terminate_all(timeout, max_in_flight).await
self.manager.terminate_all(cx, timeout, max_in_flight).await
}
}

#[async_trait::async_trait]
impl<M: ProcManager + SingleTerminate> SingleTerminate for Host<M> {
async fn terminate_proc(
&self,
cx: &impl context::Actor,
proc: &ProcId,
timeout: Duration,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
self.manager.terminate_proc(proc, timeout).await
self.manager.terminate_proc(cx, proc, timeout).await
}
}

Expand Down Expand Up @@ -566,6 +571,7 @@ pub trait ProcHandle: Clone + Send + Sync + 'static {
/// termination.
async fn terminate(
&self,
cx: &impl context::Actor,
timeout: Duration,
) -> Result<Self::TerminalStatus, TerminateError<Self::TerminalStatus>>;

Expand Down Expand Up @@ -657,6 +663,7 @@ where
{
async fn terminate_all(
&self,
_cx: &impl context::Actor,
timeout: std::time::Duration,
max_in_flight: usize,
) -> TerminateSummary {
Expand Down Expand Up @@ -699,6 +706,7 @@ where
{
async fn terminate_proc(
&self,
_cx: &impl context::Actor,
proc: &ProcId,
timeout: std::time::Duration,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
Expand Down Expand Up @@ -783,6 +791,7 @@ impl<A: Actor + Referable> ProcHandle for LocalHandle<A> {

async fn terminate(
&self,
_cx: &impl context::Actor,
timeout: Duration,
) -> Result<(), TerminateError<Self::TerminalStatus>> {
let mut proc = {
Expand Down Expand Up @@ -1010,6 +1019,7 @@ impl<A: Actor + Referable> ProcHandle for ProcessHandle<A> {

async fn terminate(
&self,
_cx: &impl context::Actor,
_deadline: Duration,
) -> Result<(), TerminateError<Self::TerminalStatus>> {
Err(TerminateError::Unsupported)
Expand Down Expand Up @@ -1441,6 +1451,7 @@ mod tests {
}
async fn terminate(
&self,
_cx: &impl context::Actor,
_timeout: Duration,
) -> Result<Self::TerminalStatus, TerminateError<Self::TerminalStatus>> {
Err(TerminateError::Unsupported)
Expand Down
41 changes: 37 additions & 4 deletions hyperactor_mesh/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use hyperactor::clock::RealClock;
use hyperactor::config::CONFIG;
use hyperactor::config::ConfigAttr;
use hyperactor::config::global as config;
use hyperactor::context;
use hyperactor::declare_attrs;
use hyperactor::host;
use hyperactor::host::Host;
Expand All @@ -64,6 +65,7 @@ use tokio::sync::watch;
use crate::logging::OutputTarget;
use crate::logging::StreamFwder;
use crate::proc_mesh::mesh_agent::ProcMeshAgent;
use crate::resource::StopAllClient;
use crate::v1;
use crate::v1::host_mesh::mesh_agent::HostAgentMode;
use crate::v1::host_mesh::mesh_agent::HostMeshAgent;
Expand Down Expand Up @@ -1242,6 +1244,7 @@ impl hyperactor::host::ProcHandle for BootstrapProcHandle {
/// or the channel was lost.
async fn terminate(
&self,
cx: &impl context::Actor,
timeout: Duration,
) -> Result<ProcStatus, hyperactor::host::TerminateError<Self::TerminalStatus>> {
const HARD_WAIT_AFTER_KILL: Duration = Duration::from_secs(5);
Expand All @@ -1264,6 +1267,30 @@ impl hyperactor::host::ProcHandle for BootstrapProcHandle {
})?;

// Best-effort mark "Stopping" (ok if state races).

// Before sending SIGTERM, try to close actors normally. Only works if
// they are in the Ready state and have an Agent we can message.
let agent = self.agent_ref();
if let Some(agent) = agent {
let mailbox_result = RealClock.timeout(timeout, agent.stop_all(cx)).await;
if let Err(timeout_err) = mailbox_result {
// Agent didn't respond in time, proceed with SIGTERM.
tracing::warn!(
"ProcMeshAgent {} didn't respond in time to stop proc: {}",
agent.actor_id(),
timeout_err,
);
} else if let Ok(Err(e)) = mailbox_result {
// Other mailbox error, proceed with SIGTERM.
tracing::warn!(
"ProcMeshAgent {} did not successfully stop all actors: {}",
agent.actor_id(),
e
);
}
}
// After the stop all actors message may be successful, we still need
// to actually stop the process.
let _ = self.mark_stopping();

// Send SIGTERM (ESRCH is treated as "already gone").
Expand Down Expand Up @@ -1885,6 +1912,7 @@ impl hyperactor::host::SingleTerminate for BootstrapProcManager {
/// Logs a warning for each failure.
async fn terminate_proc(
&self,
cx: &impl context::Actor,
proc: &ProcId,
timeout: Duration,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
Expand All @@ -1895,7 +1923,7 @@ impl hyperactor::host::SingleTerminate for BootstrapProcManager {
};

if let Some(h) = proc_handle {
h.terminate(timeout)
h.terminate(cx, timeout)
.await
.map(|_| (Vec::new(), Vec::new()))
.map_err(|e| e.into())
Expand All @@ -1920,7 +1948,12 @@ impl hyperactor::host::BulkTerminate for BootstrapProcManager {
/// those that were already terminal), and how many failed.
///
/// Logs a warning for each failure.
async fn terminate_all(&self, timeout: Duration, max_in_flight: usize) -> TerminateSummary {
async fn terminate_all(
&self,
cx: &impl context::Actor,
timeout: Duration,
max_in_flight: usize,
) -> TerminateSummary {
// Snapshot to avoid holding the lock across awaits.
let handles: Vec<BootstrapProcHandle> = {
let guard = self.children.lock().await;
Expand All @@ -1931,7 +1964,7 @@ impl hyperactor::host::BulkTerminate for BootstrapProcManager {
let mut ok = 0usize;

let results = stream::iter(handles.into_iter().map(|h| async move {
match h.terminate(timeout).await {
match h.terminate(cx, timeout).await {
Ok(_) | Err(hyperactor::host::TerminateError::AlreadyTerminated(_)) => {
// Treat "already terminal" as success.
true
Expand Down Expand Up @@ -3321,7 +3354,7 @@ mod tests {

let deadline = Duration::from_secs(2);
match RealClock
.timeout(deadline * 2, handle.terminate(deadline))
.timeout(deadline * 2, handle.terminate(&instance, deadline))
.await
{
Err(_) => panic!("terminate() future hung"),
Expand Down
28 changes: 28 additions & 0 deletions hyperactor_mesh/src/proc_mesh/mesh_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ pub(crate) fn update_event_actor_id(mut event: ActorSupervisionEvent) -> ActorSu
MeshAgentMessage,
resource::CreateOrUpdate<ActorSpec> { cast = true },
resource::Stop { cast = true },
resource::StopAll { cast = true },
resource::GetState<ActorState> { cast = true },
resource::GetRankStatus { cast = true },
]
Expand Down Expand Up @@ -272,6 +273,14 @@ impl ProcMeshAgent {
};
proc.spawn::<Self>("agent", agent).await
}

async fn destroy_and_wait<'a>(
&mut self,
cx: &Context<'a, Self>,
timeout: tokio::time::Duration,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
self.proc.destroy_and_wait::<Self>(timeout, Some(cx)).await
}
}

#[async_trait]
Expand Down Expand Up @@ -616,6 +625,25 @@ impl Handler<resource::Stop> for ProcMeshAgent {
}
}

#[async_trait]
impl Handler<resource::StopAll> for ProcMeshAgent {
async fn handle(
&mut self,
cx: &Context<Self>,
_message: resource::StopAll,
) -> anyhow::Result<()> {
let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
// By passing in the self context, destroy_and_wait will stop this agent
// last, after all others are stopped.
let _stop_result = self.destroy_and_wait(cx, timeout).await?;
for (_, actor_state) in self.actor_states.iter_mut() {
// Mark all actors as stopped.
actor_state.stopped = true;
}
Ok(())
}
}

#[async_trait]
impl Handler<resource::GetRankStatus> for ProcMeshAgent {
async fn handle(
Expand Down
16 changes: 16 additions & 0 deletions hyperactor_mesh/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ pub struct Stop {
pub reply: PortRef<StatusOverlay>,
}

/// Stop all resources owned by the receiver of this message.
/// No reply, this is meant to force a stop without waiting for acknowledgement.
#[derive(
Debug,
Clone,
Serialize,
Deserialize,
Named,
Handler,
HandleClient,
RefClient,
Bind,
Unbind
)]
pub struct StopAll {}

/// Retrieve the current state of the resource.
#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
pub struct GetState<S> {
Expand Down
12 changes: 7 additions & 5 deletions hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use hyperactor::Proc;
use hyperactor::ProcId;
use hyperactor::RefClient;
use hyperactor::channel::ChannelTransport;
use hyperactor::context;
use hyperactor::host::Host;
use hyperactor::host::HostError;
use hyperactor::host::LocalProcManager;
Expand Down Expand Up @@ -75,13 +76,14 @@ impl HostAgentMode {

async fn terminate_proc(
&self,
cx: &impl context::Actor,
proc: &ProcId,
timeout: Duration,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
#[allow(clippy::match_same_arms)]
match self {
HostAgentMode::Process(host) => host.terminate_proc(proc, timeout).await,
HostAgentMode::Local(host) => host.terminate_proc(proc, timeout).await,
HostAgentMode::Process(host) => host.terminate_proc(cx, proc, timeout).await,
HostAgentMode::Local(host) => host.terminate_proc(cx, proc, timeout).await,
}
}
}
Expand Down Expand Up @@ -212,7 +214,7 @@ impl Handler<resource::Stop> for HostMeshAgent {
!*stopped
};
if should_stop {
host.terminate_proc(proc_id, timeout).await?;
host.terminate_proc(&cx, proc_id, timeout).await?;
*stopped = true;
}
// use Stopped as a successful result for Stop.
Expand Down Expand Up @@ -328,13 +330,13 @@ impl Handler<ShutdownHost> for HostMeshAgent {
match host_mode {
HostAgentMode::Process(host) => {
let summary = host
.terminate_children(msg.timeout, msg.max_in_flight.clamp(1, 256))
.terminate_children(cx, msg.timeout, msg.max_in_flight.clamp(1, 256))
.await;
tracing::info!(?summary, "terminated children on host");
}
HostAgentMode::Local(host) => {
let summary = host
.terminate_children(msg.timeout, msg.max_in_flight)
.terminate_children(cx, msg.timeout, msg.max_in_flight)
.await;
tracing::info!(?summary, "terminated children on local host");
}
Expand Down
8 changes: 4 additions & 4 deletions python/tests/test_actor_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,13 +749,13 @@ async def test_slice_supervision() -> None:
slice_2 = error_mesh.slice(gpus=2)
slice_3 = error_mesh.slice(gpus=3)

# Trigger supervision error on gpus=3
with pytest.raises(SupervisionError, match="did not handle supervision event"):
await slice_3.fail_with_supervision_error.call()

match = (
"Actor .* (is unhealthy with reason:|exited because of the following reason:)"
)
# Trigger supervision error on gpus=3
with pytest.raises(SupervisionError, match=match):
await slice_3.fail_with_supervision_error.call()

# Mesh containing all gpus is unhealthy
with pytest.raises(SupervisionError, match=match):
await error_mesh.check.call()
Expand Down