Skip to content

Commit

Permalink
refactor resolver tuples into structs
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Volk <[email protected]>
  • Loading branch information
jevolk committed Jul 4, 2024
1 parent 97e55dd commit c4a2164
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/service/globals/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod client;
mod data;
pub(super) mod emerg_access;
pub(super) mod migrations;
mod resolver;
pub(crate) mod resolver;
pub(super) mod updates;

use std::{
Expand Down
22 changes: 11 additions & 11 deletions src/service/globals/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ use hickory_resolver::TokioAsyncResolver;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use ruma::OwnedServerName;

use crate::sending::FedDest;
use crate::sending::{CachedDest, CachedOverride};

type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
type TlsNameMap = HashMap<String, CachedOverride>;

pub struct Resolver {
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>,
pub hooked: Arc<Hooked>,
pub(crate) resolver: Arc<TokioAsyncResolver>,
pub(crate) hooked: Arc<Hooked>,
}

pub struct Hooked {
pub overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>,
pub(crate) struct Hooked {
overrides: Arc<RwLock<TlsNameMap>>,
resolver: Arc<TokioAsyncResolver>,
}

impl Resolver {
Expand Down Expand Up @@ -117,15 +117,15 @@ impl Resolve for Resolver {

impl Resolve for Hooked {
fn resolve(&self, name: Name) -> Resolving {
let addr_port = self
let cached = self
.overrides
.read()
.expect("locked for reading")
.get(name.as_str())
.cloned();

if let Some((addr, port)) = addr_port {
cached_to_reqwest(&addr, port)
if let Some(cached) = cached {
cached_to_reqwest(&cached.ips, cached.port)
} else {
resolve_to_reqwest(self.resolver.clone(), name)
}
Expand Down
2 changes: 1 addition & 1 deletion src/service/sending/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait;
use conduit::{Error, Result};
use data::Data;
pub use resolve::{resolve_actual_dest, FedDest};
pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest};
use ruma::{
api::{appservice::Registration, OutgoingRequest},
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
Expand Down
93 changes: 68 additions & 25 deletions src/service/sending/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{

use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
use ipaddress::IPAddress;
use ruma::ServerName;
use ruma::{OwnedServerName, ServerName};
use tracing::{debug, error, trace};

use crate::{debug_error, debug_info, debug_warn, services, Error, Result};
Expand Down Expand Up @@ -35,26 +35,39 @@ pub enum FedDest {
Named(String, String),
}

#[derive(Clone, Debug)]
pub(crate) struct ActualDest {
pub(crate) dest: FedDest,
pub(crate) host: String,
pub(crate) string: String,
pub(crate) cached: bool,
}

#[derive(Clone, Debug)]
pub struct CachedDest {
pub dest: FedDest,
pub host: String,
}

#[derive(Clone, Debug)]
pub struct CachedOverride {
pub ips: Vec<IpAddr>,
pub port: u16,
}

#[tracing::instrument(skip_all, name = "resolve")]
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
let cached;
let cached_result = services()
.globals
.resolver
.destinations
.read()
.expect("locked for reading")
.get(server_name)
.cloned();
.get_cached_destination(server_name);

let (dest, host) = if let Some(result) = cached_result {
let CachedDest {
dest,
host,
..
} = if let Some(result) = cached_result {
cached = true;
result
} else {
Expand All @@ -77,7 +90,7 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDe
/// Numbers in comments below refer to bullet points in linked section of
/// specification
#[tracing::instrument(skip_all, name = "actual")]
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedDest, String)> {
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<CachedDest> {
trace!("Finding actual destination for {dest}");
let mut host = dest.as_str().to_owned();
let actual_dest = match get_ip_with_port(dest.as_str()) {
Expand Down Expand Up @@ -109,7 +122,10 @@ pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedD
};

debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
Ok((actual_dest, host.into_uri_string()))
Ok(CachedDest {
dest: actual_dest,
host: host.into_uri_string(),
})
}

fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
Expand Down Expand Up @@ -193,14 +209,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> {
#[tracing::instrument(skip_all, name = "well-known")]
async fn request_well_known(dest: &str) -> Result<Option<String>> {
trace!("Requesting well known for {dest}");
if !services()
.globals
.resolver
.overrides
.read()
.unwrap()
.contains_key(dest)
{
if !services().globals.resolver.has_cached_override(dest) {
query_and_cache_override(dest, dest, 8448).await?;
}

Expand Down Expand Up @@ -269,15 +278,16 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
Err(e) => handle_resolve_error(&e),
Ok(override_ip) => {
if hostname != overname {
debug_info!("{:?} overriden by {:?}", overname, hostname);
debug_info!("{overname:?} overriden by {hostname:?}");
}
services()
.globals
.resolver
.overrides
.write()
.unwrap()
.insert(overname.to_owned(), (override_ip.iter().collect(), port));

services().globals.resolver.set_cached_override(
overname.to_owned(),
CachedOverride {
ips: override_ip.iter().collect(),
port,
},
);

Ok(())
},
Expand Down Expand Up @@ -392,6 +402,39 @@ fn add_port_to_hostname(dest_str: &str) -> FedDest {
FedDest::Named(host.to_owned(), port.to_owned())
}

impl crate::globals::resolver::Resolver {
pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
trace!(?name, ?dest, "set cached destination");
self.destinations
.write()
.expect("locked for writing")
.insert(name, dest)
}

pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
self.destinations
.read()
.expect("locked for reading")
.get(name)
.cloned()
}

pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
trace!(?name, ?over, "set cached override");
self.overrides
.write()
.expect("locked for writing")
.insert(name, over)
}

pub(crate) fn has_cached_override(&self, name: &str) -> bool {
self.overrides
.read()
.expect("locked for reading")
.contains_key(name)
}
}

impl FedDest {
fn into_https_string(self) -> String {
match self {
Expand Down
21 changes: 12 additions & 9 deletions src/service/sending/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ use ruma::{
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
SendAccessToken,
},
OwnedServerName, ServerName,
ServerName,
};
use tracing::{debug, trace};

use super::{resolve, resolve::ActualDest};
use super::{
resolve,
resolve::{ActualDest, CachedDest},
};
use crate::{debug_error, debug_warn, services, Error, Result};

#[tracing::instrument(skip_all, name = "send")]
Expand Down Expand Up @@ -103,13 +106,13 @@ where

let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && !actual.cached {
services()
.globals
.resolver
.destinations
.write()
.expect("locked for writing")
.insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone()));
services().globals.resolver.set_cached_destination(
dest.to_owned(),
CachedDest {
dest: actual.dest.clone(),
host: actual.host.clone(),
},
);
}

match response {
Expand Down

0 comments on commit c4a2164

Please sign in to comment.