Skip to content

Commit

Permalink
fix SRV override loss on cache expiration
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Volk <[email protected]>
  • Loading branch information
jevolk committed Jan 23, 2025
1 parent 265802d commit a5520e8
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 40 deletions.
10 changes: 6 additions & 4 deletions src/admin/query/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,22 @@ async fn destinations_cache(
async fn overrides_cache(&self, server_name: Option<String>) -> Result<RoomMessageEventContent> {
use service::resolver::cache::CachedOverride;

writeln!(self, "| Server Name | IP | Port | Expires |").await?;
writeln!(self, "| ----------- | --- | ----:| ------- |").await?;
writeln!(self, "| Server Name | IP | Port | Expires | Overriding |").await?;
writeln!(self, "| ----------- | --- | ----:| ------- | ---------- |").await?;

let mut overrides = self.services.resolver.cache.overrides().boxed();

while let Some((name, CachedOverride { ips, port, expire })) = overrides.next().await {
while let Some((name, CachedOverride { ips, port, expire, overriding })) =
overrides.next().await
{
if let Some(server_name) = server_name.as_ref() {
if name != server_name {
continue;
}
}

let expire = time::format(expire, "%+");
self.write_str(&format!("| {name} | {ips:?} | {port} | {expire} |\n"))
self.write_str(&format!("| {name} | {ips:?} | {port} | {expire} | {overriding:?} |\n"))
.await?;
}

Expand Down
62 changes: 32 additions & 30 deletions src/service/resolver/actual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,8 @@ impl super::Service {
async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> {
debug!("2: Hostname with included port");
let (host, port) = dest.as_str().split_at(pos);
self.conditional_query_and_cache_override(
host,
host,
port.parse::<u16>().unwrap_or(8448),
cache,
)
.await?;
self.conditional_query_and_cache(host, port.parse::<u16>().unwrap_or(8448), cache)
.await?;

Ok(FedDest::Named(
host.to_owned(),
Expand Down Expand Up @@ -163,13 +158,8 @@ impl super::Service {
) -> Result<FedDest> {
debug!("3.2: Hostname with port in .well-known file");
let (host, port) = delegated.split_at(pos);
self.conditional_query_and_cache_override(
host,
host,
port.parse::<u16>().unwrap_or(8448),
cache,
)
.await?;
self.conditional_query_and_cache(host, port.parse::<u16>().unwrap_or(8448), cache)
.await?;

Ok(FedDest::Named(
host.to_owned(),
Expand Down Expand Up @@ -208,7 +198,7 @@ impl super::Service {

async fn actual_dest_3_4(&self, cache: bool, delegated: String) -> Result<FedDest> {
debug!("3.4: No SRV records, just use the hostname from .well-known");
self.conditional_query_and_cache_override(&delegated, &delegated, 8448, cache)
self.conditional_query_and_cache(&delegated, 8448, cache)
.await?;
Ok(add_port_to_hostname(&delegated))
}
Expand Down Expand Up @@ -243,17 +233,15 @@ impl super::Service {

async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result<FedDest> {
debug!("5: No SRV record found");
self.conditional_query_and_cache_override(dest.as_str(), dest.as_str(), 8448, cache)
self.conditional_query_and_cache(dest.as_str(), 8448, cache)
.await?;

Ok(add_port_to_hostname(dest.as_str()))
}

#[tracing::instrument(skip_all, name = "well-known")]
async fn request_well_known(&self, dest: &str) -> Result<Option<String>> {
if !self.cache.has_override(dest).await {
self.query_and_cache_override(dest, dest, 8448).await?;
}
self.conditional_query_and_cache(dest, 8448, true).await?;

self.services.server.check_running()?;
trace!("Requesting well known for {dest}");
Expand Down Expand Up @@ -301,20 +289,35 @@ impl super::Service {
Ok(Some(m_server.to_owned()))
}

#[inline]
async fn conditional_query_and_cache(
&self,
hostname: &str,
port: u16,
cache: bool,
) -> Result {
self.conditional_query_and_cache_override(hostname, hostname, port, cache)
.await
}

#[inline]
async fn conditional_query_and_cache_override(
&self,
overname: &str,
hostname: &str,
port: u16,
cache: bool,
) -> Result<()> {
if cache {
self.query_and_cache_override(overname, hostname, port)
.await
} else {
Ok(())
) -> Result {
if !cache {
return Ok(());
}

if self.cache.has_override(overname).await {
return Ok(());
}

self.query_and_cache_override(overname, hostname, port)
.await
}

#[tracing::instrument(skip(self, overname, port), name = "ip")]
Expand All @@ -323,21 +326,20 @@ impl super::Service {
overname: &'_ str,
hostname: &'_ str,
port: u16,
) -> Result<()> {
) -> Result {
self.services.server.check_running()?;

debug!("querying IP for {overname:?} ({hostname:?}:{port})");
match self.resolver.resolver.lookup_ip(hostname.to_owned()).await {
| Err(e) => Self::handle_resolve_error(&e, hostname),
| Ok(override_ip) => {
if hostname != overname {
debug_info!("{overname:?} overriden by {hostname:?}");
}

self.cache.set_override(overname, &CachedOverride {
ips: override_ip.into_iter().take(MAX_IPS).collect(),
port,
expire: CachedOverride::default_expire(),
overriding: (hostname != overname)
.then_some(hostname.into())
.inspect(|_| debug_info!("{overname:?} overriden by {hostname:?}")),
});

Ok(())
Expand Down
9 changes: 5 additions & 4 deletions src/service/resolver/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct CachedOverride {
pub ips: IpAddrs,
pub port: u16,
pub expire: SystemTime,
pub overriding: Option<String>,
}

pub type IpAddrs = ArrayVec<IpAddr, MAX_IPS>;
Expand Down Expand Up @@ -63,7 +64,10 @@ pub async fn has_destination(&self, destination: &ServerName) -> bool {
#[implement(Cache)]
#[must_use]
pub async fn has_override(&self, destination: &str) -> bool {
self.get_override(destination).await.is_ok()
self.get_override(destination)
.await
.iter()
.any(CachedOverride::valid)
}

#[implement(Cache)]
Expand All @@ -85,9 +89,6 @@ pub async fn get_override(&self, name: &str) -> Result<CachedOverride> {
.await
.deserialized::<Cbor<_>>()
.map(at!(0))
.into_iter()
.find(CachedOverride::valid)
.ok_or(err!(Request(NotFound("Expired from cache"))))
}

#[implement(Cache)]
Expand Down
22 changes: 20 additions & 2 deletions src/service/resolver/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,33 @@ impl Resolve for Hooked {
}
}

#[tracing::instrument(
level = "debug",
skip_all,
fields(name = ?name.as_str())
)]
async fn hooked_resolve(
cache: Arc<Cache>,
server: Arc<Server>,
resolver: Arc<TokioAsyncResolver>,
name: Name,
) -> Result<Addrs, Box<dyn std::error::Error + Send + Sync>> {
match cache.get_override(name.as_str()).await {
| Ok(cached) => cached_to_reqwest(cached).await,
| Err(_) => resolve_to_reqwest(server, resolver, name).boxed().await,
| Ok(cached) if cached.valid() => cached_to_reqwest(cached).await,
| Ok(CachedOverride { overriding, .. }) if overriding.is_some() =>
resolve_to_reqwest(
server,
resolver,
overriding
.as_deref()
.map(str::parse)
.expect("overriding is set for this record")
.expect("overriding is a valid internet name"),
)
.boxed()
.await,

| _ => resolve_to_reqwest(server, resolver, name).boxed().await,
}
}

Expand Down

0 comments on commit a5520e8

Please sign in to comment.