diff --git a/Cargo.lock b/Cargo.lock index 89784b42..19e82390 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -434,9 +434,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.28" +version = "1.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" dependencies = [ "shlex", ] @@ -1405,9 +1405,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -3140,9 +3140,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -3151,9 +3151,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -3166,9 +3166,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -3178,9 +3178,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3188,9 +3188,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -3201,15 +3201,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/src/app.rs b/src/app.rs index 856b4b76..b2a288a2 100644 --- a/src/app.rs +++ b/src/app.rs @@ -7,7 +7,7 @@ use std::{ }; use tokio::{ runtime::{Handle, Runtime}, - sync::RwLock, + sync::{RwLock, Semaphore}, task::JoinSet, }; @@ -15,12 +15,11 @@ use crate::{ config::ServerOpts, dns::{DnsRequest, DnsResponse, SerialMessage}, dns_conf::RuntimeConfig, - dns_error::LookupError, dns_mw::{DnsMiddlewareBuilder, DnsMiddlewareHandler}, dns_mw_cache::DnsCache, log, server::{DnsHandle, IncomingDnsRequest, ServerHandle}, - third_ext::{FutureJoinAllExt as _, FutureTimeoutExt}, + third_ext::FutureJoinAllExt as _, }; pub struct App { @@ -75,18 +74,12 @@ impl App { cfg.summary(); - let runtime = { - use tokio::runtime; - let mut builder = runtime::Builder::new_multi_thread(); - builder.enable_all(); - if let Some(num_workers) = cfg.num_workers() { - builder.worker_threads(num_workers); - } - builder - .thread_name("smartdns-runtime") - .build() - .expect("failed to initialize Tokio Runtime") - }; + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(cfg.num_workers()) + .enable_all() + .thread_name("smartdns-runtime") + .build() + .expect("failed to initialize Tokio Runtime"); let handler = DnsMiddlewareBuilder::new().build(cfg.clone()); @@ -230,13 +223,19 @@ pub fn bootstrap(conf: Option) { const MAX_IDLE: Duration = Duration::from_secs(30 * 60); // 30 min + let background_concurrency = Arc::new(Semaphore::new(1)); + while let Some((message, server_opts, sender)) = incoming_request.recv().await { let handler = app.mw_handler.read().await.clone(); if server_opts.is_background { if Instant::now() - last_activity < MAX_IDLE { + let background_concurrency = background_concurrency.clone(); inner_join_set.spawn(async move { - let _ = sender.send(process(handler, message, server_opts).await); + if let Ok(permit) = background_concurrency.acquire_owned().await { + let _ = sender.send(process(handler, message, server_opts).await); + drop(permit); + } }); } } else { @@ -388,27 +387,7 @@ async fn process( response_header.set_authoritative(false); let response = { - let res = - handler - .search(&request, &server_opts) - .timeout(Duration::from_secs( - if server_opts.is_background { 60 } else { 5 }, - )) - .await - .unwrap_or_else(|_| { - let query = request.query().original().to_owned(); - log::warn!( - "Query {} {} {} timeout.", - query.name(), - query.query_type(), - if server_opts.is_background { - "in background" - } else { - "" - } - ); - Err(LookupError::no_records_found(query, 10)) - }); + let res = handler.search(&request, &server_opts).await; match res { Ok(lookup) => lookup, Err(e) => { diff --git a/src/dns_client.rs b/src/dns_client.rs index b5647e7f..c69285e4 100644 --- a/src/dns_client.rs +++ b/src/dns_client.rs @@ -7,46 +7,43 @@ use std::{ sync::Arc, }; +use tokio::sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore}; +use url::Host; + use crate::{ dns::DnsResponse, - libdns::proto::rr::rdata::opt::{ClientSubnet, EdnsOption}, - log, + dns_conf::NameServerInfo, + dns_error::LookupError, + dns_url::DnsUrlParamExt, + log::{self, debug, info, warn}, + proxy::ProxyConfig, + rustls::TlsClientConfigBundle, }; -use tokio::sync::RwLock; -use crate::{ - dns_url::DnsUrlParamExt, - libdns::proto::{ - error::ProtoResult, +use crate::libdns::{ + proto::{ + error::{ProtoError, ProtoErrorKind, ProtoResult}, op::{Edns, Message, MessageType, OpCode, Query}, rr::{ domain::{IntoName, Name}, + rdata::opt::{ClientSubnet, EdnsOption}, Record, RecordType, }, xfer::{DnsRequest, DnsRequestOptions, FirstAnswer}, DnsHandle, }, - proxy::ProxyConfig, - rustls::TlsClientConfigBundle, -}; - -use crate::libdns::proto::xfer::Protocol; -use crate::libdns::resolver::{ - config::{ResolverOpts, TlsClientConfig}, - name_server::GenericConnector, - TryParseIp, -}; - -use crate::{ - dns_conf::NameServerInfo, - dns_error::LookupError, - log::{debug, info, warn}, + resolver::{ + config::{ResolverOpts, ServerOrderingStrategy, TlsClientConfig}, + name_server::GenericConnector, + TryParseIp, + }, }; use bootstrap::BootstrapResolver; -use connection_provider::TokioRuntimeProvider; -use connection_provider::{ConnectionProvider, RawNameServer, RawNameServerConfig}; +use connection_provider::{ + Connection, ConnectionProvider, RawNameServerConfig, TokioRuntimeProvider, +}; /// Maximum TTL as defined in https://tools.ietf.org/html/rfc2181, 2147483647 /// Setting this to a value of 1 day, in seconds @@ -59,6 +56,7 @@ pub struct DnsClientBuilder { ca_path: Option, proxies: Arc>, client_subnet: Option, + max_cocurrency: Option, } impl DnsClientBuilder { @@ -92,6 +90,11 @@ impl DnsClientBuilder { self } + pub fn with_max_cocurrency(mut self, cocurrent: usize) -> Self { + self.max_cocurrency = Some(cocurrent); + self + } + pub async fn build(self) -> DnsClient { let DnsClientBuilder { server_infos, @@ -99,6 +102,7 @@ impl DnsClientBuilder { ca_path, proxies, client_subnet, + max_cocurrency, } = self; let tls_client_config = TlsClientConfigBundle::new(ca_path, ca_file); @@ -162,6 +166,7 @@ impl DnsClientBuilder { Some(tls_client_config.clone()), None, client_subnet, + max_cocurrency, ) { Ok(s) => { let s = Arc::new(s); @@ -244,6 +249,7 @@ impl DnsClientBuilder { Some(tls_client_config.clone()), Some(bootstrap.clone()), client_subnet, + max_cocurrency, ) { Ok(s) => { let s = Arc::new(s); @@ -293,6 +299,10 @@ impl DnsClientBuilder { }) }; + for s in server_groups.values() { + s.warmup().await; + } + DnsClient { default, bootstrap, @@ -347,6 +357,11 @@ pub struct NameServerGroup { } impl NameServerGroup { + pub async fn warmup(&self) { + for server in self.servers.iter() { + _ = server.client().await; + } + } #[inline] pub fn iter(&self) -> Iter> { self.servers.iter() @@ -401,29 +416,19 @@ impl GenericResolver for NameServerGroup { } } -pub enum NameServer { - IpAddress((Arc, Arc)), - DomainName { - domain: String, - config: RawNameServerConfig, - opts: Arc, - connection_provider: ConnectionProvider, - resolver: Arc, - inner: RwLock<(Vec>, HashSet)>, - }, -} +pub struct NameServer { + max_cocurrency: Arc, + connections: Arc>>, -impl From for NameServer { - fn from(config: RawNameServerConfig) -> Self { - Self::IpAddress(( - Arc::new(RawNameServer::new( - config, - Default::default(), - Default::default(), - )), - Default::default(), - )) - } + server: Host, + + ip_addrs: RwLock>, + + config: RawNameServerConfig, + options: Arc, + connection_provider: ConnectionProvider, + + resolver: Option>, } impl NameServer { @@ -433,23 +438,21 @@ impl NameServer { tls_client_config: Option, resolver: Option>, default_client_subnet: Option, + cocurrent: Option, ) -> anyhow::Result { - use Protocol::*; - let url = &config.server; - let ip_addr = url.ip(); - let port = url.port(); - - let socket_addr = SocketAddr::new(ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED.into()), port); + let socket_addr = { + let ip_addr = url.ip(); + let port = url.port(); + SocketAddr::new(ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED.into()), port) + }; - if ip_addr.is_none() && resolver.is_none() { + if socket_addr.ip().is_unspecified() && resolver.is_none() { anyhow::bail!("Parameter resolver is required for non-ip upstream"); } - let tls_dns_name = Some(url.host().to_string()); - - let tls_config = if url.proto().is_encrypted() { + let (tls_dns_name, tls_config) = if url.proto().is_encrypted() { let Some(tls_client_config) = tls_client_config else { anyhow::bail!("Parameter tls_client_config is required for Encrypted upstream"); }; @@ -462,12 +465,12 @@ impl NameServer { tls_client_config.normal }; - Some(TlsClientConfig(config)) + (Some(url.host().to_string()), Some(TlsClientConfig(config))) } else { - None + (None, None) }; - let opts = Arc::new(NameServerOpts::new( + let mut options = NameServerOpts::new( config.blacklist_ip, config.whitelist_ip, config.check_edns, @@ -476,176 +479,143 @@ impl NameServer { .as_ref() .map(|r| r.options().clone()) .unwrap_or_default(), - )); + ); + options.resolver_opts.server_ordering_strategy = ServerOrderingStrategy::QueryStatistics; let so_mark = config.so_mark; let device = config.interface; let provider = GenericConnector::new(TokioRuntimeProvider::new(proxy, so_mark, device)); - use crate::libdns::resolver::config::NameServerConfig; - let protocol = *url.proto(); - - let config = match protocol { - Udp => NameServerConfig { - socket_addr, - protocol, - trust_negative_responses: true, - tls_dns_name: None, - tls_config: None, - bind_addr: None, - http_endpoint: None, - }, - Tcp => NameServerConfig { - socket_addr, - protocol, - trust_negative_responses: true, - tls_dns_name: None, - tls_config: None, - bind_addr: None, - http_endpoint: None, - }, - #[cfg(feature = "dns-over-https")] - Https => NameServerConfig { - socket_addr, - protocol, - tls_dns_name, - tls_config, - trust_negative_responses: true, - bind_addr: None, - http_endpoint: Some(url.path().to_string()), - }, - #[cfg(feature = "dns-over-quic")] - Quic => NameServerConfig { - socket_addr, - protocol, - tls_dns_name, - tls_config, - trust_negative_responses: true, - bind_addr: None, - http_endpoint: None, - }, - #[cfg(feature = "dns-over-tls")] - Tls => NameServerConfig { - socket_addr, - protocol, - tls_dns_name, - tls_config, - trust_negative_responses: true, - bind_addr: None, - http_endpoint: None, - }, - #[cfg(feature = "dns-over-h3")] - H3 => NameServerConfig { - socket_addr, - protocol, - tls_dns_name, - tls_config, - trust_negative_responses: true, - bind_addr: None, - http_endpoint: Some(url.path().to_string()), - }, - _ => unimplemented!(), + let http_endpoint = url.path().map(|s| s.to_string()); + + let config = RawNameServerConfig { + socket_addr, + protocol, + trust_negative_responses: true, + tls_config, + tls_dns_name, + bind_addr: None, + http_endpoint, }; - Ok(if ip_addr.is_some() { - Self::IpAddress(( - Arc::new(RawNameServer::new( - config, - opts.as_ref().deref().clone(), - provider, - )), - opts, - )) - } else { - Self::DomainName { - domain: url.host().to_string(), - config, - opts, - connection_provider: provider, - inner: Default::default(), - resolver: resolver.unwrap(), - } + Ok(Self { + max_cocurrency: Arc::new(Semaphore::const_new(cocurrent.unwrap_or(1))), + server: url.host().to_owned(), + connections: Default::default(), + ip_addrs: Default::default(), + config, + options: options.into(), + connection_provider: provider, + resolver, }) } - async fn handle(&self) -> Option> { - match self { - NameServer::IpAddress((v, _)) => Some(v.clone()), - NameServer::DomainName { - domain, - config, - opts, - connection_provider, - inner, - resolver, - } => { - { - let read = inner.read().await; - - let (servers, _) = read.deref(); - - if let Some(first) = servers.first().cloned() { - if servers.len() > 1 { - let mut servers = servers.to_vec(); - servers.sort_unstable(); - if matches!(servers.first(), Some(s) if *s == first) { - // Still first, return directly - return Some(first); - } - } else { - // There is only one, return directly. - // todo:// determine if it failed, but it requires `is_connected` public. - // https://github.com/hickory-dns/hickory-dns/blob/78f9b27649d3ee1b9894c22aedcdc9bad2daf331/crates/resolver/src/name_server/name_server.rs#L85 - return Some(first); - } - } - } - - let ip_addrs = match resolver.lookup_ip(domain).await { + async fn ip_addrs(&self) -> Result, ProtoError> { + let mut ip_addrs = self.ip_addrs.read().await.clone(); + if !ip_addrs.is_empty() { + return Ok(ip_addrs); + } + match &self.server { + Host::Domain(domain) => { + let resolver = self + .resolver + .as_ref() + .expect("resolver must be set when using domain name"); + + ip_addrs = match resolver.lookup_ip(domain).await { Ok(lookup_ip) => lookup_ip.ip_addrs().into_iter().collect::>(), Err(err) => { warn!("lookup ip: {domain} failed, {err}"); vec![] } - }; + } + .into(); - let mut write = inner.write().await; - let (servers, seen) = write.deref_mut(); + if ip_addrs.is_empty() { + return Err(ProtoErrorKind::NoConnections.into()); + } else { + *self.ip_addrs.write().await.deref_mut() = ip_addrs.clone(); + } + Ok(ip_addrs) + } + Host::Ipv4(ip) => { + ip_addrs = vec![IpAddr::V4(*ip)].into(); + *self.ip_addrs.write().await.deref_mut() = ip_addrs.clone(); + Ok(ip_addrs) + } + Host::Ipv6(ip) => { + ip_addrs = vec![IpAddr::V6(*ip)].into(); + *self.ip_addrs.write().await.deref_mut() = ip_addrs.clone(); + Ok(ip_addrs) + } + } + } - for ip_addr in ip_addrs { - if seen.contains(&ip_addr) { - continue; + async fn client(&self) -> Result { + let cocurrent_permit = self + .max_cocurrency + .clone() + .acquire_owned() + .await + .map_err(|_| ProtoErrorKind::Busy)?; + + let conn = self.connections.lock().await.pop(); + + let conn = match conn { + Some(conn) => conn, + None => { + let config = self.config.clone(); + let options = self.options.as_ref().deref().clone(); + let provider = self.connection_provider.clone(); + + match &self.server { + Host::Domain(_) => { + let configs = self + .ip_addrs() + .await? + .iter() + .map(|ip| { + let mut config = config.clone(); + config.socket_addr.set_ip(*ip); + config + }) + .collect::>(); + + Connection::from_config(configs.into(), options, provider) } - - let mut config = config.clone(); - config.socket_addr.set_ip(ip_addr); - - let opts = opts.clone(); - let connection_provider = connection_provider.clone(); - - let server = Arc::new(RawNameServer::new( - config, - opts.as_ref().deref().clone(), - connection_provider, - )); - seen.insert(ip_addr); - servers.push(server); + _ => Connection::from_config(vec![config].into(), options, provider), } - - servers.sort_unstable(); - - servers.first().cloned() } - } + }; + + Ok(ClientHandle { + connections: self.connections.clone(), + connection: Some((conn, cocurrent_permit)), + }) } #[inline] pub fn options(&self) -> &NameServerOpts { - match self { - NameServer::IpAddress((_, opts)) => opts, - NameServer::DomainName { opts, .. } => opts, + &self.options + } +} + +impl From for NameServer { + fn from(config: RawNameServerConfig) -> Self { + Self { + max_cocurrency: Arc::new(Semaphore::const_new(1)), + server: match config.socket_addr.ip() { + IpAddr::V4(ipv4_addr) => Host::Ipv4(ipv4_addr), + IpAddr::V6(ipv6_addr) => Host::Ipv6(ipv6_addr), + }, + config, + ip_addrs: Default::default(), + connections: Default::default(), + options: Default::default(), + connection_provider: Default::default(), + resolver: Default::default(), } } } @@ -694,7 +664,9 @@ impl GenericResolver for NameServer { request_options, ); - let Some(ns) = self.handle().await else { + let client = self.client().await?; + + let Some(ns) = client.connection.as_ref().map(|(conn, _)| conn) else { return Err(ProtoErrorKind::NoConnections.into()); }; @@ -704,6 +676,24 @@ impl GenericResolver for NameServer { } } +struct ClientHandle { + connections: Arc>>, + connection: Option<(Connection, OwnedSemaphorePermit)>, +} + +impl Drop for ClientHandle { + fn drop(&mut self) { + let connections = self.connections.clone(); + let connection = self.connection.take(); + tokio::spawn(async move { + if let Some((connection, permit)) = connection { + connections.lock().await.push(connection); + drop(permit); + } + }); + } +} + #[derive(Clone)] pub struct NameServerOpts { /// filter result with blacklist ip @@ -983,6 +973,9 @@ mod connection_provider { pub type RawNameServer = crate::libdns::resolver::name_server::NameServer>; pub type RawNameServerConfig = crate::libdns::resolver::config::NameServerConfig; + pub type Connection = crate::libdns::resolver::name_server::GenericNameServerPool< + connection_provider::TokioRuntimeProvider, + >; pub type ConnectionProvider = GenericConnector; /// The Tokio Runtime for async execution diff --git a/src/dns_conf.rs b/src/dns_conf.rs index 9cb8c979..9a9ea2ba 100644 --- a/src/dns_conf.rs +++ b/src/dns_conf.rs @@ -107,6 +107,8 @@ impl RuntimeConfig { pub fn summary(&self) { info!(r#"whoami 👉 {}"#, self.server_name()); + info!(r#"num workers: {}"#, self.num_workers()); + for server in self.nameservers.iter() { if !server.exclude_default_group && server.group.is_empty() { continue; @@ -182,8 +184,10 @@ impl RuntimeConfig { /// The number of worker threads #[inline] - pub fn num_workers(&self) -> Option { + pub fn num_workers(&self) -> usize { + use std::num::NonZeroUsize; self.num_workers + .unwrap_or(std::thread::available_parallelism().map_or(1, NonZeroUsize::get)) } pub fn listeners(&self) -> &[ListenerConfig] { diff --git a/src/dns_mw_cache.rs b/src/dns_mw_cache.rs index 26f80ac0..721639db 100644 --- a/src/dns_mw_cache.rs +++ b/src/dns_mw_cache.rs @@ -1,8 +1,7 @@ -use std::collections::HashMap; - use chrono::DateTime; use chrono::Local; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fs::File; use std::io::Read; use std::num::NonZeroUsize; @@ -41,7 +40,12 @@ pub struct DnsCacheMiddleware { impl DnsCacheMiddleware { pub fn new(cfg: &Arc, dns_handle: DnsHandle) -> Self { - let cache = DnsCache::new(cfg.cache_size()); + let cache = DnsCache::new( + cfg.cache_size(), + cfg.serve_expired(), + cfg.serve_expired_ttl(), + cfg.serve_expired_reply_ttl(), + ); if cfg.cache_persist() { let cache_file = cfg.cache_file(); @@ -94,15 +98,8 @@ impl DnsCacheMiddleware { let prefetch_notify = self.prefetch_notify.clone(); let client = self.bg_client.clone(); - let cache = self.cache.cache(); + let cache = self.cache.clone(); tokio::spawn(async move { - let num_workers = std::cmp::max( - tokio::runtime::Handle::current().metrics().num_workers() / 5, - 1, - ); - - let concurrent = Arc::new(tokio::sync::Semaphore::new(num_workers)); - let min_interval = Duration::from_secs( std::env::var("PREFETCH_MIN_INTERVAL") .as_deref() @@ -116,89 +113,59 @@ impl DnsCacheMiddleware { prefetch_notify.notified().await; let now = Instant::now(); - let mut most_recent; + let most_recent; if now - last_check > min_interval { last_check = now; - most_recent = Duration::from_secs(MAX_TTL as u64); - let mut expired = vec![]; - - { - let mut cache = cache.lock().await; - let len = cache.len(); - if len == 0 { - continue; - } - - for (query, entry) in cache.iter_mut() { - if entry.is_in_prefetching { - continue; - } - // only prefetch query type ip addr - if !query.query_type().is_ip_addr() { - continue; - } - - if entry.is_current(now) { - most_recent = most_recent.min(entry.ttl(now)); - continue; - } - - entry.is_in_prefetching = true; + let expired = { + let (expired, most_recent0) = cache.get_expired(now, Some(5)).await; - expired.push(query.to_owned()); - } debug!( "Domain prefetch check(total: {}), elapsed {:?}", - len, + cache.cache().lock().await.len(), now.elapsed() ); - } + + most_recent = most_recent0; + + expired + }; if !expired.is_empty() { for query in expired { let client = client.clone(); - let cache = cache.clone(); - let concurrent = concurrent.clone(); + let cache = cache.cache(); tokio::spawn(async move { - match concurrent.acquire().await { - Ok(_) => { - let now = Instant::now(); - let mut message = Message::new(); - message.add_query(query.clone()); - let serial_message = client.send(message.into()).await; - - if let Ok(message) = Message::try_from(serial_message) { - if let Some(entry) = cache.lock().await.peek_mut(&query) - { - let data = message.into(); - entry.set_data(data); - entry.set_valid_until( - Instant::now() - + Duration::from_secs( - entry - .data - .min_ttl() - .unwrap_or_default() - .min(600) - .into(), - ), - ) - } - } - - debug!( - "Prefetch domain {} {}, elapsed {:?}", - query.name(), - query.query_type(), - now.elapsed() - ); - } - Err(err) => { - log::error!("{:?}", err); + let now = Instant::now(); + let mut message = Message::new(); + message.add_query(query.clone()); + let serial_message = client.send(message.into()).await; + + if let Ok(message) = Message::try_from(serial_message) { + if let Some(entry) = cache.lock().await.peek_mut(&query) { + let data = message.into(); + entry.set_data(data); + entry.set_valid_until( + Instant::now() + + Duration::from_secs( + entry + .data + .min_ttl() + .unwrap_or_default() + .min(600) + .into(), + ), + ) } } + + debug!( + "Prefetch domain {} {}, elapsed {:?}", + query.name(), + query.query_type(), + now.elapsed() + ); }); } } @@ -232,32 +199,27 @@ impl Middleware for DnsCacheMiddl let cached_res = if ctx.server_opts.is_background { None } else { - let cached_res = self.cache.get(&query, Instant::now()).await; - - if let Some((outdate, res)) = cached_res.as_ref() { - match outdate { - OutOfDate::No => { - let name_server_group = ctx.server_group_name(); - // check if it's the same nameserver group. - if matches!(res, Ok(r) if r.name_server_group() == Some(name_server_group)) - { - debug!("name: {} using caching", query.name()); - ctx.source = LookupFrom::Cache; - return res.clone(); - } - } - OutOfDate::Yes => { - if self.cfg.serve_expired() { - if let Ok(res) = res { - if matches!(res.max_ttl(), Some(ttl) if ttl < self.cfg.serve_expired_ttl() as u32 ) - { - let mut res = res.clone(); - res.set_max_ttl(self.cfg.serve_expired_reply_ttl() as u32); - return Ok(res); - } - } - } - } + let no_serve_expired = ctx + .domain_rule + .get(|r| r.no_serve_expired) + .unwrap_or_default(); + + let cached_res = self + .cache + .get(&query, Instant::now(), no_serve_expired) + .await; + + if let Some(res) = cached_res.as_ref() { + let name_server_group = ctx.server_group_name(); + // check if it's the same nameserver group. + if matches!(res, Ok(res) if res.name_server_group() == Some(name_server_group)) { + debug!( + "name: {} {} using caching", + query.name(), + query.query_type() + ); + ctx.source = LookupFrom::Cache; + return res.clone(); } } @@ -281,10 +243,12 @@ impl Middleware for DnsCacheMiddl ) .await; - if let Some(ttl) = lookup.min_ttl() { - self.prefetch_notify - .notify_after(Duration::from_secs(ttl as u64)) - .await; + if ctx.cfg().prefetch_domain() { + if let Some(ttl) = lookup.min_ttl() { + self.prefetch_notify + .notify_after(Duration::from_secs(ttl as u64)) + .await; + } } } Ok(lookup) @@ -292,7 +256,7 @@ impl Middleware for DnsCacheMiddl Err(err) => { // try to return expired result. if ctx.cfg().serve_expired() { - if let Some((_, Ok(res))) = cached_res { + if let Some(Ok(res)) = cached_res { return Ok(res); } } @@ -319,8 +283,8 @@ impl DomainPrefetchingNotify { if duration.is_zero() { self.notity.notify_one() } else { + let tick = *self.tick.read().await; let now = Instant::now(); - let tick = *(self.tick.read().await); let next_tick = now + duration; if tick > now && next_tick > tick { debug!( @@ -356,15 +320,28 @@ const MAX_TTL: u32 = 86400_u32; /// An LRU eviction cache specifically for storing DNS records pub struct DnsCache { cache: Arc>>, + serve_expired: bool, + expired_ttl: u64, + expired_reply_ttl: u64, } impl DnsCache { - fn new(cache_size: usize) -> Self { + fn new( + cache_size: usize, + serve_expired: bool, + expired_ttl: u64, + expired_reply_ttl: u64, + ) -> Self { let cache = Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(cache_size).unwrap(), ))); - Self { cache } + Self { + cache, + serve_expired, + expired_ttl, + expired_reply_ttl, + } } fn cache(&self) -> Arc>> { @@ -385,7 +362,8 @@ impl DnsCache { query_type: query.query_type(), query_class: query.query_class(), records: entry.data.records().to_vec().into_boxed_slice(), - last_access: entry.last_access, + hits: entry.stats.hits, + last_access: entry.stats.last_access, }) .collect() } @@ -518,7 +496,8 @@ impl DnsCache { &self, query: &Query, now: Instant, - ) -> Option<(OutOfDate, Result)> { + no_serve_expired: bool, + ) -> Option> { let mut cache = match self.cache.try_lock() { Ok(t) => t, Err(err) => { @@ -528,19 +507,22 @@ impl DnsCache { }; let mut expired = false; - let lookup = cache.get_mut(query).map(|value| { - value.last_access = Local::now(); + let lookup = cache.get_mut(query).and_then(|value| { + value.stats.hit(); if value.is_current(now) { let mut res = value.data.clone(); res.set_max_ttl(value.ttl(now).as_secs() as u32); - - (OutOfDate::No, Ok(res)) + Some(Ok(res)) + } else if !no_serve_expired + && self.serve_expired + && value.is_current(now - Duration::from_secs(self.expired_ttl)) + { + let mut res = value.data.clone(); + res.set_max_ttl(self.expired_reply_ttl as u32); + Some(Ok(res)) } else { expired = true; - let negative_ttl = now - value.valid_until; - let mut res = value.data.clone(); - res.set_new_ttl(negative_ttl.as_secs() as u32); - (OutOfDate::Yes, Ok(res)) + None } }); @@ -549,28 +531,67 @@ impl DnsCache { } lookup } + + async fn get_expired( + &self, + now: Instant, + seconds_ahead: Option, + ) -> (Vec, Duration) { + let mut cache = self.cache.lock().await; + let mut most_recent = Duration::from_secs(MAX_TTL as u64); + + if !cache.is_empty() { + let mut expired = vec![]; + let now = if self.expired_ttl > 0 { + now - Duration::from_secs(self.expired_ttl) + } else { + now + } + Duration::from_secs(seconds_ahead.unwrap_or(5)); // 5 seconds ahead + + for (query, entry) in cache.iter_mut() { + if entry.is_in_prefetching { + continue; + } + // only prefetch query type ip addr + if !query.query_type().is_ip_addr() { + continue; + } + + if entry.is_current(now) { + most_recent = most_recent.min(entry.ttl(now)); + continue; + } + + entry.is_in_prefetching = true; + + expired.push((query.to_owned(), entry.stats.hits)); + } + drop(cache); + + expired.sort_by_key(|(_, hits)| std::cmp::Reverse(*hits)); + + (expired.into_iter().map(|(q, _)| q).collect(), most_recent) + } else { + (Vec::with_capacity(0), most_recent) + } + } } #[derive(Deserialize, Serialize)] pub struct CachedQueryRecord { name: Name, + hits: usize, last_access: DateTime, query_type: RecordType, query_class: DNSClass, records: Box<[Record]>, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum OutOfDate { - Yes, - No, -} - struct DnsCacheEntry { data: T, valid_until: Instant, is_in_prefetching: bool, - last_access: DateTime, + stats: DnsCacheStats, } impl DnsCacheEntry { @@ -579,7 +600,7 @@ impl DnsCacheEntry { data, valid_until, is_in_prefetching: false, - last_access: Local::now(), + stats: DnsCacheStats::new(), } } @@ -603,51 +624,40 @@ impl DnsCacheEntry { } } -mod lookup { - - use crate::dns::DnsResponse; - use std::ops::Deref; - use std::time::Instant; - - use crate::libdns::proto::{ - error::ProtoResult, - op::Message, - serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder}, - }; +struct DnsCacheStats { + /// The number of lookups that have been performed + hits: usize, + last_access: DateTime, +} - pub fn serialize(lookups: &[DnsResponse], writer: &mut impl std::io::Write) -> ProtoResult<()> { - let mut buf = vec![]; - for lookup in lookups { - { - let mut encoder = BinEncoder::new(&mut buf); - serialize_one(lookup, &mut encoder)?; - } - writer.write_all(&buf)?; - buf.truncate(0); +impl DnsCacheStats { + fn new() -> Self { + Self { + hits: 0, + last_access: Local::now(), } + } - Ok(()) + fn hit(&mut self) { + self.hits += 1; + self.last_access = Local::now(); } +} - pub fn deserialize(data: &[u8]) -> ProtoResult> { - let mut lookups = vec![]; - let mut offset = 0; +use crate::libdns::proto::serialize::binary::{ + BinDecodable, BinDecoder, BinEncodable, BinEncoder, DecodeError, +}; - while offset < data.len() { - let mut decoder = BinDecoder::new(&data[offset..]); - lookups.push(deserialize_one(&mut decoder)?); - offset += decoder.index(); - } +impl BinEncodable for DnsCacheEntry { + fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> { + let res = &self.data; - Ok(lookups) - } - pub fn serialize_one(res: &DnsResponse, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> { - if let Some(group_name) = res.name_server_group().map(|n| n.as_bytes()) { - encoder.emit_u16(group_name.len() as u16)?; - encoder.emit_vec(&group_name[0..(group_name.len() as u16 as usize)])?; - } else { - encoder.emit_u16(0)?; - } + // message + encoder.emit_u8(1)?; + res.deref().emit(encoder)?; + + // valid_until + encoder.emit_u8(2)?; let valid_until_bytes = unsafe { std::slice::from_raw_parts( (&res.valid_until() as *const Instant) as *const u8, @@ -655,11 +665,44 @@ mod lookup { ) }; encoder.emit_vec(valid_until_bytes)?; - res.deref().emit(encoder)?; + + // group_name + encoder.emit_u8(3)?; + if let Some(group_name) = res.name_server_group().map(|n| n.as_bytes()) { + encoder.emit_u16(group_name.len() as u16)?; + encoder.emit_vec(&group_name[0..(group_name.len() as u16 as usize)])?; + } else { + encoder.emit_u16(0)?; + } + + // hits + encoder.emit_u8(4)?; + encoder.emit_u32(self.stats.hits as u32)?; Ok(()) } +} - pub fn deserialize_one(decoder: &mut BinDecoder<'_>) -> ProtoResult { +impl<'r> BinDecodable<'r> for DnsCacheEntry { + fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult { + // message + if !decoder.read_u8()?.verify(|v| *v == 1).is_valid() { + return Err(DecodeError::InsufficientBytes.into()); + } + let message = Message::read(decoder)?; + + // valid_until + if !decoder.read_u8()?.verify(|v| *v == 2).is_valid() { + return Err(DecodeError::InsufficientBytes.into()); + } + let valid_until_bytes = decoder + .read_slice(std::mem::size_of::())? + .unverified(); + let valid_until = unsafe { std::ptr::read(valid_until_bytes.as_ptr() as *const Instant) }; + + // group_name + if !decoder.read_u8()?.verify(|v| *v == 3).is_valid() { + return Err(DecodeError::InsufficientBytes.into()); + } let group_name = { let name_len = decoder.read_u16()?.unverified(); if name_len > 0 { @@ -669,19 +712,55 @@ mod lookup { None } }; - let valid_until_bytes = decoder - .read_slice(std::mem::size_of::())? - .unverified(); - let valid_until = unsafe { std::ptr::read(valid_until_bytes.as_ptr() as *const Instant) }; - let message = Message::read(decoder)?; + // hits + if !decoder.read_u8()?.verify(|v| *v == 4).is_valid() { + return Err(DecodeError::InsufficientBytes.into()); + } + let hits = decoder.read_u32()?.unverified(); + + // construct the response let mut res: DnsResponse = message.into(); res = res.with_valid_until(valid_until); if let Some(g) = group_name { res = res.with_name_server_group(g); } + let valid_until = res.valid_until(); + let mut entry = DnsCacheEntry::new(res, valid_until); + entry.stats.hits = hits as usize; - Ok(res) + Ok(entry) + } +} + +impl DnsCacheEntry { + fn serialize_many<'a>( + entries: impl Iterator, + writer: &mut impl std::io::Write, + ) -> ProtoResult<()> { + let mut buf = vec![]; + + for entry in entries { + buf.truncate(0); + let mut encoder = BinEncoder::new(&mut buf); + if (*entry).emit(&mut encoder).is_ok() { + let _ = writer.write_all(&buf); + } + } + Ok(()) + } + + fn deserialize_many(data: &[u8]) -> ProtoResult> { + let mut entries = vec![]; + let mut offset = 0; + + while offset < data.len() { + let mut decoder = BinDecoder::new(&data[offset..]); + entries.push(DnsCacheEntry::read(&mut decoder)?); + offset += decoder.index(); + } + + Ok(entries) } } @@ -694,23 +773,17 @@ trait PersistCache { impl PersistCache for LruCache { fn persist>(&self, path: P) { let path = path.as_ref(); - fn cache_to_file(lookups: &[DnsResponse], path: &Path) -> ProtoResult<()> { + let cache_to_file = || { let mut file = File::options() .create(true) .truncate(true) .write(true) .open(path)?; + let entries = self.iter().map(|(_, entry)| entry); + DnsCacheEntry::serialize_many(entries, &mut file) + }; - lookup::serialize(lookups, &mut file)?; - Ok(()) - } - - let lookups = self - .iter() - .map(|(_, entry)| entry.data.clone()) - .collect::>(); - - match cache_to_file(&lookups, path) { + match cache_to_file() { Ok(_) => info!("save DNS cache to file {:?} successfully.", path), Err(err) => error!("failed to save DNS cache to file {}", err), } @@ -721,23 +794,21 @@ impl PersistCache for LruCache { info!("reading DNS cache from file: {:?}", path); let now = Instant::now(); - fn read_from_cache_file(path: &Path) -> ProtoResult> { + let read_from_cache_file = || { let mut file = File::options().read(true).open(path)?; let mut data = vec![]; file.read_to_end(&mut data)?; - lookup::deserialize(&data) - } - match read_from_cache_file(path) { - Ok(lookups) => { - let count = lookups.len(); - let cache = self; - for lookup in lookups { - let query = lookup.query().clone().clone(); - cache.put(query, { - let valid_until = lookup.valid_until(); - DnsCacheEntry::new(lookup, valid_until) - }); + DnsCacheEntry::deserialize_many(&data) + }; + + match read_from_cache_file() { + Ok(entries) => { + let count = entries.len(); + let cache = self; + for entry in entries { + let query = entry.data.query().clone(); + cache.put(query, entry); } info!( "DNS cache {} records loaded, elapsed {:?}", @@ -755,13 +826,16 @@ mod tests { use super::*; - fn create_lookup(name: &str, data: RData, ttl: u64) -> DnsResponse { + fn create_lookup(name: &str, data: RData, ttl: u64) -> DnsCacheEntry { let name: Name = name.parse().unwrap(); let ttl = Duration::from_secs(ttl); let query = Query::query(name.clone(), data.record_type()); let records = vec![Record::from_rdata(name, ttl.as_secs() as u32, data)]; let valid_until = Instant::now() + ttl; - DnsResponse::new_with_deadline(query, records, valid_until) + DnsCacheEntry::new( + DnsResponse::new_with_deadline(query, records, valid_until), + valid_until, + ) } #[test] @@ -776,13 +850,13 @@ mod tests { ]; let mut data = vec![]; - lookup::serialize(&lookups, &mut data).unwrap(); - let lookup2 = lookup::deserialize(&data).unwrap(); + DnsCacheEntry::serialize_many(lookups.iter(), &mut data).unwrap(); + let lookup2 = DnsCacheEntry::deserialize_many(&data).unwrap(); assert_eq!(lookup2.len(), lookups.len()); - assert_eq!(&lookups[0], &lookup2[0]); - assert_eq!(&lookups[1], &lookup2[1]); + assert_eq!(&lookups[0].data, &lookup2[0].data); + assert_eq!(&lookups[1].data, &lookup2[1].data); } #[tokio::test] @@ -798,14 +872,14 @@ mod tests { 3000, ); - let cache = DnsCache::new(10); + let cache = DnsCache::new(10, true, 30, 5); let now = Instant::now(); cache .insert_records( - lookup1.query().clone(), - lookup1.record_iter().cloned(), + lookup1.data.query().clone(), + lookup1.data.record_iter().cloned(), now, "default", ) @@ -813,8 +887,8 @@ mod tests { cache .insert_records( - lookup2.query().clone(), - lookup2.record_iter().cloned(), + lookup2.data.query().clone(), + lookup2.data.record_iter().cloned(), now, "default", ) @@ -822,7 +896,7 @@ mod tests { sleep(Duration::from_millis(500)).await; - assert!(cache.get(lookup1.query(), now).await.is_some()); + assert!(cache.get(lookup1.data.query(), now, false).await.is_some()); { let lru_cache = cache.cache(); @@ -831,7 +905,7 @@ mod tests { lru_cache.persist("./logs/smartdns-test.cache"); - assert!(lru_cache.get(lookup1.query()).is_some()); + assert!(lru_cache.get(lookup1.data.query()).is_some()); lru_cache.clear(); @@ -844,26 +918,24 @@ mod tests { assert!(lru_cache .iter() .map(|(q, _)| q) - .any(|q| q == lookup1.query())); + .any(|q| q == lookup1.data.query())); assert!(lru_cache .iter() .map(|(q, _)| q) - .any(|q| q == lookup2.query())); + .any(|q| q == lookup2.data.query())); - assert!(lru_cache.contains(lookup1.query())); - assert!(lru_cache.contains(lookup2.query())); + assert!(lru_cache.contains(lookup1.data.query())); + assert!(lru_cache.contains(lookup2.data.query())); }; - let res = cache.get(lookup1.query(), now).await; + let res = cache.get(lookup1.data.query(), now, false).await; assert!(res.is_some()); - let (out_of_date, res) = res.unwrap(); - - assert_eq!(out_of_date, OutOfDate::No); + let res = res.unwrap(); let lookup = res.unwrap(); - assert_eq!(lookup.query(), lookup1.query()); - assert_eq!(lookup.records(), lookup1.records()); + assert_eq!(lookup.query(), lookup1.data.query()); + assert_eq!(lookup.records(), lookup1.data.records()); } } diff --git a/src/dns_url.rs b/src/dns_url.rs index e6ff84a1..a9bdc6af 100644 --- a/src/dns_url.rs +++ b/src/dns_url.rs @@ -47,13 +47,13 @@ impl DnsUrl { self.port() == dns_proto_default_port(&self.proto) } - pub fn path(&self) -> &str { + pub fn path(&self) -> Option<&str> { match self.proto { Protocol::Https | Protocol::H3 => match self.path.as_ref() { - Some(p) if !p.is_empty() => p, - _ => "/dns-query", + Some(p) if !p.is_empty() => Some(p), + _ => Some("/dns-query"), }, - _ => "", + _ => None, } } @@ -227,8 +227,8 @@ impl std::fmt::Display for DnsUrl { } // path - if matches!(self.proto, Protocol::Https | Protocol::H3) { - write!(f, "{}", self.path())?; + if let Some(path) = self.path() { + write!(f, "{}", path)?; } // fragment @@ -367,7 +367,7 @@ mod tests { assert_eq!(url.proto, Protocol::Udp); assert_eq!(url.host.to_string(), "8.8.8.8"); assert_eq!(url.port(), 53); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "udp://8.8.8.8"); assert!(url.ip().is_some()); } @@ -378,7 +378,7 @@ mod tests { assert_eq!(url.proto, Protocol::Udp); assert_eq!(url.host.to_string(), "8.8.8.8"); assert_eq!(url.port(), 53); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "udp://8.8.8.8"); assert!(url.ip().is_some()); } @@ -389,7 +389,7 @@ mod tests { assert_eq!(url.proto, Protocol::Udp); assert_eq!(url.host.to_string(), "1.1.1.1"); assert_eq!(url.port(), 8053); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "udp://1.1.1.1:8053"); assert!(url.ip().is_some()); } @@ -407,7 +407,7 @@ mod tests { assert_eq!(url.proto, Protocol::Tcp); assert_eq!(url.host.to_string(), "8.8.8.8"); assert_eq!(url.port(), 53); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "tcp://8.8.8.8"); } @@ -417,7 +417,7 @@ mod tests { assert_eq!(url.proto, Protocol::Tcp); assert_eq!(url.host.to_string(), "8.8.8.8"); assert_eq!(url.port(), 8053); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "tcp://8.8.8.8:8053"); } @@ -428,7 +428,7 @@ mod tests { assert_eq!(url.proto, Protocol::Tls); assert_eq!(url.host.to_string(), "8.8.8.8"); assert_eq!(url.port(), 853); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "tls://8.8.8.8"); } @@ -439,7 +439,7 @@ mod tests { assert_eq!(url.proto, Protocol::Tls); assert_eq!(url.host.to_string(), "8.8.8.8"); assert_eq!(url.port(), 953); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "tls://8.8.8.8:953"); } @@ -451,7 +451,7 @@ mod tests { assert_eq!(url.proto, Protocol::Tls); assert_eq!(url.host.to_string(), "dns.google"); assert_eq!(url.port(), 953); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "tls://dns.google:953"); assert_eq!(url.ip(), "8.8.8.8".parse().ok()) } @@ -463,7 +463,7 @@ mod tests { assert_eq!(url.proto, Protocol::Https); assert_eq!(url.host.to_string(), "dns.google"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/dns-query"); + assert_eq!(url.path(), Some("/dns-query")); assert_eq!(url.to_string(), "https://dns.google/dns-query"); assert!(url.ip().is_none()); } @@ -475,7 +475,7 @@ mod tests { assert_eq!(url.proto, Protocol::Https); assert_eq!(url.host.to_string(), "dns.google"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/dns-query1"); + assert_eq!(url.path(), Some("/dns-query1")); assert_eq!(url.to_string(), "https://dns.google/dns-query1"); assert!(url.ip().is_none()); } @@ -488,7 +488,7 @@ mod tests { assert_eq!(url.proto, Protocol::Https); assert_eq!(url.host.to_string(), "dns.google"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/dns-query"); + assert_eq!(url.path(), Some("/dns-query")); assert_eq!(url.to_string(), "https://dns.google/dns-query"); assert!(url.ip().is_none()); } @@ -501,7 +501,7 @@ mod tests { assert_eq!(url.proto, Protocol::Quic); assert_eq!(url.host.to_string(), "dns.adguard-dns.com"); assert_eq!(url.port(), 853); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "quic://dns.adguard-dns.com"); assert!(url.ip().is_none()); } @@ -514,7 +514,7 @@ mod tests { assert_eq!(url.proto, Protocol::H3); assert_eq!(url.host.to_string(), "dns.adguard-dns.com"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/dns-query"); + assert_eq!(url.path(), Some("/dns-query")); assert_eq!(url.to_string(), "h3://dns.adguard-dns.com/dns-query"); assert!(url.ip().is_none()); } @@ -527,7 +527,7 @@ mod tests { assert_eq!(url.proto, Protocol::H3); assert_eq!(url.host.to_string(), "dns.adguard-dns.com"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/dns-query"); + assert_eq!(url.path(), Some("/dns-query")); assert_eq!(url.to_string(), "https://dns.adguard-dns.com/dns-query#h3"); assert!(url.ip().is_none()); } @@ -540,7 +540,7 @@ mod tests { assert_eq!(url.proto, Protocol::H3); assert_eq!(url.host.to_string(), "dns.adguard-dns.com"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/dns-query"); + assert_eq!(url.path(), Some("/dns-query")); assert_eq!(url.to_string(), "https://dns.adguard-dns.com/dns-query#h3"); assert!(url.ip().is_none()); } @@ -553,7 +553,7 @@ mod tests { assert_eq!(url.proto, Protocol::H3); assert_eq!(url.host.to_string(), "dns.adguard-dns.com"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/dns-query"); + assert_eq!(url.path(), Some("/dns-query")); assert_eq!(url.to_string(), "https://dns.adguard-dns.com/dns-query#h3"); assert!(url.ip().is_none()); } @@ -566,7 +566,7 @@ mod tests { assert_eq!(url.proto, Protocol::H3); assert_eq!(url.host.to_string(), "dns.adguard-dns.com"); assert_eq!(url.port(), 443); - assert_eq!(url.path(), "/2dns-query"); + assert_eq!(url.path(), Some("/2dns-query")); assert_eq!(url.to_string(), "https://dns.adguard-dns.com/2dns-query#h3"); assert!(url.ip().is_none()); } @@ -585,7 +585,7 @@ mod tests { assert_eq!(url.proto, Protocol::Udp); assert_eq!(url.host.to_string(), "127.0.0.1"); assert_eq!(url.port(), 1053); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "udp://127.0.0.1:1053"); assert!(url.ip().is_some()); } @@ -596,7 +596,7 @@ mod tests { assert_eq!(url.proto, Protocol::Udp); assert_eq!(url.host.to_string(), "[240e:1f:1::1]"); assert_eq!(url.port(), 53); - assert_eq!(url.path(), ""); + assert_eq!(url.path(), None); assert_eq!(url.to_string(), "udp://[240e:1f:1::1]"); assert!(url.ip().is_some()); } diff --git a/src/main.rs b/src/main.rs index 0355b6cc..eca7140f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -206,6 +206,7 @@ impl RuntimeConfig { builder = builder.with_client_subnet(subnet); } builder = builder.with_proxies(proxies); + builder = builder.with_max_cocurrency(self.num_workers()); builder.build().await } }