From 388456fcfe92c6620aef9369543f258386f179c2 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 15:47:44 +0200 Subject: [PATCH 01/11] Introduced RemoteRuntime type --- changelog.d/agent-threading.internal.md | 1 + mirrord/agent/src/dns.rs | 92 ++++--- mirrord/agent/src/entrypoint.rs | 293 +++++++++----------- mirrord/agent/src/entrypoint/setup.rs | 51 +--- mirrord/agent/src/error.rs | 57 +--- mirrord/agent/src/main.rs | 1 - mirrord/agent/src/namespace.rs | 47 ++-- mirrord/agent/src/outgoing.rs | 46 +--- mirrord/agent/src/outgoing/udp.rs | 39 +-- mirrord/agent/src/sniffer.rs | 16 +- mirrord/agent/src/sniffer/api.rs | 17 +- mirrord/agent/src/steal/api.rs | 20 +- mirrord/agent/src/steal/connection.rs | 16 +- mirrord/agent/src/util.rs | 73 +---- mirrord/agent/src/util/remote_runtime.rs | 331 +++++++++++++++++++++++ mirrord/agent/src/vpn.rs | 131 +++++---- mirrord/agent/src/watched_task.rs | 128 --------- 17 files changed, 697 insertions(+), 662 deletions(-) create mode 100644 changelog.d/agent-threading.internal.md create mode 100644 mirrord/agent/src/util/remote_runtime.rs delete mode 100644 mirrord/agent/src/watched_task.rs diff --git a/changelog.d/agent-threading.internal.md b/changelog.d/agent-threading.internal.md new file mode 100644 index 00000000000..e2adceb008a --- /dev/null +++ b/changelog.d/agent-threading.internal.md @@ -0,0 +1 @@ +Reworked agent's threading model to avoid spawning excessive threads. diff --git a/mirrord/agent/src/dns.rs b/mirrord/agent/src/dns.rs index 6b0ab7783d9..a541ea7e8be 100644 --- a/mirrord/agent/src/dns.rs +++ b/mirrord/agent/src/dns.rs @@ -1,4 +1,6 @@ -use std::{future, io, path::PathBuf, sync::atomic::Ordering, time::Duration}; +use std::{ + collections::HashMap, future, io, path::PathBuf, sync::atomic::Ordering, time::Duration, +}; use futures::{stream::FuturesOrdered, StreamExt}; use hickory_resolver::{ @@ -24,16 +26,12 @@ use tokio::{ mpsc::{Receiver, Sender}, oneshot, }, - task::JoinSet, + task::{Id, JoinSet}, }; use tokio_util::sync::CancellationToken; use tracing::{warn, Level}; -use crate::{ - error::{AgentError, AgentResult}, - metrics::DNS_REQUEST_COUNT, - watched_task::TaskStatus, -}; +use crate::{error::AgentResult, metrics::DNS_REQUEST_COUNT, util::remote_runtime::BgTaskStatus}; #[derive(Debug)] pub(crate) enum ClientGetAddrInfoRequest { @@ -54,7 +52,7 @@ impl ClientGetAddrInfoRequest { #[derive(Debug)] pub(crate) struct DnsCommand { request: ClientGetAddrInfoRequest, - response_tx: oneshot::Sender>, + response_tx: oneshot::Sender>, } /// Background task for resolving hostnames to IP addresses. @@ -80,12 +78,11 @@ pub(crate) struct DnsWorker { /// Background tasks that handle the DNS requests. /// /// Each of these builds a new [`TokioAsyncResolver`] and performs one lookup. - tasks: JoinSet<()>, + tasks: JoinSet>, + response_txs: HashMap>>, } impl DnsWorker { - pub const TASK_NAME: &'static str = "DNS worker"; - /// Creates a new instance of this worker. /// To run this worker, call [`Self::run`]. /// @@ -124,6 +121,7 @@ impl DnsWorker { attempts, support_ipv6, tasks: Default::default(), + response_txs: Default::default(), } } @@ -203,34 +201,51 @@ impl DnsWorker { let attempts = self.attempts; let support_ipv6 = self.support_ipv6; - let lookup_future = async move { - let result = Self::do_lookup( - etc_path, - message.request.into_v2(), - attempts, - timeout, - support_ipv6, - ) - .await; - - let _ = message.response_tx.send(result); - }; + let handle = self.tasks.spawn(Self::do_lookup( + etc_path, + message.request.into_v2(), + attempts, + timeout, + support_ipv6, + )); + self.response_txs.insert(handle.id(), message.response_tx); DNS_REQUEST_COUNT.fetch_add(1, Ordering::Relaxed); - self.tasks.spawn(lookup_future); } - pub(crate) async fn run(mut self, cancellation_token: CancellationToken) -> AgentResult<()> { + pub(crate) async fn run(mut self, cancellation_token: CancellationToken) { loop { tokio::select! { - _ = cancellation_token.cancelled() => break Ok(()), + _ = cancellation_token.cancelled() => break, - _ = self.tasks.join_next() => { + Some(result) = self.tasks.join_next_with_id() => { DNS_REQUEST_COUNT.fetch_sub(1, Ordering::Relaxed); + let (id, result) = match result { + Ok((id, result)) => ( + id, + result.map_err(Into::into), + ), + Err(error) => { + ( + error.id(), + Err(ResolveErrorKindInternal::Message("DNS task panicked".into())) + ) + } + }; + + let response_tx = self.response_txs.remove(&id); + match response_tx { + Some(response_tx) => { + let _ = response_tx.send(result); + } + None => { + warn!(?id, "Received a DNS result with no matching response channel"); + } + } } message = self.request_rx.recv() => match message { - None => break Ok(()), + None => break, Some(message) => self.handle_message(message), }, } @@ -246,15 +261,15 @@ impl Drop for DnsWorker { } pub(crate) struct DnsApi { - task_status: TaskStatus, + task_status: BgTaskStatus, request_tx: Sender, /// [`DnsWorker`] processes all requests concurrently, so we use a combination of [`oneshot`] /// channels and [`FuturesOrdered`] to preserve order of responses. - responses: FuturesOrdered>>, + responses: FuturesOrdered>>, } impl DnsApi { - pub(crate) fn new(task_status: TaskStatus, task_sender: Sender) -> Self { + pub(crate) fn new(task_status: BgTaskStatus, task_sender: Sender) -> Self { Self { task_status, request_tx: task_sender, @@ -276,7 +291,7 @@ impl DnsApi { response_tx, }; if self.request_tx.send(command).await.is_err() { - return Err(self.task_status.unwrap_err().await); + return Err(self.task_status.wait_assert_running().await); } self.responses.push_back(response_rx); @@ -294,11 +309,14 @@ impl DnsApi { return future::pending().await; }; - let response = response - .map_err(|_| AgentError::DnsTaskPanic)? - .map_err(|error| ResponseError::DnsLookup(DnsLookupError { kind: error.into() })); - - Ok(GetAddrInfoResponse(response)) + match response { + Ok(response) => { + Ok(GetAddrInfoResponse(response.map_err(|kind| { + ResponseError::DnsLookup(DnsLookupError { kind }) + }))) + } + Err(..) => Err(self.task_status.wait_assert_running().await), + } } } diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 955ea317ea7..2730e51625b 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -22,7 +22,6 @@ use mirrord_agent_iptables::{ IPTABLE_STANDARD_ENV, }; use mirrord_protocol::{ClientMessage, DaemonMessage, GetEnvVarsRequest, LogMessage}; -use sniffer::tcp_capture::RawSocketTcpCapture; use steal::StealerMessage; use tokio::{ net::{TcpListener, TcpStream}, @@ -34,23 +33,28 @@ use tokio::{ time::{timeout, Duration}, }; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, trace, warn, Level}; +use tracing::{debug, error, trace, warn, Level}; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; use crate::{ - cli::Args, - client_connection::ClientConnection, + cli::{self, Args}, + client_connection::{self, ClientConnection}, container_handle::ContainerHandle, - dns::DnsApi, + dns::{self, DnsApi}, + env, error::{AgentError, AgentResult}, file::FileManager, + metrics, + namespace::NamespaceType, outgoing::{TcpOutgoingApi, UdpOutgoingApi}, - runtime::get_container, + runtime::{self, get_container}, sniffer::{api::TcpSnifferApi, messages::SnifferCommand, TcpConnectionSniffer}, - steal::{StealTlsHandlerStore, StealerCommand, TcpConnectionStealer, TcpStealerApi}, - util::{path_resolver::InTargetPathResolver, run_thread_in_namespace, ClientId}, - watched_task::{TaskStatus, WatchedTask}, - *, + steal::{self, StealTlsHandlerStore, StealerCommand, TcpConnectionStealer, TcpStealerApi}, + util::{ + path_resolver::InTargetPathResolver, + remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime, RemoteRuntime}, + ClientId, + }, }; mod setup; @@ -73,6 +77,8 @@ struct State { ephemeral: bool, /// When present, it is used to secure incoming TCP connections. tls_connector: Option, + /// [`tokio::runtime`] that should be used for network operations. + network_runtime: MaybeRemoteRuntime, } impl State { @@ -122,6 +128,13 @@ impl State { cli::Mode::Targetless | cli::Mode::BlackboxTest => (false, None, "self".to_string()), }; + let network_runtime = match container.as_ref().map(ContainerHandle::pid) { + Some(pid) if ephemeral.not() => MaybeRemoteRuntime::Remote( + RemoteRuntime::new_in_namespace(pid, NamespaceType::Net).await?, + ), + None | Some(..) => MaybeRemoteRuntime::Local, + }; + let environ_path = PathBuf::from("/proc").join(pid).join("environ"); match env::get_proc_environ(environ_path).await { @@ -137,6 +150,7 @@ impl State { env: Arc::new(env), ephemeral, tls_connector, + network_runtime, }) } @@ -178,7 +192,7 @@ impl State { } enum BackgroundTask { - Running(TaskStatus, Sender), + Running(BgTaskStatus, Sender), Disabled, } @@ -206,7 +220,9 @@ struct ClientConnectionHandler { /// Handles mirrord's file operations, see [`FileManager`]. file_manager: FileManager, connection: ClientConnection, + /// [`None`] when targetless. tcp_sniffer_api: Option, + /// [`None`] when targetless. tcp_stealer_api: Option, tcp_outgoing_api: TcpOutgoingApi, udp_outgoing_api: UdpOutgoingApi, @@ -240,8 +256,8 @@ impl ClientConnectionHandler { Self::create_stealer_api(id, bg_tasks.stealer, &mut connection).await?; let dns_api = Self::create_dns_api(bg_tasks.dns); - let tcp_outgoing_api = TcpOutgoingApi::new(pid); - let udp_outgoing_api = UdpOutgoingApi::new(pid); + let tcp_outgoing_api = TcpOutgoingApi::new(&state.network_runtime); + let udp_outgoing_api = UdpOutgoingApi::new(&state.network_runtime); let client_handler = Self { id, @@ -460,16 +476,21 @@ impl ClientConnectionHandler { if let Some(sniffer_api) = &mut self.tcp_sniffer_api { sniffer_api.handle_client_message(message).await? } else { - warn!("received tcp sniffer request while not available"); - Err(AgentError::SnifferNotRunning)? + self.respond(DaemonMessage::Close( + "component responsible for mirroring incoming traffic is not running, \ + which might be due to Kubernetes node kernel version <4.20. \ + Check agent logs for errors and please report a bug if kernel version >=4.20".into(), + )).await?; } } ClientMessage::TcpSteal(message) => { if let Some(tcp_stealer_api) = self.tcp_stealer_api.as_mut() { tcp_stealer_api.handle_client_message(message).await? } else { - warn!("received tcp steal request while not available"); - Err(AgentError::StealerNotRunning)? + self.respond(DaemonMessage::Close( + "incoming traffic stealing is not available in the targetless mode".into(), + )) + .await?; } } ClientMessage::Close => { @@ -498,8 +519,8 @@ impl ClientConnectionHandler { self.ready_for_logs = true; } ClientMessage::Vpn(_message) => { - unreachable!("VPN is not supported"); - // self.vpn_api.layer_message(message).await?; + self.respond(DaemonMessage::Close("VPN is not supported".into())) + .await?; } } @@ -559,117 +580,85 @@ async fn start_agent(args: Args) -> AgentResult<()> { }); } - let (sniffer_command_tx, sniffer_command_rx) = mpsc::channel::(1000); - let (stealer_command_tx, stealer_command_rx) = mpsc::channel::(1000); - let (dns_command_tx, dns_command_rx) = mpsc::channel::(1000); - - let (sniffer_task, sniffer_status) = if args.mode.is_targetless() { - (None, None) + let sniffer = if args.mode.is_targetless() { + BackgroundTask::Disabled } else { let cancellation_token = cancellation_token.clone(); + let (command_tx, command_rx) = mpsc::channel::(1000); + let is_mesh = args.is_mesh(); - // We're using this to avoid crashing on old kernels when initializing the - // `RawSocketTcpCapture`. failed task causes the agent to exit - // so we just check that initialization was successful - // then decide whether to store the task or drop it - // https://github.com/metalbear-co/mirrord/pull/2910 - let (sniffer_init_tx, sniffer_init_rx) = tokio::sync::oneshot::channel::(); - let watched_task = WatchedTask::new( - TcpConnectionSniffer::::TASK_NAME, - async move { - if let Ok(sniffer) = - TcpConnectionSniffer::new(sniffer_command_rx, args.network_interface, is_mesh) - .await - { - if let Err(error) = sniffer_init_tx.send(true) { - tracing::error!(%error, "Failed to send sniffer init result"); - }; - // will block from this point on - let res = sniffer.start(cancellation_token).await; - if let Err(err) = res { - error!(%err, "Sniffer failed"); - } - } else if let Err(error) = sniffer_init_tx.send(false) { - tracing::error!(%error, "Failed to send sniffer init result"); - } + let sniffer = state + .network_runtime + .spawn(TcpConnectionSniffer::new( + command_rx, + args.network_interface, + is_mesh, + )) + .await; - Ok(()) - }, - ); - let status = watched_task.status(); - let task = run_thread_in_namespace( - watched_task.start(), - TcpConnectionSniffer::::TASK_NAME.to_string(), - state.container_pid(), - "net", - ); - - match sniffer_init_rx.await { - Ok(true) => (Some(task), Some(status)), - Ok(false) => (None, None), + match sniffer { + Ok(Ok(sniffer)) => { + let task_status = state + .network_runtime + .spawn(sniffer.start(cancellation_token.clone())) + .into_status("TcpSnifferTask"); + + BackgroundTask::Running(task_status, command_tx) + } + Ok(Err(error)) => { + error!(%error, "Failed to create a TCP sniffer"); + BackgroundTask::Disabled + } Err(error) => { - tracing::error!(%error, "unexpected error while waiting for sniffer init"); - (None, None) + error!(%error, "Failed to create a TCP sniffer"); + BackgroundTask::Disabled } } }; - let (stealer_task, stealer_status) = match state.container_pid() { - None => (None, None), - Some(pid) => { - let steal_handle = setup::start_traffic_redirector(pid).await?; + let stealer = match &state.network_runtime { + MaybeRemoteRuntime::Local => BackgroundTask::Disabled, + MaybeRemoteRuntime::Remote(runtime) => { + let steal_handle = setup::start_traffic_redirector(runtime).await?; let cancellation_token = cancellation_token.clone(); + let (command_tx, command_rx) = mpsc::channel::(1000); + let tls_steal_config = envs::STEAL_TLS_CONFIG.from_env_or_default(); let tls_handler_store = tls_steal_config.is_empty().not().then(|| { - StealTlsHandlerStore::new(tls_steal_config, InTargetPathResolver::new(pid)) + StealTlsHandlerStore::new( + tls_steal_config, + InTargetPathResolver::new(runtime.target_pid()), + ) }); - let watched_task = WatchedTask::new(TcpConnectionStealer::TASK_NAME, async move { - TcpConnectionStealer::new(stealer_command_rx, steal_handle, tls_handler_store) - .start(cancellation_token) - .await - .inspect_err(|error| { - error!(%error, "Stealer failed"); - }) - }); - let status = watched_task.status(); - let task = run_thread_in_namespace( - watched_task.start(), - TcpConnectionStealer::TASK_NAME.to_string(), - state.container_pid(), - "net", - ); - - (Some(task), Some(status)) + let task_status = state + .network_runtime + .spawn( + TcpConnectionStealer::new(command_rx, steal_handle, tls_handler_store) + .start(cancellation_token), + ) + .into_status("TcpStealerTask"); + + BackgroundTask::Running(task_status, command_tx) } }; - let (dns_task, dns_status) = { - let cancellation_token = cancellation_token.clone(); - let watched_task = WatchedTask::new( - DnsWorker::TASK_NAME, - DnsWorker::new(state.container_pid(), dns_command_rx, args.ipv6) - .run(cancellation_token), - ); - let status = watched_task.status(); - let task = run_thread_in_namespace( - watched_task.start(), - DnsWorker::TASK_NAME.to_string(), - state.container_pid(), - "net", - ); - - (task, status) + let dns = { + let (command_tx, command_rx) = mpsc::channel::(1000); + let task_status = state + .network_runtime + .spawn( + DnsWorker::new(state.container_pid(), command_rx, args.ipv6) + .run(cancellation_token.clone()), + ) + .into_status("DnsTask"); + BackgroundTask::Running(task_status, command_tx) }; let bg_tasks = BackgroundTasks { - sniffer: sniffer_status - .map(|status| BackgroundTask::Running(status, sniffer_command_tx)) - .unwrap_or(BackgroundTask::Disabled), - stealer: stealer_status - .map(|status| BackgroundTask::Running(status, stealer_command_tx)) - .unwrap_or(BackgroundTask::Disabled), - dns: BackgroundTask::Running(dns_status, dns_command_tx), + sniffer, + stealer, + dns, }; // WARNING: `wait_for_agent_startup` in `mirrord/kube/src/api/container.rs` expects a line @@ -732,7 +721,6 @@ async fn start_agent(args: Args) -> AgentResult<()> { Some(Err(error)) => { error!(?error, "start_agent -> Failed to join client handler task"); - Err(error)? } None => { @@ -753,30 +741,29 @@ async fn start_agent(args: Args) -> AgentResult<()> { dns, } = bg_tasks; - if let (Some(sniffer_task), BackgroundTask::Running(mut sniffer_status, _)) = - (sniffer_task, sniffer) - { - sniffer_task.join().map_err(|_| AgentError::JoinTask)?; - if let Some(err) = sniffer_status.err().await { - error!("start_agent -> sniffer task failed with error: {}", err); - } - } - - if let (Some(stealer_task), BackgroundTask::Running(mut stealer_status, _)) = - (stealer_task, stealer) - { - stealer_task.join().map_err(|_| AgentError::JoinTask)?; - if let Some(err) = stealer_status.err().await { - error!("start_agent -> stealer task failed with error: {}", err); - } - } - - if let BackgroundTask::Running(mut dns_status, _) = dns { - dns_task.join().map_err(|_| AgentError::JoinTask)?; - if let Some(err) = dns_status.err().await { - error!("start_agent -> dns task failed with error: {}", err); - } - } + tokio::join!( + async move { + if let BackgroundTask::Running(status, _) = sniffer { + if let Err(error) = status.wait().await { + error!("start_agent -> {error}"); + } + } + }, + async move { + if let BackgroundTask::Running(status, _) = stealer { + if let Err(error) = status.wait().await { + error!("start_agent -> {error}"); + } + } + }, + async move { + if let BackgroundTask::Running(status, _) = dns { + if let Err(error) = status.wait().await { + error!("start_agent -> {error}"); + } + } + }, + ); trace!("start_agent -> Agent shutdown"); @@ -841,14 +828,18 @@ async fn start_iptable_guard(args: Args) -> AgentResult<()> { result = run_child_agent() => result, }; - let _ = run_thread_in_namespace( - clear_iptable_chain(), - "clear iptables".to_owned(), - pid, - "net", - ) - .join() - .map_err(|_| AgentError::JoinTask)?; + let Some(pid) = pid else { + return result; + }; + + let runtime = RemoteRuntime::new_in_namespace(pid, NamespaceType::Net).await?; + runtime + .spawn(clear_iptable_chain()) + .await + .map_err(|error| AgentError::BackgroundTaskFailed { + task: "IPTablesCleaner", + error: Arc::new(error), + })??; result } @@ -898,7 +889,7 @@ pub async fn main() -> AgentResult<()> { let args = cli::parse_args(); - let agent_result = if args.mode.is_targetless() + if args.mode.is_targetless() || (std::env::var(IPTABLE_PREROUTING_ENV).is_ok() && std::env::var(IPTABLE_MESH_ENV).is_ok() && std::env::var(IPTABLE_STANDARD_ENV).is_ok()) @@ -906,19 +897,5 @@ pub async fn main() -> AgentResult<()> { start_agent(args).await } else { start_iptable_guard(args).await - }; - - match agent_result { - Ok(_) => { - info!("main -> mirrord-agent `start` exiting successfully.") - } - Err(fail) => { - error!( - "main -> mirrord-agent `start` exiting with error {:#?}", - fail - ) - } } - - Ok(()) } diff --git a/mirrord/agent/src/entrypoint/setup.rs b/mirrord/agent/src/entrypoint/setup.rs index 094e8685460..444d9fe0141 100644 --- a/mirrord/agent/src/entrypoint/setup.rs +++ b/mirrord/agent/src/entrypoint/setup.rs @@ -1,53 +1,30 @@ use mirrord_agent_env::envs; -use tokio::sync::oneshot; use crate::{ error::{AgentError, AgentResult}, incoming::{self, RedirectorTask, StealHandle}, - util::run_thread_in_namespace, + util::remote_runtime::RemoteRuntime, }; -/// Starts a [`RedirectorTask`] in the target's network namespace. +/// Starts a [`RedirectorTask`] on the given `runtime`. /// /// Returns the [`StealHandle`] that can be used to steal incoming traffic. -pub(crate) async fn start_traffic_redirector(target_pid: u64) -> AgentResult { +pub(super) async fn start_traffic_redirector(runtime: &RemoteRuntime) -> AgentResult { let flush_connections = envs::STEALER_FLUSH_CONNECTIONS.from_env_or_default(); let pod_ips = envs::POD_IPS.from_env_or_default(); let support_ipv6 = envs::IPV6_SUPPORT.from_env_or_default(); - let (handle_tx, handle_rx) = oneshot::channel(); + let (task, handle) = runtime + .spawn(async move { + incoming::create_iptables_redirector(flush_connections, &pod_ips, support_ipv6) + .await + .map(|redirector| RedirectorTask::new(redirector)) + }) + .await + .map_err(|error| AgentError::IPTablesSetupError(error.into()))? + .map_err(|error| AgentError::IPTablesSetupError(error.into()))?; - run_thread_in_namespace( - async move { - let redirector_result = - incoming::create_iptables_redirector(flush_connections, &pod_ips, support_ipv6) - .await; + runtime.spawn(task.run()); - let redirector = match redirector_result { - Ok(redirector) => redirector, - Err(error) => { - let _ = handle_tx.send(Err(error)); - return; - } - }; - - let (task, handle) = RedirectorTask::new(redirector); - - if handle_tx.send(Ok(handle)).is_err() { - return; - } - - let _ = task.run().await.inspect_err(|error| { - tracing::error!(%error, "Incoming traffic redirector task failed"); - }); - }, - "IncomingTrafficRedirector".into(), - Some(target_pid), - "net", - ); - - match handle_rx.await { - Ok(result) => result.map_err(|error| AgentError::IPTablesSetupError(error.into())), - Err(..) => Err(AgentError::IPTablesSetupError("task panicked".into())), - } + Ok(handle) } diff --git a/mirrord/agent/src/error.rs b/mirrord/agent/src/error.rs index abeb51d5cd9..cde2c8fded1 100644 --- a/mirrord/agent/src/error.rs +++ b/mirrord/agent/src/error.rs @@ -1,12 +1,10 @@ -use std::process::ExitStatus; +use std::{process::ExitStatus, sync::Arc}; -use mirrord_protocol::outgoing::udp::DaemonUdpOutgoing; use thiserror::Error; -use tokio::sync::mpsc::{self, error::SendError}; use crate::{ client_connection::TlsSetupError, incoming::RedirectorTaskError, namespace::NamespaceError, - runtime, sniffer::messages::SnifferCommand, steal::StealerCommand, + runtime, util::remote_runtime::RemoteRuntimeError, }; #[derive(Debug, Error)] @@ -14,49 +12,22 @@ pub(crate) enum AgentError { #[error("io error: {0}")] IO(#[from] std::io::Error), - #[error("SnifferCommand sender failed with `{0}`")] - SendSnifferCommand(#[from] SendError), - - #[error("TCP stealer task is dead")] - TcpStealerTaskDead, - - #[error("UdpConnectRequest sender failed with `{0}`")] - SendUdpOutgoingTrafficResponse(#[from] SendError), - - #[error("task::Join failed with `{0}`")] - Join(#[from] tokio::task::JoinError), - #[error("Container runtime error: {0}")] ContainerRuntimeError(#[from] runtime::ContainerRuntimeError), #[error("Path failed with `{0}`")] StripPrefixError(#[from] std::path::StripPrefixError), - #[error("Join task failed")] - JoinTask, - - #[error("DNS request send failed with `{0}`")] - DnsRequestSendError(#[from] SendError), - - #[error("DNS background task panicked")] - DnsTaskPanic, - #[error(r#"Failed to set socket flag PACKET_IGNORE_OUTGOING, this might be due to kernel version before 4.20. Original error `{0}`"#)] PacketIgnoreOutgoing(#[source] std::io::Error), - #[error( - r#"Couldn't send message to sniffer (mirror) api, sniffer probably not running. - Possible reason can be node kernel version before 4.20 - Check agent logs for errors and please report a bug if kernel version >=4.20"# - )] - SnifferNotRunning, - - #[error("Couldn't send message to stealer (steal) api, stealer probably not running.")] - StealerNotRunning, - - #[error("Background task `{task}` failed with `{cause}`")] - BackgroundTaskFailed { task: &'static str, cause: String }, + #[error("Background task `{task}` failed: `{error}`")] + BackgroundTaskFailed { + task: &'static str, + #[source] + error: Arc, + }, #[error("Returning an error to test the agent's error cleanup. Should only ever be used when testing mirrord.")] TestError, @@ -77,22 +48,14 @@ pub(crate) enum AgentError { #[error("Timeout on accepting first client connection")] FirstConnectionTimeout, - #[allow(dead_code)] - /// Temporary error for vpn feature - #[error("Generic error in vpn: {0}")] - VpnError(String), - #[error("Incoming traffic redirector failed: {0}")] PortRedirectorError(#[from] RedirectorTaskError), #[error("IP tables setup failed: {0}")] IPTablesSetupError(#[source] Box), -} -impl From> for AgentError { - fn from(_: mpsc::error::SendError) -> Self { - Self::TcpStealerTaskDead - } + #[error("Failed to start a tokio runtime in the target's namespace: {0}")] + RemoteRuntimeError(#[from] RemoteRuntimeError), } pub(crate) type AgentResult = std::result::Result; diff --git a/mirrord/agent/src/main.rs b/mirrord/agent/src/main.rs index 03f8b3159e6..2788ae5ab52 100644 --- a/mirrord/agent/src/main.rs +++ b/mirrord/agent/src/main.rs @@ -31,7 +31,6 @@ mod sniffer; mod steal; mod util; mod vpn; -mod watched_task; #[cfg(target_os = "linux")] #[tokio::main(flavor = "current_thread")] diff --git a/mirrord/agent/src/namespace.rs b/mirrord/agent/src/namespace.rs index d640134b7f1..2ff746abe43 100644 --- a/mirrord/agent/src/namespace.rs +++ b/mirrord/agent/src/namespace.rs @@ -1,27 +1,41 @@ -use std::fs::File; +use std::{fmt, fs::File}; use nix::sched::{setns, CloneFlags}; use thiserror::Error; +use tracing::Level; +/// Errors that can occur when entering a Linux namespace. #[derive(Debug, Error)] -pub(crate) enum NamespaceError { - #[error("Failed opening pid's namespace file: {0}")] +pub enum NamespaceError { + #[error("failed to open target's namespace file: {0}")] FailedNamespaceOpen(#[from] std::io::Error), - #[error("Failed to enter namespace: {0}")] + #[error("failed to enter target's namespace: {0}")] FailedNamespaceEnter(#[from] nix::Error), } -/// Non exhaustive namespace type enum. Add as needed -#[derive(Debug)] -pub(crate) enum NamespaceType { +/// Linux namespace types. +/// +/// Add more as needed. +#[derive(Debug, Clone, Copy)] +pub enum NamespaceType { Net, } impl NamespaceType { - #[tracing::instrument(level = "trace", ret)] - fn path_from_pid(&self, pid: u64) -> String { + /// Returns a path to the namespace file for the given target process ID. + /// + /// This path can be used with [`setns`] to enter the namespace. + fn path_for_target(self, target_pid: u64) -> String { match self { - NamespaceType::Net => format!("/proc/{}/ns/net", pid), + NamespaceType::Net => format!("/proc/{target_pid}/ns/net"), + } + } +} + +impl fmt::Display for NamespaceType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Net => f.write_str("net"), } } } @@ -34,14 +48,11 @@ impl From for CloneFlags { } } -/// Set namespace by cloneflags and pid. -/// NOTE: don't make it async in the case we're in an multi-thread scheduler and we want it to -/// happen on the same thread always. -#[tracing::instrument(level = "trace")] -pub(crate) fn set_namespace(pid: u64, namespace_type: NamespaceType) -> Result<(), NamespaceError> { - let fd = File::open(namespace_type.path_from_pid(pid))?; +/// Reassociates the current thread with the target's namespace. +#[tracing::instrument(level = Level::TRACE, ret, err)] +pub fn set_namespace(target_pid: u64, namespace_type: NamespaceType) -> Result<(), NamespaceError> { + let file = File::open(namespace_type.path_for_target(target_pid))?; + setns(file, namespace_type.into())?; - // use as_raw_fd to get reference so it will drop after setns - setns(fd, namespace_type.into())?; Ok(()) } diff --git a/mirrord/agent/src/outgoing.rs b/mirrord/agent/src/outgoing.rs index dda3813be7b..062e1d4b96f 100644 --- a/mirrord/agent/src/outgoing.rs +++ b/mirrord/agent/src/outgoing.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt, thread, time::Duration}; +use std::{collections::HashMap, fmt, time::Duration}; use bytes::Bytes; use mirrord_protocol::{ @@ -20,8 +20,7 @@ use tracing::Level; use crate::{ error::AgentResult, metrics::TCP_OUTGOING_CONNECTION, - util::run_thread_in_namespace, - watched_task::{TaskStatus, WatchedTask}, + util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}, }; mod socket_stream; @@ -33,11 +32,7 @@ pub(crate) use udp::UdpOutgoingApi; /// Each agent client has their own independent instance (neither this wrapper nor the background /// task are shared). pub(crate) struct TcpOutgoingApi { - /// Holds the thread in which [`TcpOutgoingTask`] is running. - _task: thread::JoinHandle<()>, - - /// Status of the [`TcpOutgoingTask`]. - task_status: TaskStatus, + task_status: BgTaskStatus, /// Sends the layer messages to the [`TcpOutgoingTask`]. layer_tx: Sender, @@ -47,33 +42,22 @@ pub(crate) struct TcpOutgoingApi { } impl TcpOutgoingApi { - const TASK_NAME: &'static str = "TcpOutgoing"; - - /// Spawns a new background task for handling `outgoing` feature and creates a new instance of - /// this struct to serve as an interface. + /// Spawns a new background task for handling the `outgoing` feature and creates a new instance + /// of this struct to serve as an interface. /// /// # Params /// - /// * `pid` - process id of the agent's target container - #[tracing::instrument(level = Level::TRACE)] - pub(crate) fn new(pid: Option) -> Self { + /// * `runtime` - tokio runtime to spawn the background task on. + pub(crate) fn new(runtime: &MaybeRemoteRuntime) -> Self { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); - let watched_task = WatchedTask::new( - Self::TASK_NAME, - TcpOutgoingTask::new(pid, layer_rx, daemon_tx).run(), - ); - let task_status = watched_task.status(); - let task = run_thread_in_namespace( - watched_task.start(), - Self::TASK_NAME.to_string(), - pid, - "net", - ); + let pid = runtime.target_pid(); + let task_status = runtime + .spawn(TcpOutgoingTask::new(pid, layer_rx, daemon_tx).run()) + .into_status("TcpOutgoingTask"); Self { - _task: task, task_status, layer_tx, daemon_rx, @@ -86,7 +70,7 @@ impl TcpOutgoingApi { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { - Err(self.task_status.unwrap_err().await) + Err(self.task_status.wait_assert_running().await) } } @@ -95,7 +79,7 @@ impl TcpOutgoingApi { pub(crate) async fn recv_from_task(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), - None => Err(self.task_status.unwrap_err().await), + None => Err(self.task_status.wait_assert_running().await), } } } @@ -160,7 +144,7 @@ impl TcpOutgoingTask { /// Runs this task as long as the channels connecting it with [`TcpOutgoingApi`] are open. /// This routine never fails and returns [`Result`] only due to [`WatchedTask`] constraints. #[tracing::instrument(level = Level::TRACE, skip(self))] - async fn run(mut self) -> AgentResult<()> { + async fn run(mut self) { loop { let channel_closed = select! { biased; @@ -182,7 +166,7 @@ impl TcpOutgoingTask { if channel_closed { tracing::trace!("Client channel closed, exiting"); - break Ok(()); + break; } } } diff --git a/mirrord/agent/src/outgoing/udp.rs b/mirrord/agent/src/outgoing/udp.rs index d2a2128da96..3e592235fe5 100644 --- a/mirrord/agent/src/outgoing/udp.rs +++ b/mirrord/agent/src/outgoing/udp.rs @@ -2,7 +2,6 @@ use core::fmt; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - thread, }; use bytes::{Bytes, BytesMut}; @@ -27,8 +26,7 @@ use tracing::Level; use crate::{ error::AgentResult, metrics::UDP_OUTGOING_CONNECTION, - util::run_thread_in_namespace, - watched_task::{TaskStatus, WatchedTask}, + util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}, }; /// Task that handles [`LayerUdpOutgoing`] and [`DaemonUdpOutgoing`] messages. @@ -91,7 +89,7 @@ impl UdpOutgoingTask { /// This routine never fails and returns [`AgentResult`] only due to [`WatchedTask`] /// constraints. #[tracing::instrument(level = Level::TRACE, skip(self))] - pub(super) async fn run(mut self) -> AgentResult<()> { + pub(super) async fn run(mut self) { loop { let channel_closed = select! { biased; @@ -113,7 +111,7 @@ impl UdpOutgoingTask { if channel_closed { tracing::trace!("Client channel closed, exiting"); - break Ok(()); + break; } } } @@ -276,12 +274,7 @@ impl UdpOutgoingTask { /// Handles (briefly) the `UdpOutgoingRequest` and `UdpOutgoingResponse` messages, mostly the /// passing of these messages to the `interceptor_task` thread. pub(crate) struct UdpOutgoingApi { - /// Holds the `interceptor_task`. - _task: thread::JoinHandle<()>, - - /// Status of the `interceptor_task`. - task_status: TaskStatus, - + task_status: BgTaskStatus, /// Sends the `Layer` message to the `interceptor_task`. layer_tx: Sender, @@ -312,27 +305,15 @@ async fn connect(remote_address: SocketAddress) -> Result) -> Self { + pub(crate) fn new(runtime: &MaybeRemoteRuntime) -> Self { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); - let watched_task = WatchedTask::new( - Self::TASK_NAME, - UdpOutgoingTask::new(pid, layer_rx, daemon_tx).run(), - ); - - let task_status = watched_task.status(); - let task = run_thread_in_namespace( - watched_task.start(), - Self::TASK_NAME.to_string(), - pid, - "net", - ); + let task_status = runtime + .spawn(UdpOutgoingTask::new(runtime.target_pid(), layer_rx, daemon_tx).run()) + .into_status("UdpOutgoingTask"); Self { - _task: task, task_status, layer_tx, daemon_rx, @@ -345,7 +326,7 @@ impl UdpOutgoingApi { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { - Err(self.task_status.unwrap_err().await) + Err(self.task_status.wait_assert_running().await) } } @@ -353,7 +334,7 @@ impl UdpOutgoingApi { pub(crate) async fn recv_from_task(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), - None => Err(self.task_status.unwrap_err().await), + None => Err(self.task_status.wait_assert_running().await), } } } diff --git a/mirrord/agent/src/sniffer.rs b/mirrord/agent/src/sniffer.rs index 49d0e62b109..976ae0e0b6a 100644 --- a/mirrord/agent/src/sniffer.rs +++ b/mirrord/agent/src/sniffer.rs @@ -187,8 +187,6 @@ impl TcpConnectionSniffer where R: TcpCapture, { - pub const TASK_NAME: &'static str = "Sniffer"; - /// Capacity of [`broadcast`] channels used to distribute incoming TCP packets between clients. const CONNECTION_DATA_CHANNEL_CAPACITY: usize = 512; @@ -418,11 +416,11 @@ mod test { use tokio::sync::mpsc; use super::*; - use crate::watched_task::{TaskStatus, WatchedTask}; + use crate::util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}; struct TestSnifferSetup { command_tx: Sender, - task_status: TaskStatus, + task_status: BgTaskStatus, packet_tx: Sender<(TcpSessionDirectionId, TcpPacketData)>, times_filter_changed: Arc, next_client_id: ClientId, @@ -458,12 +456,10 @@ mod test { client_txs: Default::default(), clients_closed: Default::default(), }; - let watched_task = WatchedTask::new( - TcpConnectionSniffer::::TASK_NAME, - sniffer.start(CancellationToken::new()), - ); - let task_status = watched_task.status(); - tokio::spawn(watched_task.start()); + + let task_status = MaybeRemoteRuntime::Local + .spawn(sniffer.start(CancellationToken::new())) + .into_status("TcpSnifferTask"); Self { command_tx, diff --git a/mirrord/agent/src/sniffer/api.rs b/mirrord/agent/src/sniffer/api.rs index 08874e93124..44e19bc9eda 100644 --- a/mirrord/agent/src/sniffer/api.rs +++ b/mirrord/agent/src/sniffer/api.rs @@ -18,7 +18,10 @@ use super::{ messages::{SniffedConnection, SnifferCommand, SnifferCommandInner}, AgentResult, }; -use crate::{error::AgentError, util::ClientId, watched_task::TaskStatus}; +use crate::{ + error::AgentError, + util::{remote_runtime::BgTaskStatus, ClientId}, +}; /// Interface used by clients to interact with the /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). Multiple instances of this struct operate @@ -34,7 +37,7 @@ pub(crate) struct TcpSnifferApi { /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). receiver: Receiver, /// View on the sniffer task's status. - task_status: TaskStatus, + task_status: BgTaskStatus, /// Currently sniffed connections. connections: StreamMap>>>, /// Ids for sniffed connections. @@ -59,7 +62,7 @@ impl TcpSnifferApi { pub async fn new( client_id: ClientId, sniffer_sender: Sender, - mut task_status: TaskStatus, + task_status: BgTaskStatus, ) -> AgentResult { let (sender, receiver) = mpsc::channel(Self::CONNECTION_CHANNEL_SIZE); @@ -68,7 +71,7 @@ impl TcpSnifferApi { command: SnifferCommandInner::NewClient(sender), }; if sniffer_sender.send(command).await.is_err() { - return Err(task_status.unwrap_err().await); + return Err(task_status.wait_assert_running().await); } Ok(Self { @@ -93,7 +96,7 @@ impl TcpSnifferApi { if self.sender.send(command).await.is_ok() { Ok(()) } else { - Err(self.task_status.unwrap_err().await) + Err(self.task_status.wait_assert_running().await) } } @@ -120,7 +123,7 @@ impl TcpSnifferApi { }, None => { - Err(self.task_status.unwrap_err().await) + Err(self.task_status.wait_assert_running().await) }, }, @@ -157,7 +160,7 @@ impl TcpSnifferApi { Some(result) = self.subscriptions_in_progress.next() => match result { Ok(port) => Ok((DaemonTcp::SubscribeResult(Ok(port)), None)), Err(..) => { - Err(self.task_status.unwrap_err().await) + Err(self.task_status.wait_assert_running().await) } } } diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index fd7346bca44..8e651b3a724 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -12,8 +12,9 @@ use tracing::Level; use super::{http::ReceiverStreamBody, *}; use crate::{ - error::AgentResult, metrics::HTTP_REQUEST_IN_PROGRESS_COUNT, util::ClientId, - watched_task::TaskStatus, + error::AgentResult, + metrics::HTTP_REQUEST_IN_PROGRESS_COUNT, + util::{remote_runtime::BgTaskStatus, ClientId}, }; type ResponseBodyTx = Sender, Infallible>>; @@ -43,7 +44,7 @@ pub(crate) struct TcpStealerApi { daemon_rx: Receiver, /// View on the stealer task's status. - task_status: TaskStatus, + task_status: BgTaskStatus, /// [`Sender`]s that allow us to provide body [`Frame`]s of responses to filtered HTTP /// requests. @@ -63,17 +64,20 @@ impl TcpStealerApi { pub(crate) async fn new( client_id: ClientId, command_tx: Sender, - task_status: TaskStatus, + task_status: BgTaskStatus, channel_size: usize, ) -> AgentResult { let (daemon_tx, daemon_rx) = mpsc::channel(channel_size); - command_tx + let init_result = command_tx .send(StealerCommand { client_id, command: Command::NewClient(daemon_tx), }) - .await?; + .await; + if init_result.is_err() { + return Err(task_status.wait_assert_running().await); + } Ok(Self { client_id, @@ -94,7 +98,7 @@ impl TcpStealerApi { if self.command_tx.send(command).await.is_ok() { Ok(()) } else { - Err(self.task_status.unwrap_err().await) + Err(self.task_status.wait_assert_running().await) } } @@ -117,7 +121,7 @@ impl TcpStealerApi { } Ok(msg) } - None => Err(self.task_status.unwrap_err().await), + None => Err(self.task_status.wait_assert_running().await), } } diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 8749291fab8..2b7240902c7 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -392,8 +392,6 @@ pub(crate) struct TcpConnectionStealer { } impl TcpConnectionStealer { - pub const TASK_NAME: &'static str = "Stealer"; - /// Initializes a new [`TcpConnectionStealer`], but doesn't start the actual work. /// You need to call [`TcpConnectionStealer::start`] to do so. pub(crate) fn new( @@ -807,7 +805,7 @@ mod test { net::{TcpListener, TcpStream}, sync::{ mpsc::{self, Receiver, Sender}, - oneshot, watch, + oneshot, }, }; use tokio_stream::wrappers::ReceiverStream; @@ -820,7 +818,7 @@ mod test { connection::{Client, MatchedHttpRequest}, TcpConnectionStealer, TcpStealerApi, }, - watched_task::TaskStatus, + util::remote_runtime::{IntoStatus, MaybeRemoteRuntime}, }; async fn prepare_dummy_service() -> ( @@ -1025,11 +1023,11 @@ mod test { let (task, handle) = RedirectorTask::new(redirector); tokio::spawn(task.run()); - let stealer = TcpConnectionStealer::new(command_rx, handle, None); - tokio::spawn(stealer.start(CancellationToken::new())); - - let (_dummy_tx, dummy_rx) = watch::channel(None); - let task_status = TaskStatus::dummy(TcpConnectionStealer::TASK_NAME, dummy_rx); + let task_status = MaybeRemoteRuntime::Local + .spawn( + TcpConnectionStealer::new(command_rx, handle, None).start(CancellationToken::new()), + ) + .into_status("TcpStealerTask"); let mut api = TcpStealerApi::new(0, command_tx.clone(), task_status, 8) .await .unwrap(); diff --git a/mirrord/agent/src/util.rs b/mirrord/agent/src/util.rs index 24944066a42..28588b84ae2 100644 --- a/mirrord/agent/src/util.rs +++ b/mirrord/agent/src/util.rs @@ -5,19 +5,13 @@ use std::{ hash::Hash, pin::Pin, task::{Context, Poll}, - thread::JoinHandle, }; use futures::{future::BoxFuture, FutureExt}; use tokio::sync::mpsc; -use tracing::error; - -use crate::{ - error::AgentResult, - namespace::{set_namespace, NamespaceType}, -}; pub mod path_resolver; +pub mod remote_runtime; /// Struct that helps you manage topic -> subscribers /// @@ -99,71 +93,6 @@ where } } -/// Helper that creates a new [`tokio::runtime::Runtime`], and immediately blocks on it. -/// -/// Used to start new tasks that would be too heavy for just [`tokio::task::spawn()`] in the -/// caller's runtime. -/// -/// These tasks will execute `on_start_fn` to change namespace (see [`enter_namespace`] for more -/// details). -#[tracing::instrument(level = "trace", skip_all)] -pub(crate) fn run_thread( - future: F, - thread_name: String, - on_start_fn: StartFn, -) -> JoinHandle -where - F: Future + Send + 'static, - F::Output: Send + 'static, - StartFn: Fn() + Send + Sync + 'static, -{ - std::thread::spawn(move || { - on_start_fn(); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .thread_name(thread_name) - .on_thread_start(on_start_fn) - .build() - .unwrap(); - rt.block_on(future) - }) -} - -/// Calls [`run_thread`] with `on_start_fn` always being [`enter_namespace`]. -#[tracing::instrument(level = "trace", skip_all)] -pub(crate) fn run_thread_in_namespace( - future: F, - thread_name: String, - pid: Option, - namespace: &str, -) -> JoinHandle -where - F: Future + Send + 'static, - F::Output: Send + 'static, -{ - let namespace = namespace.to_string(); - - run_thread(future, thread_name, move || { - enter_namespace(pid, &namespace).expect("Failed setting namespace!") - }) -} - -/// Used to enter a different (so far only used for "net") namespace for a task. -/// -/// Many of the agent's TCP/UDP connections require that they're made from the `pid`'s namespace to -/// work. -#[tracing::instrument(level = "trace")] -pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> AgentResult<()> { - if let Some(pid) = pid { - Ok(set_namespace(pid, NamespaceType::Net).inspect_err(|fail| { - error!("Failed setting pid {pid:#?} namespace {namespace:#?} with {fail:#?}") - })?) - } else { - Ok(()) - } -} - /// [`Future`] that resolves to [`ClientId`] when the client drops their [`mpsc::Receiver`]. pub(crate) struct ChannelClosedFuture(BoxFuture<'static, ClientId>); diff --git a/mirrord/agent/src/util/remote_runtime.rs b/mirrord/agent/src/util/remote_runtime.rs new file mode 100644 index 00000000000..7a097962b7d --- /dev/null +++ b/mirrord/agent/src/util/remote_runtime.rs @@ -0,0 +1,331 @@ +//! This module contains utilities for running async code in the agent target's namespace. +//! +//! This is useful for running tasks that require access to the target's network namespace, +//! such as traffic stealing, traffic mirroring, DNS resolution, outgoing traffic. +//! +//! This module provides: +//! 1. A [`RemoteRuntime`] struct, that can be used to run tasks in the target's namespace. +//! 2. A [`MaybeRemoteRuntime`] enum, that don't necessarily require a target (DNS and outgoing +//! traffic), but should be run in the target's namespace if available. +//! 3. A [`BgTaskStatus`] struct, that can be used to poll for a spawned task's status. + +use std::{ + error::Error, + fmt, + future::Future, + io, + ops::Not, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + thread, +}; + +use futures::{ + future::{BoxFuture, Shared}, + FutureExt, +}; +use thiserror::Error; +use tokio::sync::{mpsc, oneshot}; + +use crate::{ + error::AgentError, + namespace::{self, NamespaceError, NamespaceType}, +}; + +/// Errors that can occur when creating a [`RemoteRuntime`]. +#[derive(Error, Debug)] +pub enum RemoteRuntimeError { + #[error("failed to spawn runtime thread: {0}")] + ThreadSpawnError(#[source] io::Error), + #[error(transparent)] + NamespaceError(#[from] NamespaceError), + #[error("failed to build tokio runtime: {0}")] + TokioRuntimeError(#[source] io::Error), + #[error("runtime thread panicked")] + Panicked, +} + +/// A cloneable handle to a remote [`tokio::runtime::Runtime`] that runs in its own thread. +/// +/// Can be used to spawn tasks with [`RemoteRuntime::spawn`]. +/// +/// The runtime will be aborted when all handles are dropped. +#[derive(Clone)] +pub struct RemoteRuntime { + target_pid: u64, + future_tx: mpsc::Sender>, +} + +impl RemoteRuntime { + /// Creates a new remote runtime. + /// + /// This runtime's thread will enter the specified namespace of the target. + pub async fn new_in_namespace( + target_pid: u64, + namespace_type: NamespaceType, + ) -> Result { + let (future_tx, mut future_rx) = mpsc::channel(16); + let (result_tx, result_rx) = oneshot::channel(); + let thread_name = format!("remote-{namespace_type}-runtime-thread"); + let thread_logic = move || { + if let Err(error) = namespace::set_namespace(target_pid, namespace_type) { + let _ = result_tx.send(Err(error.into())); + return; + } + + let rt_result = tokio::runtime::Builder::new_current_thread() + .enable_all() + .thread_name(format!("remote-{namespace_type}-runtime-worker")) + .build(); + let rt = match rt_result { + Ok(rt) => rt, + Err(error) => { + let _ = result_tx.send(Err(RemoteRuntimeError::TokioRuntimeError(error))); + return; + } + }; + + if result_tx.send(Ok(())).is_err() { + return; + } + + rt.block_on(async move { + while let Some(future) = future_rx.recv().await { + tokio::spawn(future); + } + }); + }; + + thread::Builder::new() + .name(thread_name) + .spawn(thread_logic) + .map_err(RemoteRuntimeError::ThreadSpawnError)?; + + match result_rx.await { + Ok(Ok(())) => Ok(Self { + target_pid, + future_tx, + }), + Ok(Err(error)) => Err(error), + Err(..) => Err(RemoteRuntimeError::Panicked), + } + } + + /// Spawns the given future on this remote runtime. + pub fn spawn(&self, future: F) -> BgTask + where + F: 'static + Future + Send, + F::Output: 'static + Send, + { + let (result_tx, result_rx) = oneshot::channel(); + + let future = async move { + let result = future.await; + let _ = result_tx.send(result); + } + .boxed(); + + let future_tx = self.future_tx.clone(); + tokio::spawn(async move { + let _ = future_tx.send(future).await; + }); + + BgTask { + future_result: result_rx, + } + } + + /// Returns the target's PID. + pub fn target_pid(&self) -> u64 { + self.target_pid + } +} + +impl fmt::Debug for RemoteRuntime { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RemoteRuntime") + .field("running", &self.future_tx.is_closed().not()) + .finish() + } +} + +/// An error that occurs when polling a future spawned with [`RemoteRuntime::spawn`] or +/// [`MaybeRemoteRuntime::spawn`]. +/// +/// This error indicated that the future has panicked. +#[derive(Debug, Error)] +#[error("task panicked")] +pub struct BgTaskPanicked; + +/// A future spawned with [`RemoteRuntime::spawn`] or +/// [`MaybeRemoteRuntime::spawn`] +pub struct BgTask { + future_result: oneshot::Receiver, +} + +impl Future for BgTask { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + Pin::new(&mut this.future_result) + .poll(cx) + .map_err(|_| BgTaskPanicked) + } +} + +/// A cloneable status of a future spawned with [`RemoteRuntime::spawn`] or +/// [`MaybeRemoteRuntime::spawn`]. +#[derive(Clone)] +pub struct BgTaskStatus { + task_name: &'static str, + result: Shared>>>, +} + +impl BgTaskStatus { + /// Waits for the future to finish and returns its result. + /// + /// Should the future fail or panic, this function will return + /// [`AgentError::BackgroundTaskFailed`]. + pub async fn wait(&self) -> Result<(), AgentError> { + match self.result.clone().await { + Ok(Ok(())) => Ok(()), + Ok(Err(error)) => Err(AgentError::BackgroundTaskFailed { + task: self.task_name, + error, + }), + Err(..) => Err(AgentError::BackgroundTaskFailed { + task: self.task_name, + error: Arc::new(BgTaskPanicked) as Arc, + }), + } + } + + /// Waits for the future to finish and returns its result. + /// + /// This function always returns [`AgentError::BackgroundTaskFailed`]. Use it when the task is + /// not expected to finish yet. + pub async fn wait_assert_running(&self) -> AgentError { + match self.result.clone().await { + Ok(Ok(())) => AgentError::BackgroundTaskFailed { + task: self.task_name, + error: Box::::from("task finished unexpectedly").into(), + }, + Ok(Err(error)) => AgentError::BackgroundTaskFailed { + task: self.task_name, + error, + }, + Err(..) => AgentError::BackgroundTaskFailed { + task: self.task_name, + error: Arc::new(BgTaskPanicked) as Arc, + }, + } + } +} + +impl fmt::Debug for BgTaskStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BgTaskStatus") + .field("task_name", &self.task_name) + .field("result", &self.result.clone().now_or_never()) + .finish() + } +} + +/// Convenience trait for transforming [`BgTask`] into [`BgTaskStatus`]. +pub trait IntoStatus { + fn into_status(self, task_name: &'static str) -> BgTaskStatus; +} + +impl IntoStatus for BgTask> +where + E: Error + Send + Sync + 'static, +{ + fn into_status(self, task_name: &'static str) -> BgTaskStatus { + let (result_tx, result_rx) = oneshot::channel(); + + tokio::spawn(async move { + let result = match self.future_result.await { + Ok(Ok(())) => Ok(()), + Ok(Err(e)) => Err(Arc::new(e) as Arc), + Err(..) => Err(Arc::new(BgTaskPanicked) as Arc), + }; + + let _ = result_tx.send(result); + }); + + BgTaskStatus { + task_name, + result: result_rx.shared(), + } + } +} + +impl IntoStatus for BgTask<()> { + fn into_status(self, task_name: &'static str) -> BgTaskStatus { + let (result_tx, result_rx) = oneshot::channel(); + + tokio::spawn(async move { + let result = match self.future_result.await { + Ok(()) => Ok(()), + Err(..) => Err(Arc::new(BgTaskPanicked) as Arc), + }; + + let _ = result_tx.send(result); + }); + + BgTaskStatus { + task_name, + result: result_rx.shared(), + } + } +} + +/// A runtime to spawn tasks on, either remote or local. +/// +/// This can be used to spawn tasks that can either run in the target's namespace or the agent's. +/// +/// If the agent has a target, you should use [`MaybeRemoteRuntime::Remote`]. +/// If the agent does not have a target, you should fallback to [`MaybeRemoteRuntime::Local`]. +#[derive(Clone)] +pub enum MaybeRemoteRuntime { + /// Remote runtime, which runs in the target's namespace. + Remote(RemoteRuntime), + /// Local runtime ([`tokio::runtime::Handle::current`]). + Local, +} + +impl MaybeRemoteRuntime { + /// Spawns the given future on this runtime. + pub fn spawn(&self, future: F) -> BgTask + where + F: 'static + Future + Send, + F::Output: 'static + Send, + { + match self { + Self::Remote(remote_runtime) => remote_runtime.spawn(future), + Self::Local => { + let (result_tx, result_rx) = oneshot::channel(); + + tokio::spawn(async move { + let result = future.await; + let _ = result_tx.send(result); + }); + + BgTask { + future_result: result_rx, + } + } + } + } + + /// If this is a remote runtime, returns the target's PID. + /// Otherwise, returns [`None`]. + pub fn target_pid(&self) -> Option { + match self { + Self::Remote(remote_runtime) => Some(remote_runtime.target_pid()), + Self::Local => None, + } + } +} diff --git a/mirrord/agent/src/vpn.rs b/mirrord/agent/src/vpn.rs index d7d30d5ca6f..cff649e90f5 100644 --- a/mirrord/agent/src/vpn.rs +++ b/mirrord/agent/src/vpn.rs @@ -1,9 +1,10 @@ +//! This code is not used anywhere. + #![allow(dead_code)] use std::{ fmt, - io::Read, + io::{self, Read}, net::{IpAddr, Ipv4Addr, SocketAddr}, - thread, }; use mirrord_protocol::vpn::{ClientVpn, NetworkConfiguration, ServerVpn}; @@ -17,21 +18,15 @@ use tokio::{ }; use crate::{ - error::{AgentError, AgentResult}, - util::run_thread_in_namespace, - watched_task::{TaskStatus, WatchedTask}, + error::AgentResult, + util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}, }; /// An interface for a background task handling [`ClientVpn`] messages. /// Each agent client has their own independent instance (neither this wrapper nor the background /// task are shared). pub(crate) struct VpnApi { - /// Holds the thread in which [`VpnTask`] is running. - _task: thread::JoinHandle<()>, - - /// Status of the [`VpnTask`]. - task_status: TaskStatus, - + task_status: BgTaskStatus, /// Sends the layer messages to the [`VpnTask`]. layer_tx: Sender, @@ -40,33 +35,22 @@ pub(crate) struct VpnApi { } impl VpnApi { - const TASK_NAME: &'static str = "Vpn"; - - /// Spawns a new background task for handling `outgoing` feature and creates a new instance of + /// Spawns a new background task for handling the `vpn` feature and creates a new instance of /// this struct to serve as an interface. /// /// # Params /// - /// * `pid` - process id of the agent's target container - #[tracing::instrument(level = "trace")] - pub(crate) fn new(pid: Option) -> Self { + /// * `runtime` - tokio runtime to spawn the task on. + pub(crate) fn new(runtime: &MaybeRemoteRuntime) -> Self { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); + let pid = runtime.target_pid(); - let watched_task = WatchedTask::new( - Self::TASK_NAME, - VpnTask::new(pid, layer_rx, daemon_tx).run(), - ); - let task_status = watched_task.status(); - let task = run_thread_in_namespace( - watched_task.start(), - Self::TASK_NAME.to_string(), - pid, - "net", - ); + let task_status = runtime + .spawn(VpnTask::new(pid, layer_rx, daemon_tx).run()) + .into_status("VpnTask"); Self { - _task: task, task_status, layer_tx, daemon_rx, @@ -74,12 +58,11 @@ impl VpnApi { } /// Sends the [`ClientVpn`] message to the background task. - #[tracing::instrument(level = "trace", skip(self))] pub(crate) async fn layer_message(&mut self, message: ClientVpn) -> AgentResult<()> { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { - Err(self.task_status.unwrap_err().await) + Err(self.task_status.wait_assert_running().await) } } @@ -87,7 +70,7 @@ impl VpnApi { pub(crate) async fn daemon_message(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), - None => Err(self.task_status.unwrap_err().await), + None => Err(self.task_status.wait_assert_running().await), } } } @@ -121,25 +104,27 @@ impl AsyncRawSocket { } } -async fn create_raw_socket() -> AgentResult { - let index = nix::net::if_::if_nametoindex("eth0") - .map_err(|err| AgentError::VpnError(err.to_string()))?; +async fn create_raw_socket() -> io::Result { + let index = nix::net::if_::if_nametoindex("eth0")?; let socket = Socket::new( Domain::PACKET, Type::DGRAM, Some(Protocol::from(libc::ETH_P_IP.to_be())), )?; - let sock_addr = interface_index_to_sock_addr( - i32::try_from(index).map_err(|err| AgentError::VpnError(err.to_string()))?, - )?; + let sock_addr = interface_index_to_sock_addr(i32::try_from(index).map_err(|_| { + io::Error::new( + io::ErrorKind::Other, + format!("invalid interface index {index}"), + ) + })?)?; socket.bind(&sock_addr)?; socket.set_nonblocking(true)?; - AsyncRawSocket::new(socket, sock_addr).map_err(From::from) + AsyncRawSocket::new(socket, sock_addr) } #[tracing::instrument(level = "debug", ret)] -async fn resolve_interface() -> AgentResult<(IpAddr, IpAddr, IpAddr)> { +async fn resolve_interface() -> io::Result<(IpAddr, IpAddr, IpAddr)> { // Connect to a remote address so we can later get the default network interface. let temporary_socket = UdpSocket::bind("0.0.0.0:0").await?; temporary_socket.connect("8.8.8.8:53").await?; @@ -152,36 +137,50 @@ async fn resolve_interface() -> AgentResult<(IpAddr, IpAddr, IpAddr)> { let raw_local_address = SockaddrStorage::from(local_address); // Try to find an interface that matches the local ip we have. - let usable_interface = nix::ifaddrs::getifaddrs() - .map_err(|err| AgentError::VpnError(err.to_string()))? + let usable_interface = nix::ifaddrs::getifaddrs()? .find(|iface| { iface .address .map(|addr| addr == raw_local_address) .unwrap_or(false) }) - .ok_or_else(|| AgentError::VpnError("usable_interface".to_owned()))?; + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no usable interface"))?; let ip = usable_interface .address - .ok_or_else(|| AgentError::VpnError("usable_interface.address".to_owned()))? - .as_sockaddr_in() - .ok_or_else(|| AgentError::VpnError("usable_interface.address.as_sockaddr_in".to_owned()))? + .as_ref() + .and_then(SockaddrStorage::as_sockaddr_in) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "usable_interface.address.as_sockaddr_in", + ) + })? .ip() .into(); let net_mask = usable_interface .netmask - .ok_or_else(|| AgentError::VpnError("usable_interface.netmask".to_owned()))? - .as_sockaddr_in() - .ok_or_else(|| AgentError::VpnError("usable_interface.netmask.as_sockaddr_in".to_owned()))? + .as_ref() + .and_then(SockaddrStorage::as_sockaddr_in) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "usable_interface.netmask.as_sockaddr_in", + ) + })? .ip() .into(); // extracting gateway is more difficult, ugly patch for now. let temp_gateway = usable_interface .address - .ok_or_else(|| AgentError::VpnError("usable_interface.address".to_owned()))? - .as_sockaddr_in() - .ok_or_else(|| AgentError::VpnError("usable_interface.address.as_sockaddr_in".to_owned()))? + .as_ref() + .and_then(SockaddrStorage::as_sockaddr_in) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "usable_interface.address.as_sockaddr_in", + ) + })? .ip() .octets(); @@ -209,16 +208,16 @@ impl fmt::Debug for VpnTask { } } -fn interface_index_to_sock_addr(index: i32) -> AgentResult { +fn interface_index_to_sock_addr(index: i32) -> io::Result { let mut addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; let len = std::mem::size_of::() as libc::socklen_t; - let macs = procfs::net::arp().map_err(|err| AgentError::VpnError(err.to_string()))?; + let macs = procfs::net::arp().map_err(|error| io::Error::new(io::ErrorKind::Other, error))?; tracing::debug!(?macs, "arp entries"); let hw_addr = macs .into_iter() .find_map(|entry| entry.hw_address) - .ok_or_else(|| AgentError::VpnError("no entry with hw_address".to_owned()))?; + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no entry with hw address"))?; unsafe { let sock_addr = std::ptr::addr_of_mut!(addr_storage) as *mut libc::sockaddr_ll; @@ -245,7 +244,7 @@ impl VpnTask { } #[allow(clippy::indexing_slicing)] - async fn run(mut self) -> AgentResult<()> { + async fn run(mut self) -> io::Result<()> { // so host won't respond with RST to our packets. // TODO: need to do it for UDP as well to avoid ICMP unreachable. let output = std::process::Command::new("iptables") @@ -260,8 +259,7 @@ impl VpnTask { "-j", "DROP", ]) - .output() - .map_err(|err| AgentError::VpnError(err.to_string()))?; + .output()?; tracing::debug!(?output, "iptables output"); let (ip, net_mask, gateway) = resolve_interface().await?; @@ -298,7 +296,7 @@ impl VpnTask { self.daemon_tx .send(ServerVpn::Packet(packet)) .await - .map_err(|err| AgentError::VpnError(err.to_string()))?; + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; buffer[..len].fill(0); } @@ -318,7 +316,7 @@ impl VpnTask { &mut self, message: ClientVpn, network_configuration: &NetworkConfiguration, - ) -> AgentResult<()> { + ) -> io::Result<()> { match message { // We make connection to the requested address, split the stream into halves with // `io::split`, and put them into respective maps. @@ -328,24 +326,17 @@ impl VpnTask { network_configuration.clone(), )) .await - .map_err(|err| AgentError::VpnError(err.to_string()))?; + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; } ClientVpn::Packet(packet) => { if let Some(socket) = self.socket.as_mut() { - socket - .write(&packet) - .await - .map_err(|err| AgentError::VpnError(err.to_string()))?; + socket.write(&packet).await?; } else { tracing::error!(?packet, "unable to send packet"); } } ClientVpn::OpenSocket => { - self.socket.replace( - create_raw_socket() - .await - .map_err(|err| AgentError::VpnError(err.to_string()))?, - ); + self.socket.replace(create_raw_socket().await?); } } diff --git a/mirrord/agent/src/watched_task.rs b/mirrord/agent/src/watched_task.rs deleted file mode 100644 index 2e7370b262c..00000000000 --- a/mirrord/agent/src/watched_task.rs +++ /dev/null @@ -1,128 +0,0 @@ -use std::future::Future; - -use tokio::sync::watch::{self, Receiver, Sender}; - -use crate::error::{AgentError, AgentResult}; - -/// A shared clonable view on a background task's status. -#[derive(Debug, Clone)] -pub(crate) struct TaskStatus { - /// Name of the task. - task_name: &'static str, - /// Channel to receive the result of the task. - /// Initially, this channel contains [`None`]. - /// Only one value should ever be sent through this channel and it should be [`Some`]. - result_rx: Receiver>>, -} - -impl TaskStatus { - /// Wait for the task to complete and return the error. - /// Can be called multiple times and safely cancelled. - /// - /// # Panics - /// Panic if the task has not failed. - pub async fn unwrap_err(&mut self) -> AgentError { - self.err().await.expect("task did not fail") - } - - /// Wait for the task to complete. - /// If the task has failed, return the error. - /// Can be called multiple times and safely cancelled. - pub async fn err(&mut self) -> Option { - if self.result_rx.borrow().is_none() && self.result_rx.changed().await.is_err() { - return Some(AgentError::BackgroundTaskFailed { - task: self.task_name, - cause: "task panicked".into(), - }); - } - - self.result_rx - .borrow() - .as_ref() - .expect("WatchedTask set an empty status on exit") - .as_ref() - .err() - .map(|e| AgentError::BackgroundTaskFailed { - task: self.task_name, - cause: e.to_string(), - }) - } -} - -/// A wrapper around asynchronous task. -/// Captures the task's status and exposes it through [`TaskStatus`]. -pub(crate) struct WatchedTask { - /// Shared view on the task status. - status: TaskStatus, - /// The task to be executed. - task: F, - /// Channel to send the task result. - result_tx: Sender>>, -} - -impl WatchedTask { - /// Wrap the given task in a new instance of this struct. - pub(crate) fn new(task_name: &'static str, task: F) -> Self { - let (result_tx, result_rx) = watch::channel(None); - - Self { - status: TaskStatus { - task_name, - result_rx, - }, - task, - result_tx, - } - } - - /// Return a shared view over the inner [`TaskStatus`]. - pub(crate) fn status(&self) -> TaskStatus { - self.status.clone() - } -} - -impl WatchedTask -where - T: Future>, -{ - /// Execute the wrapped task. - /// Store its result in the inner [`TaskStatus`]. - pub(crate) async fn start(self) { - let result = self.task.await; - self.result_tx.send(Some(result)).ok(); // All receivers may be dropped. - } -} - -#[cfg(test)] -pub(crate) mod test { - use super::*; - - impl TaskStatus { - pub fn dummy( - task_name: &'static str, - result_rx: Receiver>>, - ) -> Self { - Self { - task_name, - result_rx, - } - } - } - - #[tokio::test] - async fn simple_successful() { - let task = WatchedTask::new("task", async move { Ok(()) }); - let mut status = task.status(); - task.start().await; - assert!(status.err().await.is_none()); - } - - #[tokio::test] - async fn simple_failing() { - let task = WatchedTask::new("task", async move { Err(AgentError::TestError) }); - let mut status = task.status(); - task.start().await; - assert!(status.err().await.is_some()); - status.unwrap_err().await; - } -} From dfe08e1578baf7f7dd6516f215a15955b59570d1 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 16:25:28 +0200 Subject: [PATCH 02/11] Fix --- mirrord/agent/src/entrypoint.rs | 6 +++++- mirrord/agent/src/entrypoint/setup.rs | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 2730e51625b..4d206338621 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -839,7 +839,11 @@ async fn start_iptable_guard(args: Args) -> AgentResult<()> { .map_err(|error| AgentError::BackgroundTaskFailed { task: "IPTablesCleaner", error: Arc::new(error), - })??; + })? + .map_err(|error| AgentError::BackgroundTaskFailed { + task: "IPTablesCleaner", + error: Arc::new(error), + })?; result } diff --git a/mirrord/agent/src/entrypoint/setup.rs b/mirrord/agent/src/entrypoint/setup.rs index 444d9fe0141..c069f0d794d 100644 --- a/mirrord/agent/src/entrypoint/setup.rs +++ b/mirrord/agent/src/entrypoint/setup.rs @@ -18,7 +18,7 @@ pub(super) async fn start_traffic_redirector(runtime: &RemoteRuntime) -> AgentRe .spawn(async move { incoming::create_iptables_redirector(flush_connections, &pod_ips, support_ipv6) .await - .map(|redirector| RedirectorTask::new(redirector)) + .map(RedirectorTask::new) }) .await .map_err(|error| AgentError::IPTablesSetupError(error.into()))? From 78e7807bbb74d2467972b16511289ac42cd06d66 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 18:27:40 +0200 Subject: [PATCH 03/11] Moved all BG tasks setup to a submodule --- mirrord/agent/src/entrypoint.rs | 84 ++++--------------------- mirrord/agent/src/entrypoint/setup.rs | 88 ++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 74 deletions(-) diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 4d206338621..a6f12ae3598 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -11,7 +11,7 @@ use std::{ }; use client_connection::AgentTlsConnector; -use dns::{ClientGetAddrInfoRequest, DnsCommand, DnsWorker}; +use dns::{ClientGetAddrInfoRequest, DnsCommand}; use futures::TryFutureExt; use metrics::{start_metrics, CLIENT_COUNT}; use mirrord_agent_env::envs; @@ -28,7 +28,7 @@ use tokio::{ process::Command, select, signal::unix::SignalKind, - sync::mpsc::{self, Sender}, + sync::mpsc::Sender, task::JoinSet, time::{timeout, Duration}, }; @@ -48,11 +48,10 @@ use crate::{ namespace::NamespaceType, outgoing::{TcpOutgoingApi, UdpOutgoingApi}, runtime::{self, get_container}, - sniffer::{api::TcpSnifferApi, messages::SnifferCommand, TcpConnectionSniffer}, - steal::{self, StealTlsHandlerStore, StealerCommand, TcpConnectionStealer, TcpStealerApi}, + sniffer::{api::TcpSnifferApi, messages::SnifferCommand}, + steal::{self, StealerCommand, TcpStealerApi}, util::{ - path_resolver::InTargetPathResolver, - remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime, RemoteRuntime}, + remote_runtime::{BgTaskStatus, MaybeRemoteRuntime, RemoteRuntime}, ClientId, }, }; @@ -580,81 +579,20 @@ async fn start_agent(args: Args) -> AgentResult<()> { }); } - let sniffer = if args.mode.is_targetless() { - BackgroundTask::Disabled - } else { - let cancellation_token = cancellation_token.clone(); - let (command_tx, command_rx) = mpsc::channel::(1000); - - let is_mesh = args.is_mesh(); - let sniffer = state - .network_runtime - .spawn(TcpConnectionSniffer::new( - command_rx, - args.network_interface, - is_mesh, - )) - .await; - - match sniffer { - Ok(Ok(sniffer)) => { - let task_status = state - .network_runtime - .spawn(sniffer.start(cancellation_token.clone())) - .into_status("TcpSnifferTask"); - - BackgroundTask::Running(task_status, command_tx) - } - Ok(Err(error)) => { - error!(%error, "Failed to create a TCP sniffer"); - BackgroundTask::Disabled - } - Err(error) => { - error!(%error, "Failed to create a TCP sniffer"); - BackgroundTask::Disabled - } + let sniffer = match &state.network_runtime { + MaybeRemoteRuntime::Remote(runtime) => { + setup::start_sniffer(&args, runtime, cancellation_token.clone()).await } + MaybeRemoteRuntime::Local => BackgroundTask::Disabled, }; - let stealer = match &state.network_runtime { MaybeRemoteRuntime::Local => BackgroundTask::Disabled, MaybeRemoteRuntime::Remote(runtime) => { let steal_handle = setup::start_traffic_redirector(runtime).await?; - - let cancellation_token = cancellation_token.clone(); - let (command_tx, command_rx) = mpsc::channel::(1000); - - let tls_steal_config = envs::STEAL_TLS_CONFIG.from_env_or_default(); - let tls_handler_store = tls_steal_config.is_empty().not().then(|| { - StealTlsHandlerStore::new( - tls_steal_config, - InTargetPathResolver::new(runtime.target_pid()), - ) - }); - let task_status = state - .network_runtime - .spawn( - TcpConnectionStealer::new(command_rx, steal_handle, tls_handler_store) - .start(cancellation_token), - ) - .into_status("TcpStealerTask"); - - BackgroundTask::Running(task_status, command_tx) + setup::start_stealer(runtime, steal_handle, cancellation_token.clone()) } }; - - let dns = { - let (command_tx, command_rx) = mpsc::channel::(1000); - let task_status = state - .network_runtime - .spawn( - DnsWorker::new(state.container_pid(), command_rx, args.ipv6) - .run(cancellation_token.clone()), - ) - .into_status("DnsTask"); - BackgroundTask::Running(task_status, command_tx) - }; - + let dns = setup::start_dns(&args, &state.network_runtime, cancellation_token.clone()); let bg_tasks = BackgroundTasks { sniffer, stealer, diff --git a/mirrord/agent/src/entrypoint/setup.rs b/mirrord/agent/src/entrypoint/setup.rs index c069f0d794d..cf4176aeafb 100644 --- a/mirrord/agent/src/entrypoint/setup.rs +++ b/mirrord/agent/src/entrypoint/setup.rs @@ -1,9 +1,20 @@ +use std::ops::Not; + use mirrord_agent_env::envs; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use super::BackgroundTask; use crate::{ + dns::{DnsCommand, DnsWorker}, error::{AgentError, AgentResult}, incoming::{self, RedirectorTask, StealHandle}, - util::remote_runtime::RemoteRuntime, + sniffer::{messages::SnifferCommand, TcpConnectionSniffer}, + steal::{StealTlsHandlerStore, StealerCommand, TcpConnectionStealer}, + util::{ + path_resolver::InTargetPathResolver, + remote_runtime::{IntoStatus, MaybeRemoteRuntime, RemoteRuntime}, + }, }; /// Starts a [`RedirectorTask`] on the given `runtime`. @@ -28,3 +39,78 @@ pub(super) async fn start_traffic_redirector(runtime: &RemoteRuntime) -> AgentRe Ok(handle) } + +pub(super) async fn start_sniffer( + args: &super::Args, + runtime: &RemoteRuntime, + cancellation_token: CancellationToken, +) -> BackgroundTask { + let (command_tx, command_rx) = mpsc::channel::(1000); + + let sniffer = runtime + .spawn(TcpConnectionSniffer::new( + command_rx, + args.network_interface.clone(), + args.is_mesh(), + )) + .await; + + match sniffer { + Ok(Ok(sniffer)) => { + let task_status = runtime + .spawn(sniffer.start(cancellation_token.clone())) + .into_status("TcpSnifferTask"); + + BackgroundTask::Running(task_status, command_tx) + } + Ok(Err(error)) => { + tracing::error!(%error, "Failed to create a TCP sniffer"); + BackgroundTask::Disabled + } + Err(error) => { + tracing::error!(%error, "Failed to create a TCP sniffer"); + BackgroundTask::Disabled + } + } +} + +pub(super) fn start_stealer( + runtime: &RemoteRuntime, + steal_handle: StealHandle, + cancellation_token: CancellationToken, +) -> BackgroundTask { + let (command_tx, command_rx) = mpsc::channel::(1000); + + let tls_steal_config = envs::STEAL_TLS_CONFIG.from_env_or_default(); + let tls_handler_store = tls_steal_config.is_empty().not().then(|| { + StealTlsHandlerStore::new( + tls_steal_config, + InTargetPathResolver::new(runtime.target_pid()), + ) + }); + let task_status = runtime + .spawn( + TcpConnectionStealer::new(command_rx, steal_handle, tls_handler_store) + .start(cancellation_token), + ) + .into_status("TcpStealerTask"); + + BackgroundTask::Running(task_status, command_tx) +} + +pub(super) fn start_dns( + args: &super::Args, + runtime: &MaybeRemoteRuntime, + cancellation_token: CancellationToken, +) -> BackgroundTask { + let (command_tx, command_rx) = mpsc::channel::(1000); + + let task_status = runtime + .spawn( + DnsWorker::new(runtime.target_pid(), command_rx, args.ipv6) + .run(cancellation_token.clone()), + ) + .into_status("DnsTask"); + + BackgroundTask::Running(task_status, command_tx) +} From c8015e60d4e95cbf284e6fc81dae163770d1e7d6 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 18:30:53 +0200 Subject: [PATCH 04/11] Fixed docs --- mirrord/agent/src/entrypoint.rs | 3 ++- mirrord/agent/src/outgoing.rs | 3 +-- mirrord/agent/src/outgoing/udp.rs | 6 ++---- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index a6f12ae3598..40c21a4ee2b 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -58,7 +58,8 @@ use crate::{ mod setup; -/// Size of [`mpsc`] channels connecting [`TcpStealerApi`] with the background task. +/// Size of [`mpsc`](tokio::sync::mpsc) channels connecting [`TcpStealerApi`]s with the background +/// task. const CHANNEL_SIZE: usize = 1024; /// Keeps track of next client id. diff --git a/mirrord/agent/src/outgoing.rs b/mirrord/agent/src/outgoing.rs index 062e1d4b96f..33a3353d72a 100644 --- a/mirrord/agent/src/outgoing.rs +++ b/mirrord/agent/src/outgoing.rs @@ -141,8 +141,7 @@ impl TcpOutgoingTask { } } - /// Runs this task as long as the channels connecting it with [`TcpOutgoingApi`] are open. - /// This routine never fails and returns [`Result`] only due to [`WatchedTask`] constraints. + /// Runs this task as long as the channels connecting it with the [`TcpOutgoingApi`] are open. #[tracing::instrument(level = Level::TRACE, skip(self))] async fn run(mut self) { loop { diff --git a/mirrord/agent/src/outgoing/udp.rs b/mirrord/agent/src/outgoing/udp.rs index 3e592235fe5..9f39c4dde2c 100644 --- a/mirrord/agent/src/outgoing/udp.rs +++ b/mirrord/agent/src/outgoing/udp.rs @@ -31,7 +31,7 @@ use crate::{ /// Task that handles [`LayerUdpOutgoing`] and [`DaemonUdpOutgoing`] messages. /// -/// We start these tasks from the [`UdpOutgoingApi`] as a [`WatchedTask`]. +/// We start these tasks from the [`UdpOutgoingApi`] on a [`MaybeRemoteRuntime`]. struct UdpOutgoingTask { next_connection_id: ConnectionId, /// Writing halves of peer connections made on layer's requests. @@ -85,9 +85,7 @@ impl UdpOutgoingTask { } } - /// Runs this task as long as the channels connecting it with [`UdpOutgoingApi`] are open. - /// This routine never fails and returns [`AgentResult`] only due to [`WatchedTask`] - /// constraints. + /// Runs this task as long as the channels connecting it with the [`UdpOutgoingApi`] are open. #[tracing::instrument(level = Level::TRACE, skip(self))] pub(super) async fn run(mut self) { loop { From 314d2c43f5894ad91a64eac06133eff50fdce75a Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 23:07:42 +0200 Subject: [PATCH 05/11] Blackbox text improved --- mirrord/agent/tests/blackbox.rs | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/mirrord/agent/tests/blackbox.rs b/mirrord/agent/tests/blackbox.rs index b40d0ea9d2b..6ab5e195b12 100644 --- a/mirrord/agent/tests/blackbox.rs +++ b/mirrord/agent/tests/blackbox.rs @@ -73,14 +73,18 @@ async fn sanity() { ))) .await .expect("port subscribe failed"); - assert!(matches!( - codec - .next() - .await - .expect("couldn't get next message") - .expect("got invalid message"), - DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok(_))) - )); + let response = codec + .next() + .await + .expect("couldn't get next message") + .expect("got invalid message"); + assert!( + matches!( + response, + DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok(_))) + ), + "unexpected response: {response:?}" + ); let mut test_conn = TcpStream::connect("127.0.0.1:1337") .await .expect("connection to dummy failed"); @@ -114,7 +118,8 @@ async fn sanity() { local_address: IpAddr::V4("127.0.0.1".parse().unwrap()), destination_port: 1337, source_port: port - })) + })), + "unexpected message: {new_conn_msg:?}", ); assert_eq!( @@ -122,12 +127,14 @@ async fn sanity() { DaemonMessage::Tcp(DaemonTcp::Data(TcpData { connection_id: 0, bytes: test_data.to_vec() - })) + })), + "unexpected message: {data_msg:?}", ); assert_eq!( close_msg, - DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 0 })) + DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 0 })), + "unexpected message: {close_msg:?}", ); drop(codec); From 401845d993125917df7496099bf7d3556a459cc1 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 23:16:31 +0200 Subject: [PATCH 06/11] Some more fixes --- mirrord/agent/src/entrypoint.rs | 12 ++++-------- mirrord/agent/tests/blackbox.rs | 21 +++++++++++++++++---- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 40c21a4ee2b..85304485536 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -136,7 +136,6 @@ impl State { }; let environ_path = PathBuf::from("/proc").join(pid).join("environ"); - match env::get_proc_environ(environ_path).await { Ok(environ) => env.extend(environ.into_iter()), Err(err) => { @@ -553,13 +552,10 @@ async fn start_agent(args: Args) -> AgentResult<()> { ipv4_listener_result }?; - match listener.local_addr() { - Ok(addr) => debug!( - client_listener_address = addr.to_string(), - "Created listener." - ), - Err(err) => error!(%err, "listener local address error"), - } + debug!( + client_listener_address = %listener.local_addr()?, + "Created the client listener.", + ); let state = State::new(&args).await?; diff --git a/mirrord/agent/tests/blackbox.rs b/mirrord/agent/tests/blackbox.rs index 6ab5e195b12..26fdf05794e 100644 --- a/mirrord/agent/tests/blackbox.rs +++ b/mirrord/agent/tests/blackbox.rs @@ -15,19 +15,32 @@ use tokio::{ time::{sleep, Duration}, }; use tokio_stream::StreamExt; +use tracing_subscriber::{fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt}; /// This test requires root or CAP_NET_RAW to setup TCP sniffing. #[tokio::test] async fn sanity() { - let mut bin = get_test_bin("mirrord-agent"); - // we do wait, not sure what's happened - #[allow(clippy::zombie_processes)] + tracing_subscriber::registry() + .with( + tracing_subscriber::fmt::layer() + .with_thread_ids(true) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .pretty() + .with_line_number(true), + ) + .with(tracing_subscriber::EnvFilter::from_env( + "mirrord=trace,warn", + )) + .init(); + + let mut bin = tokio::process::Command::from(get_test_bin("mirrord-agent")); let child = bin .arg("-t") .arg("2") .arg("-i") .arg("lo") .arg("blackbox-test") + .kill_on_drop(true) .spawn() .expect("mirrord-agent failed to start"); // Wait for agent to listen @@ -142,7 +155,7 @@ async fn sanity() { drop(mutex); task.await.unwrap(); - let result = child.wait_with_output().unwrap(); + let result = child.wait_with_output().await.unwrap(); assert!(result.status.success()); let stderr = String::from_utf8_lossy(&result.stderr); From 2deba33a3de1814bcc077dc36b90ce92ef4b983c Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 23:25:41 +0200 Subject: [PATCH 07/11] More fixes --- mirrord/agent/src/entrypoint.rs | 16 +++++++++------- mirrord/agent/tests/blackbox.rs | 15 +-------------- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 85304485536..2d3072e150e 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -92,7 +92,7 @@ impl State { let mut env: HashMap = HashMap::new(); - let (ephemeral, container, pid) = match &args.mode { + let (ephemeral, container) = match &args.mode { cli::Mode::Targeted { container_id, container_runtime, @@ -101,11 +101,10 @@ impl State { let container = get_container(container_id.clone(), container_runtime).await?; let container_handle = ContainerHandle::new(container).await?; - let pid = container_handle.pid().to_string(); env.extend(container_handle.raw_env().clone()); - (false, Some(container_handle), pid) + (false, Some(container_handle)) } cli::Mode::Ephemeral { .. } => { let container_handle = ContainerHandle::new(runtime::Container::Ephemeral( @@ -119,13 +118,12 @@ impl State { )) .await?; - let pid = container_handle.pid().to_string(); env.extend(container_handle.raw_env().clone()); // If we are in an ephemeral container, we use pid 1. - (true, Some(container_handle), pid) + (true, Some(container_handle)) } - cli::Mode::Targetless | cli::Mode::BlackboxTest => (false, None, "self".to_string()), + cli::Mode::Targetless | cli::Mode::BlackboxTest => (false, None), }; let network_runtime = match container.as_ref().map(ContainerHandle::pid) { @@ -135,7 +133,11 @@ impl State { None | Some(..) => MaybeRemoteRuntime::Local, }; - let environ_path = PathBuf::from("/proc").join(pid).join("environ"); + let env_pid = match container.as_ref().map(ContainerHandle::pid) { + Some(pid) => pid.to_string(), + None => "self".to_string(), + }; + let environ_path = PathBuf::from("/proc").join(env_pid).join("environ"); match env::get_proc_environ(environ_path).await { Ok(environ) => env.extend(environ.into_iter()), Err(err) => { diff --git a/mirrord/agent/tests/blackbox.rs b/mirrord/agent/tests/blackbox.rs index 26fdf05794e..c2d20cfd1c5 100644 --- a/mirrord/agent/tests/blackbox.rs +++ b/mirrord/agent/tests/blackbox.rs @@ -15,24 +15,10 @@ use tokio::{ time::{sleep, Duration}, }; use tokio_stream::StreamExt; -use tracing_subscriber::{fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt}; /// This test requires root or CAP_NET_RAW to setup TCP sniffing. #[tokio::test] async fn sanity() { - tracing_subscriber::registry() - .with( - tracing_subscriber::fmt::layer() - .with_thread_ids(true) - .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) - .pretty() - .with_line_number(true), - ) - .with(tracing_subscriber::EnvFilter::from_env( - "mirrord=trace,warn", - )) - .init(); - let mut bin = tokio::process::Command::from(get_test_bin("mirrord-agent")); let child = bin .arg("-t") @@ -40,6 +26,7 @@ async fn sanity() { .arg("-i") .arg("lo") .arg("blackbox-test") + .env("RUST_LOG", "mirrord=trace,warn") .kill_on_drop(true) .spawn() .expect("mirrord-agent failed to start"); From 55e67506b10af9c57b480f33b921ec26507d69ce Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 23:34:38 +0200 Subject: [PATCH 08/11] Removed blackbox test --- Cargo.lock | 7 -- mirrord/agent/Cargo.toml | 1 - mirrord/agent/src/cli.rs | 2 - mirrord/agent/src/entrypoint.rs | 2 +- mirrord/agent/src/main.rs | 6 -- mirrord/agent/tests/blackbox.rs | 156 -------------------------------- 6 files changed, 1 insertion(+), 173 deletions(-) delete mode 100644 mirrord/agent/tests/blackbox.rs diff --git a/Cargo.lock b/Cargo.lock index ef89b634778..f3c3dd340f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4332,7 +4332,6 @@ dependencies = [ "socket2", "streammap-ext", "tempfile", - "test_bin", "thiserror 2.0.12", "tokio", "tokio-rustls 0.26.2", @@ -7076,12 +7075,6 @@ dependencies = [ "toml 0.5.11", ] -[[package]] -name = "test_bin" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e7a7de15468c6e65dd7db81cf3822c1ec94c71b2a3c1a976ea8e4696c91115c" - [[package]] name = "tests" version = "0.1.0" diff --git a/mirrord/agent/Cargo.toml b/mirrord/agent/Cargo.toml index 6bec1139e3b..91cb344de10 100644 --- a/mirrord/agent/Cargo.toml +++ b/mirrord/agent/Cargo.toml @@ -82,4 +82,3 @@ rcgen.workspace = true reqwest.workspace = true rstest.workspace = true tempfile.workspace = true -test_bin = "0.4" diff --git a/mirrord/agent/src/cli.rs b/mirrord/agent/src/cli.rs index c3626bf85d4..e77e494cf95 100644 --- a/mirrord/agent/src/cli.rs +++ b/mirrord/agent/src/cli.rs @@ -98,8 +98,6 @@ pub enum Mode { }, #[default] Targetless, - #[clap(hide = true)] - BlackboxTest, } impl Mode { diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 2d3072e150e..94f2d1bcfee 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -123,7 +123,7 @@ impl State { // If we are in an ephemeral container, we use pid 1. (true, Some(container_handle)) } - cli::Mode::Targetless | cli::Mode::BlackboxTest => (false, None), + cli::Mode::Targetless => (false, None), }; let network_runtime = match container.as_ref().map(ContainerHandle::pid) { diff --git a/mirrord/agent/src/main.rs b/mirrord/agent/src/main.rs index 2788ae5ab52..e582f3eb85d 100644 --- a/mirrord/agent/src/main.rs +++ b/mirrord/agent/src/main.rs @@ -7,12 +7,6 @@ #![warn(clippy::indexing_slicing)] #![deny(unused_crate_dependencies)] -/// Silences `deny(unused_crate_dependencies)`. -/// -/// This dependency is only used in integration tests. -#[cfg(test)] -use test_bin as _; - mod cli; mod client_connection; mod container_handle; diff --git a/mirrord/agent/tests/blackbox.rs b/mirrord/agent/tests/blackbox.rs deleted file mode 100644 index c2d20cfd1c5..00000000000 --- a/mirrord/agent/tests/blackbox.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::{io::ErrorKind, net::IpAddr, sync::Arc}; - -use actix_codec::Framed; -use futures::SinkExt; -use mirrord_protocol::{ - tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, - ClientCodec, ClientMessage, DaemonMessage, -}; -use test_bin::get_test_bin; -use tokio::{ - io::AsyncWriteExt, - net::{TcpListener, TcpStream}, - select, - sync::Mutex, - time::{sleep, Duration}, -}; -use tokio_stream::StreamExt; - -/// This test requires root or CAP_NET_RAW to setup TCP sniffing. -#[tokio::test] -async fn sanity() { - let mut bin = tokio::process::Command::from(get_test_bin("mirrord-agent")); - let child = bin - .arg("-t") - .arg("2") - .arg("-i") - .arg("lo") - .arg("blackbox-test") - .env("RUST_LOG", "mirrord=trace,warn") - .kill_on_drop(true) - .spawn() - .expect("mirrord-agent failed to start"); - // Wait for agent to listen - sleep(Duration::from_millis(2000)).await; - let stream = TcpStream::connect("127.0.0.1:61337") - .await - .expect("connection to agent failed"); - let mutex = Arc::new(Mutex::new(0)); - let task_mutex = Arc::clone(&mutex); - let guard = mutex.lock().await; - let task = tokio::spawn(async move { - let listener = TcpListener::bind("127.0.0.1:1337") - .await - .expect("couldn't bind socket"); - loop { - select! { - Ok((socket, _)) = listener.accept() => { - let mut buf = [0; 4096]; - loop { - match socket.try_read(&mut buf) { - Ok(0) => break, - Ok(_) => {} - Err(ref e) if e.kind() == ErrorKind::WouldBlock => { - sleep(Duration::from_millis(10)).await; - } - Err(e) => panic!("socket error {e:?}") - } - } - }, - _ = task_mutex.lock() => { - break - } - } - } - }); - - let mut codec = Framed::new(stream, ClientCodec::default()); - let subscription_port = 1337; - - codec - .send(ClientMessage::Tcp(LayerTcp::PortSubscribe( - subscription_port, - ))) - .await - .expect("port subscribe failed"); - let response = codec - .next() - .await - .expect("couldn't get next message") - .expect("got invalid message"); - assert!( - matches!( - response, - DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok(_))) - ), - "unexpected response: {response:?}" - ); - let mut test_conn = TcpStream::connect("127.0.0.1:1337") - .await - .expect("connection to dummy failed"); - let port = test_conn.local_addr().unwrap().port(); - let test_data = [0, 3, 5]; - test_conn - .write_all(&test_data) - .await - .expect("couldn't write test data"); - drop(test_conn); - let new_conn_msg = codec - .next() - .await - .expect("couldn't get next message") - .expect("got invalid message"); - let data_msg = codec - .next() - .await - .expect("couldn't get next message") - .expect("got invalid message"); - let close_msg = codec - .next() - .await - .expect("couldn't get next message") - .expect("got invalid message"); - assert_eq!( - new_conn_msg, - DaemonMessage::Tcp(DaemonTcp::NewConnection(NewTcpConnection { - connection_id: 0, - remote_address: IpAddr::V4("127.0.0.1".parse().unwrap()), - local_address: IpAddr::V4("127.0.0.1".parse().unwrap()), - destination_port: 1337, - source_port: port - })), - "unexpected message: {new_conn_msg:?}", - ); - - assert_eq!( - data_msg, - DaemonMessage::Tcp(DaemonTcp::Data(TcpData { - connection_id: 0, - bytes: test_data.to_vec() - })), - "unexpected message: {data_msg:?}", - ); - - assert_eq!( - close_msg, - DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 0 })), - "unexpected message: {close_msg:?}", - ); - - drop(codec); - drop(guard); - drop(mutex); - - task.await.unwrap(); - let result = child.wait_with_output().await.unwrap(); - assert!(result.status.success()); - - let stderr = String::from_utf8_lossy(&result.stderr); - println!("stderr: {stderr:?}"); - - let stdout = String::from_utf8_lossy(&result.stdout); - println!("stdout: {stdout:?}"); - - assert!(!stderr.to_ascii_lowercase().contains("error")); - assert!(!stdout.to_ascii_lowercase().contains("error")); -} From daafa7bdc7f097fade326fd5f2bdfaab3e4f4c56 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 1 Apr 2025 23:45:35 +0200 Subject: [PATCH 09/11] Task setup fix --- mirrord/agent/src/entrypoint.rs | 24 ++++++++++++++---------- mirrord/agent/src/entrypoint/setup.rs | 16 ++++++++-------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 94f2d1bcfee..6701179fcde 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -578,17 +578,21 @@ async fn start_agent(args: Args) -> AgentResult<()> { }); } - let sniffer = match &state.network_runtime { - MaybeRemoteRuntime::Remote(runtime) => { - setup::start_sniffer(&args, runtime, cancellation_token.clone()).await - } - MaybeRemoteRuntime::Local => BackgroundTask::Disabled, + let sniffer = if state.container_pid().is_some() { + setup::start_sniffer(&args, &state.network_runtime, cancellation_token.clone()).await + } else { + BackgroundTask::Disabled }; - let stealer = match &state.network_runtime { - MaybeRemoteRuntime::Local => BackgroundTask::Disabled, - MaybeRemoteRuntime::Remote(runtime) => { - let steal_handle = setup::start_traffic_redirector(runtime).await?; - setup::start_stealer(runtime, steal_handle, cancellation_token.clone()) + let stealer = match state.container_pid() { + None => BackgroundTask::Disabled, + Some(pid) => { + let steal_handle = setup::start_traffic_redirector(&state.network_runtime).await?; + setup::start_stealer( + &state.network_runtime, + pid, + steal_handle, + cancellation_token.clone(), + ) } }; let dns = setup::start_dns(&args, &state.network_runtime, cancellation_token.clone()); diff --git a/mirrord/agent/src/entrypoint/setup.rs b/mirrord/agent/src/entrypoint/setup.rs index cf4176aeafb..d45c4df38c1 100644 --- a/mirrord/agent/src/entrypoint/setup.rs +++ b/mirrord/agent/src/entrypoint/setup.rs @@ -13,14 +13,16 @@ use crate::{ steal::{StealTlsHandlerStore, StealerCommand, TcpConnectionStealer}, util::{ path_resolver::InTargetPathResolver, - remote_runtime::{IntoStatus, MaybeRemoteRuntime, RemoteRuntime}, + remote_runtime::{IntoStatus, MaybeRemoteRuntime}, }, }; /// Starts a [`RedirectorTask`] on the given `runtime`. /// /// Returns the [`StealHandle`] that can be used to steal incoming traffic. -pub(super) async fn start_traffic_redirector(runtime: &RemoteRuntime) -> AgentResult { +pub(super) async fn start_traffic_redirector( + runtime: &MaybeRemoteRuntime, +) -> AgentResult { let flush_connections = envs::STEALER_FLUSH_CONNECTIONS.from_env_or_default(); let pod_ips = envs::POD_IPS.from_env_or_default(); let support_ipv6 = envs::IPV6_SUPPORT.from_env_or_default(); @@ -42,7 +44,7 @@ pub(super) async fn start_traffic_redirector(runtime: &RemoteRuntime) -> AgentRe pub(super) async fn start_sniffer( args: &super::Args, - runtime: &RemoteRuntime, + runtime: &MaybeRemoteRuntime, cancellation_token: CancellationToken, ) -> BackgroundTask { let (command_tx, command_rx) = mpsc::channel::(1000); @@ -75,7 +77,8 @@ pub(super) async fn start_sniffer( } pub(super) fn start_stealer( - runtime: &RemoteRuntime, + runtime: &MaybeRemoteRuntime, + target_pid: u64, steal_handle: StealHandle, cancellation_token: CancellationToken, ) -> BackgroundTask { @@ -83,10 +86,7 @@ pub(super) fn start_stealer( let tls_steal_config = envs::STEAL_TLS_CONFIG.from_env_or_default(); let tls_handler_store = tls_steal_config.is_empty().not().then(|| { - StealTlsHandlerStore::new( - tls_steal_config, - InTargetPathResolver::new(runtime.target_pid()), - ) + StealTlsHandlerStore::new(tls_steal_config, InTargetPathResolver::new(target_pid)) }); let task_status = runtime .spawn( From e5d58168a7147e2828991251cfd34bbdea1deb13 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Wed, 2 Apr 2025 23:00:04 +0200 Subject: [PATCH 10/11] CR suggestions --- mirrord/agent/src/error.rs | 2 +- mirrord/agent/src/util.rs | 1 + mirrord/agent/src/util/error.rs | 27 +++++++++++++++++ mirrord/agent/src/util/remote_runtime.rs | 38 ++++++------------------ 4 files changed, 38 insertions(+), 30 deletions(-) create mode 100644 mirrord/agent/src/util/error.rs diff --git a/mirrord/agent/src/error.rs b/mirrord/agent/src/error.rs index cde2c8fded1..b80c9b90eee 100644 --- a/mirrord/agent/src/error.rs +++ b/mirrord/agent/src/error.rs @@ -4,7 +4,7 @@ use thiserror::Error; use crate::{ client_connection::TlsSetupError, incoming::RedirectorTaskError, namespace::NamespaceError, - runtime, util::remote_runtime::RemoteRuntimeError, + runtime, util::error::RemoteRuntimeError, }; #[derive(Debug, Error)] diff --git a/mirrord/agent/src/util.rs b/mirrord/agent/src/util.rs index 28588b84ae2..c3ec87d45d9 100644 --- a/mirrord/agent/src/util.rs +++ b/mirrord/agent/src/util.rs @@ -10,6 +10,7 @@ use std::{ use futures::{future::BoxFuture, FutureExt}; use tokio::sync::mpsc; +pub mod error; pub mod path_resolver; pub mod remote_runtime; diff --git a/mirrord/agent/src/util/error.rs b/mirrord/agent/src/util/error.rs new file mode 100644 index 00000000000..f114feca04c --- /dev/null +++ b/mirrord/agent/src/util/error.rs @@ -0,0 +1,27 @@ +use std::io; + +use thiserror::Error; + +use crate::namespace::NamespaceError; + +/// Errors that can occur when creating a [`RemoteRuntime`](super::remote_runtime::RemoteRuntime). +#[derive(Error, Debug)] +pub enum RemoteRuntimeError { + #[error("failed to spawn runtime thread: {0}")] + ThreadSpawnError(#[source] io::Error), + #[error(transparent)] + NamespaceError(#[from] NamespaceError), + #[error("failed to build tokio runtime: {0}")] + TokioRuntimeError(#[source] io::Error), + #[error("runtime thread panicked")] + Panicked, +} + +/// An error that occurs when polling a future spawned with +/// [`RemoteRuntime::spawn`](super::remote_runtime::RemoteRuntime::spawn) or +/// [`MaybeRemoteRuntime::spawn`](super::remote_runtime::MaybeRemoteRuntime::spawn). +/// +/// This error indicated that the future has panicked. +#[derive(Debug, Error)] +#[error("task panicked")] +pub struct BgTaskPanicked; diff --git a/mirrord/agent/src/util/remote_runtime.rs b/mirrord/agent/src/util/remote_runtime.rs index 7a097962b7d..ebef651696a 100644 --- a/mirrord/agent/src/util/remote_runtime.rs +++ b/mirrord/agent/src/util/remote_runtime.rs @@ -1,9 +1,9 @@ -//! This module contains utilities for running async code in the agent target's namespace. +//! Utilities for running async code in the agent target's namespace. //! -//! This is useful for running tasks that require access to the target's network namespace, +//! Useful for running tasks that require access to the target's network namespace, //! such as traffic stealing, traffic mirroring, DNS resolution, outgoing traffic. //! -//! This module provides: +//! Provides: //! 1. A [`RemoteRuntime`] struct, that can be used to run tasks in the target's namespace. //! 2. A [`MaybeRemoteRuntime`] enum, that don't necessarily require a target (DNS and outgoing //! traffic), but should be run in the target's namespace if available. @@ -13,7 +13,6 @@ use std::{ error::Error, fmt, future::Future, - io, ops::Not, pin::Pin, sync::Arc, @@ -25,27 +24,14 @@ use futures::{ future::{BoxFuture, Shared}, FutureExt, }; -use thiserror::Error; use tokio::sync::{mpsc, oneshot}; +use super::error::{BgTaskPanicked, RemoteRuntimeError}; use crate::{ error::AgentError, - namespace::{self, NamespaceError, NamespaceType}, + namespace::{self, NamespaceType}, }; -/// Errors that can occur when creating a [`RemoteRuntime`]. -#[derive(Error, Debug)] -pub enum RemoteRuntimeError { - #[error("failed to spawn runtime thread: {0}")] - ThreadSpawnError(#[source] io::Error), - #[error(transparent)] - NamespaceError(#[from] NamespaceError), - #[error("failed to build tokio runtime: {0}")] - TokioRuntimeError(#[source] io::Error), - #[error("runtime thread panicked")] - Panicked, -} - /// A cloneable handle to a remote [`tokio::runtime::Runtime`] that runs in its own thread. /// /// Can be used to spawn tasks with [`RemoteRuntime::spawn`]. @@ -67,7 +53,7 @@ impl RemoteRuntime { ) -> Result { let (future_tx, mut future_rx) = mpsc::channel(16); let (result_tx, result_rx) = oneshot::channel(); - let thread_name = format!("remote-{namespace_type}-runtime-thread"); + let thread_name = format!("remote-{target_pid}-{namespace_type}-runtime-thread"); let thread_logic = move || { if let Err(error) = namespace::set_namespace(target_pid, namespace_type) { let _ = result_tx.send(Err(error.into())); @@ -76,7 +62,9 @@ impl RemoteRuntime { let rt_result = tokio::runtime::Builder::new_current_thread() .enable_all() - .thread_name(format!("remote-{namespace_type}-runtime-worker")) + .thread_name(format!( + "remote-{target_pid}-{namespace_type}-runtime-worker" + )) .build(); let rt = match rt_result { Ok(rt) => rt, @@ -150,14 +138,6 @@ impl fmt::Debug for RemoteRuntime { } } -/// An error that occurs when polling a future spawned with [`RemoteRuntime::spawn`] or -/// [`MaybeRemoteRuntime::spawn`]. -/// -/// This error indicated that the future has panicked. -#[derive(Debug, Error)] -#[error("task panicked")] -pub struct BgTaskPanicked; - /// A future spawned with [`RemoteRuntime::spawn`] or /// [`MaybeRemoteRuntime::spawn`] pub struct BgTask { From a4206ab29797de86d401e41d4c08f505593ec45e Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Wed, 2 Apr 2025 23:02:04 +0200 Subject: [PATCH 11/11] MaybeRemoteRuntime -> BgTaskRuntime --- mirrord/agent/src/entrypoint.rs | 8 ++++---- mirrord/agent/src/entrypoint/setup.rs | 12 +++++------- mirrord/agent/src/outgoing.rs | 4 ++-- mirrord/agent/src/outgoing/udp.rs | 6 +++--- mirrord/agent/src/sniffer.rs | 4 ++-- mirrord/agent/src/steal/connection.rs | 4 ++-- mirrord/agent/src/util/error.rs | 2 +- mirrord/agent/src/util/remote_runtime.rs | 16 ++++++++-------- mirrord/agent/src/vpn.rs | 4 ++-- 9 files changed, 29 insertions(+), 31 deletions(-) diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index a8544ed85a8..967f2e01536 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -51,7 +51,7 @@ use crate::{ sniffer::{api::TcpSnifferApi, messages::SnifferCommand}, steal::{self, StealerCommand, TcpStealerApi}, util::{ - remote_runtime::{BgTaskStatus, MaybeRemoteRuntime, RemoteRuntime}, + remote_runtime::{BgTaskRuntime, BgTaskStatus, RemoteRuntime}, ClientId, }, }; @@ -78,7 +78,7 @@ struct State { /// When present, it is used to secure incoming TCP connections. tls_connector: Option, /// [`tokio::runtime`] that should be used for network operations. - network_runtime: MaybeRemoteRuntime, + network_runtime: BgTaskRuntime, } impl State { @@ -127,10 +127,10 @@ impl State { }; let network_runtime = match container.as_ref().map(ContainerHandle::pid) { - Some(pid) if ephemeral.not() => MaybeRemoteRuntime::Remote( + Some(pid) if ephemeral.not() => BgTaskRuntime::Remote( RemoteRuntime::new_in_namespace(pid, NamespaceType::Net).await?, ), - None | Some(..) => MaybeRemoteRuntime::Local, + None | Some(..) => BgTaskRuntime::Local, }; let env_pid = match container.as_ref().map(ContainerHandle::pid) { diff --git a/mirrord/agent/src/entrypoint/setup.rs b/mirrord/agent/src/entrypoint/setup.rs index d45c4df38c1..38ba7d44f29 100644 --- a/mirrord/agent/src/entrypoint/setup.rs +++ b/mirrord/agent/src/entrypoint/setup.rs @@ -13,16 +13,14 @@ use crate::{ steal::{StealTlsHandlerStore, StealerCommand, TcpConnectionStealer}, util::{ path_resolver::InTargetPathResolver, - remote_runtime::{IntoStatus, MaybeRemoteRuntime}, + remote_runtime::{BgTaskRuntime, IntoStatus}, }, }; /// Starts a [`RedirectorTask`] on the given `runtime`. /// /// Returns the [`StealHandle`] that can be used to steal incoming traffic. -pub(super) async fn start_traffic_redirector( - runtime: &MaybeRemoteRuntime, -) -> AgentResult { +pub(super) async fn start_traffic_redirector(runtime: &BgTaskRuntime) -> AgentResult { let flush_connections = envs::STEALER_FLUSH_CONNECTIONS.from_env_or_default(); let pod_ips = envs::POD_IPS.from_env_or_default(); let support_ipv6 = envs::IPV6_SUPPORT.from_env_or_default(); @@ -44,7 +42,7 @@ pub(super) async fn start_traffic_redirector( pub(super) async fn start_sniffer( args: &super::Args, - runtime: &MaybeRemoteRuntime, + runtime: &BgTaskRuntime, cancellation_token: CancellationToken, ) -> BackgroundTask { let (command_tx, command_rx) = mpsc::channel::(1000); @@ -77,7 +75,7 @@ pub(super) async fn start_sniffer( } pub(super) fn start_stealer( - runtime: &MaybeRemoteRuntime, + runtime: &BgTaskRuntime, target_pid: u64, steal_handle: StealHandle, cancellation_token: CancellationToken, @@ -100,7 +98,7 @@ pub(super) fn start_stealer( pub(super) fn start_dns( args: &super::Args, - runtime: &MaybeRemoteRuntime, + runtime: &BgTaskRuntime, cancellation_token: CancellationToken, ) -> BackgroundTask { let (command_tx, command_rx) = mpsc::channel::(1000); diff --git a/mirrord/agent/src/outgoing.rs b/mirrord/agent/src/outgoing.rs index 33a3353d72a..e36dcc7838d 100644 --- a/mirrord/agent/src/outgoing.rs +++ b/mirrord/agent/src/outgoing.rs @@ -20,7 +20,7 @@ use tracing::Level; use crate::{ error::AgentResult, metrics::TCP_OUTGOING_CONNECTION, - util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}, + util::remote_runtime::{BgTaskRuntime, BgTaskStatus, IntoStatus}, }; mod socket_stream; @@ -48,7 +48,7 @@ impl TcpOutgoingApi { /// # Params /// /// * `runtime` - tokio runtime to spawn the background task on. - pub(crate) fn new(runtime: &MaybeRemoteRuntime) -> Self { + pub(crate) fn new(runtime: &BgTaskRuntime) -> Self { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); diff --git a/mirrord/agent/src/outgoing/udp.rs b/mirrord/agent/src/outgoing/udp.rs index 9f39c4dde2c..8985815544d 100644 --- a/mirrord/agent/src/outgoing/udp.rs +++ b/mirrord/agent/src/outgoing/udp.rs @@ -26,12 +26,12 @@ use tracing::Level; use crate::{ error::AgentResult, metrics::UDP_OUTGOING_CONNECTION, - util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}, + util::remote_runtime::{BgTaskRuntime, BgTaskStatus, IntoStatus}, }; /// Task that handles [`LayerUdpOutgoing`] and [`DaemonUdpOutgoing`] messages. /// -/// We start these tasks from the [`UdpOutgoingApi`] on a [`MaybeRemoteRuntime`]. +/// We start these tasks from the [`UdpOutgoingApi`] on a [`BgTaskRuntime`]. struct UdpOutgoingTask { next_connection_id: ConnectionId, /// Writing halves of peer connections made on layer's requests. @@ -303,7 +303,7 @@ async fn connect(remote_address: SocketAddress) -> Result Self { + pub(crate) fn new(runtime: &BgTaskRuntime) -> Self { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); diff --git a/mirrord/agent/src/sniffer.rs b/mirrord/agent/src/sniffer.rs index 976ae0e0b6a..b1a6f1cb0e0 100644 --- a/mirrord/agent/src/sniffer.rs +++ b/mirrord/agent/src/sniffer.rs @@ -416,7 +416,7 @@ mod test { use tokio::sync::mpsc; use super::*; - use crate::util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}; + use crate::util::remote_runtime::{BgTaskRuntime, BgTaskStatus, IntoStatus}; struct TestSnifferSetup { command_tx: Sender, @@ -457,7 +457,7 @@ mod test { clients_closed: Default::default(), }; - let task_status = MaybeRemoteRuntime::Local + let task_status = BgTaskRuntime::Local .spawn(sniffer.start(CancellationToken::new())) .into_status("TcpSnifferTask"); diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 2b7240902c7..cf7d73774dc 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -818,7 +818,7 @@ mod test { connection::{Client, MatchedHttpRequest}, TcpConnectionStealer, TcpStealerApi, }, - util::remote_runtime::{IntoStatus, MaybeRemoteRuntime}, + util::remote_runtime::{BgTaskRuntime, IntoStatus}, }; async fn prepare_dummy_service() -> ( @@ -1023,7 +1023,7 @@ mod test { let (task, handle) = RedirectorTask::new(redirector); tokio::spawn(task.run()); - let task_status = MaybeRemoteRuntime::Local + let task_status = BgTaskRuntime::Local .spawn( TcpConnectionStealer::new(command_rx, handle, None).start(CancellationToken::new()), ) diff --git a/mirrord/agent/src/util/error.rs b/mirrord/agent/src/util/error.rs index f114feca04c..f482b7841c7 100644 --- a/mirrord/agent/src/util/error.rs +++ b/mirrord/agent/src/util/error.rs @@ -19,7 +19,7 @@ pub enum RemoteRuntimeError { /// An error that occurs when polling a future spawned with /// [`RemoteRuntime::spawn`](super::remote_runtime::RemoteRuntime::spawn) or -/// [`MaybeRemoteRuntime::spawn`](super::remote_runtime::MaybeRemoteRuntime::spawn). +/// [`BgTaskRuntime::spawn`](super::remote_runtime::BgTaskRuntime::spawn). /// /// This error indicated that the future has panicked. #[derive(Debug, Error)] diff --git a/mirrord/agent/src/util/remote_runtime.rs b/mirrord/agent/src/util/remote_runtime.rs index ebef651696a..48e82ce9e16 100644 --- a/mirrord/agent/src/util/remote_runtime.rs +++ b/mirrord/agent/src/util/remote_runtime.rs @@ -5,8 +5,8 @@ //! //! Provides: //! 1. A [`RemoteRuntime`] struct, that can be used to run tasks in the target's namespace. -//! 2. A [`MaybeRemoteRuntime`] enum, that don't necessarily require a target (DNS and outgoing -//! traffic), but should be run in the target's namespace if available. +//! 2. A [`BgTaskRuntime`] enum, that don't necessarily require a target (DNS and outgoing traffic), +//! but should be run in the target's namespace if available. //! 3. A [`BgTaskStatus`] struct, that can be used to poll for a spawned task's status. use std::{ @@ -139,7 +139,7 @@ impl fmt::Debug for RemoteRuntime { } /// A future spawned with [`RemoteRuntime::spawn`] or -/// [`MaybeRemoteRuntime::spawn`] +/// [`BgTaskRuntime::spawn`] pub struct BgTask { future_result: oneshot::Receiver, } @@ -156,7 +156,7 @@ impl Future for BgTask { } /// A cloneable status of a future spawned with [`RemoteRuntime::spawn`] or -/// [`MaybeRemoteRuntime::spawn`]. +/// [`BgTaskRuntime::spawn`]. #[derive(Clone)] pub struct BgTaskStatus { task_name: &'static str, @@ -266,17 +266,17 @@ impl IntoStatus for BgTask<()> { /// /// This can be used to spawn tasks that can either run in the target's namespace or the agent's. /// -/// If the agent has a target, you should use [`MaybeRemoteRuntime::Remote`]. -/// If the agent does not have a target, you should fallback to [`MaybeRemoteRuntime::Local`]. +/// If the agent has a target, you should use [`BgTaskRuntime::Remote`]. +/// If the agent does not have a target, you should fallback to [`BgTaskRuntime::Local`]. #[derive(Clone)] -pub enum MaybeRemoteRuntime { +pub enum BgTaskRuntime { /// Remote runtime, which runs in the target's namespace. Remote(RemoteRuntime), /// Local runtime ([`tokio::runtime::Handle::current`]). Local, } -impl MaybeRemoteRuntime { +impl BgTaskRuntime { /// Spawns the given future on this runtime. pub fn spawn(&self, future: F) -> BgTask where diff --git a/mirrord/agent/src/vpn.rs b/mirrord/agent/src/vpn.rs index cff649e90f5..276db95e558 100644 --- a/mirrord/agent/src/vpn.rs +++ b/mirrord/agent/src/vpn.rs @@ -19,7 +19,7 @@ use tokio::{ use crate::{ error::AgentResult, - util::remote_runtime::{BgTaskStatus, IntoStatus, MaybeRemoteRuntime}, + util::remote_runtime::{BgTaskRuntime, BgTaskStatus, IntoStatus}, }; /// An interface for a background task handling [`ClientVpn`] messages. @@ -41,7 +41,7 @@ impl VpnApi { /// # Params /// /// * `runtime` - tokio runtime to spawn the task on. - pub(crate) fn new(runtime: &MaybeRemoteRuntime) -> Self { + pub(crate) fn new(runtime: &BgTaskRuntime) -> Self { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); let pid = runtime.target_pid();