Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions changelog.d/agent-threading.internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reworked agent's threading model to avoid spawning excessive threads.
1 change: 0 additions & 1 deletion mirrord/agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,3 @@ rcgen.workspace = true
reqwest.workspace = true
rstest.workspace = true
tempfile.workspace = true
test_bin = "0.4"
2 changes: 0 additions & 2 deletions mirrord/agent/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ pub enum Mode {
},
#[default]
Targetless,
#[clap(hide = true)]
BlackboxTest,
}

impl Mode {
Expand Down
92 changes: 55 additions & 37 deletions mirrord/agent/src/dns.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::{future, io, path::PathBuf, sync::atomic::Ordering, time::Duration};
use std::{
collections::HashMap, future, io, path::PathBuf, sync::atomic::Ordering, time::Duration,
};

use futures::{stream::FuturesOrdered, StreamExt};
use hickory_resolver::{
Expand All @@ -24,16 +26,12 @@ use tokio::{
mpsc::{Receiver, Sender},
oneshot,
},
task::JoinSet,
task::{Id, JoinSet},
};
use tokio_util::sync::CancellationToken;
use tracing::{warn, Level};

use crate::{
error::{AgentError, AgentResult},
metrics::DNS_REQUEST_COUNT,
watched_task::TaskStatus,
};
use crate::{error::AgentResult, metrics::DNS_REQUEST_COUNT, util::remote_runtime::BgTaskStatus};

#[derive(Debug)]
pub(crate) enum ClientGetAddrInfoRequest {
Expand All @@ -54,7 +52,7 @@ impl ClientGetAddrInfoRequest {
#[derive(Debug)]
pub(crate) struct DnsCommand {
request: ClientGetAddrInfoRequest,
response_tx: oneshot::Sender<Result<DnsLookup, InternalLookupError>>,
response_tx: oneshot::Sender<Result<DnsLookup, ResolveErrorKindInternal>>,
}

/// Background task for resolving hostnames to IP addresses.
Expand All @@ -80,12 +78,11 @@ pub(crate) struct DnsWorker {
/// Background tasks that handle the DNS requests.
///
/// Each of these builds a new [`TokioAsyncResolver`] and performs one lookup.
tasks: JoinSet<()>,
tasks: JoinSet<Result<DnsLookup, InternalLookupError>>,
response_txs: HashMap<Id, oneshot::Sender<Result<DnsLookup, ResolveErrorKindInternal>>>,
}

impl DnsWorker {
pub const TASK_NAME: &'static str = "DNS worker";

/// Creates a new instance of this worker.
/// To run this worker, call [`Self::run`].
///
Expand Down Expand Up @@ -124,6 +121,7 @@ impl DnsWorker {
attempts,
support_ipv6,
tasks: Default::default(),
response_txs: Default::default(),
}
}

Expand Down Expand Up @@ -203,34 +201,51 @@ impl DnsWorker {
let attempts = self.attempts;
let support_ipv6 = self.support_ipv6;

let lookup_future = async move {
let result = Self::do_lookup(
etc_path,
message.request.into_v2(),
attempts,
timeout,
support_ipv6,
)
.await;

let _ = message.response_tx.send(result);
};
let handle = self.tasks.spawn(Self::do_lookup(
etc_path,
message.request.into_v2(),
attempts,
timeout,
support_ipv6,
));
self.response_txs.insert(handle.id(), message.response_tx);

DNS_REQUEST_COUNT.fetch_add(1, Ordering::Relaxed);
self.tasks.spawn(lookup_future);
}

pub(crate) async fn run(mut self, cancellation_token: CancellationToken) -> AgentResult<()> {
pub(crate) async fn run(mut self, cancellation_token: CancellationToken) {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => break Ok(()),
_ = cancellation_token.cancelled() => break,

Some(..) = self.tasks.join_next() => {
Some(result) = self.tasks.join_next_with_id() => {
DNS_REQUEST_COUNT.fetch_sub(1, Ordering::Relaxed);
let (id, result) = match result {
Ok((id, result)) => (
id,
result.map_err(Into::into),
),
Err(error) => {
(
error.id(),
Err(ResolveErrorKindInternal::Message("DNS task panicked".into()))
)
}
};

let response_tx = self.response_txs.remove(&id);
match response_tx {
Some(response_tx) => {
let _ = response_tx.send(result);
}
None => {
warn!(?id, "Received a DNS result with no matching response channel");
}
}
}

message = self.request_rx.recv() => match message {
None => break Ok(()),
None => break,
Some(message) => self.handle_message(message),
},
}
Expand All @@ -246,15 +261,15 @@ impl Drop for DnsWorker {
}

pub(crate) struct DnsApi {
task_status: TaskStatus,
task_status: BgTaskStatus,
request_tx: Sender<DnsCommand>,
/// [`DnsWorker`] processes all requests concurrently, so we use a combination of [`oneshot`]
/// channels and [`FuturesOrdered`] to preserve order of responses.
responses: FuturesOrdered<oneshot::Receiver<Result<DnsLookup, InternalLookupError>>>,
responses: FuturesOrdered<oneshot::Receiver<Result<DnsLookup, ResolveErrorKindInternal>>>,
}

impl DnsApi {
pub(crate) fn new(task_status: TaskStatus, task_sender: Sender<DnsCommand>) -> Self {
pub(crate) fn new(task_status: BgTaskStatus, task_sender: Sender<DnsCommand>) -> Self {
Self {
task_status,
request_tx: task_sender,
Expand All @@ -276,7 +291,7 @@ impl DnsApi {
response_tx,
};
if self.request_tx.send(command).await.is_err() {
return Err(self.task_status.unwrap_err().await);
return Err(self.task_status.wait_assert_running().await);
}

self.responses.push_back(response_rx);
Expand All @@ -294,11 +309,14 @@ impl DnsApi {
return future::pending().await;
};

let response = response
.map_err(|_| AgentError::DnsTaskPanic)?
.map_err(|error| ResponseError::DnsLookup(DnsLookupError { kind: error.into() }));

Ok(GetAddrInfoResponse(response))
match response {
Ok(response) => {
Ok(GetAddrInfoResponse(response.map_err(|kind| {
ResponseError::DnsLookup(DnsLookupError { kind })
})))
}
Err(..) => Err(self.task_status.wait_assert_running().await),
}
}
}

Expand Down
Loading