diff --git a/Cargo.lock b/Cargo.lock index b84f4633..10ccb0db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1570,9 +1570,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "lru" -version = "0.12.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2994eeba8ed550fd9b47a0b38f0242bc3344e496483c6180b69139cc2fa5d1d7" +checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" [[package]] name = "lru-cache" diff --git a/Cargo.toml b/Cargo.toml index fe498acd..f23d4d0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "smartdns" version = "0.8.4" authors = ["YISH "] edition = "2021" -rust-version = "1.70.0" +rust-version = "1.75.0" keywords = ["DNS", "BIND", "dig", "named", "dnssec", "SmartDNS", "Dnsmasq"] categories = ["network-programming"] diff --git a/src/app.rs b/src/app.rs index f25bf7d1..8da3ff76 100644 --- a/src/app.rs +++ b/src/app.rs @@ -15,11 +15,12 @@ use crate::{ config::ServerOpts, dns::{DnsRequest, DnsResponse, SerialMessage}, dns_conf::RuntimeConfig, + dns_error::LookupError, dns_mw::{DnsMiddlewareBuilder, DnsMiddlewareHandler}, dns_mw_cache::DnsCache, log, server::{DnsHandle, IncomingDnsRequest, ServerHandle}, - third_ext::FutureJoinAllExt as _, + third_ext::{FutureJoinAllExt as _, FutureTimeoutExt}, }; pub struct App { @@ -412,7 +413,28 @@ async fn process( response_header.set_authoritative(false); let response = { - match handler.search(&request, &server_opts).await { + let res = + handler + .search(&request, &server_opts) + .timeout(Duration::from_secs( + if server_opts.is_background { 60 } else { 5 }, + )) + .await + .unwrap_or_else(|_| { + let query = request.query().original().to_owned(); + log::warn!( + "Query {} {} {} timeout.", + query.name(), + query.query_type(), + if server_opts.is_background { + "in background" + } else { + "" + } + ); + Err(LookupError::no_records_found(query, 10)) + }); + match res { Ok(lookup) => lookup, Err(e) => { if e.is_nx_domain() { diff --git a/src/cli.rs b/src/cli.rs index 00684127..cb794158 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -138,6 +138,10 @@ struct CompatibleCli { /// Verbose screen. #[arg(short = 'x', long)] verbose: bool, + + /// ignore segment fault signal + #[arg(short = 'S')] + segment_fault_signal: bool, } impl From for Cli { @@ -147,6 +151,7 @@ impl From for Cli { pid, verbose, foreground, + segment_fault_signal: _, }: CompatibleCli, ) -> Self { if !foreground { @@ -299,4 +304,18 @@ mod tests { assert_eq!(cli.log_level(), Some(log::Level::INFO)); } + + #[test] + fn test_cli_args_parse_compatible_run_4() { + let cli = Cli::parse_from(["smartdns", "-f", "-c", "/etc/smartdns.conf", "-S"]); + assert!(matches!( + cli.command, + Commands::Run { + conf: Some(_), + pid: None, + } + )); + + assert_eq!(cli.log_level(), Some(log::Level::INFO)); + } } diff --git a/src/config/speed_mode.rs b/src/config/speed_mode.rs index 47c0f0d6..d3bb1d94 100644 --- a/src/config/speed_mode.rs +++ b/src/config/speed_mode.rs @@ -48,7 +48,7 @@ impl std::fmt::Debug for SpeedCheckMode { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct SpeedCheckModeList(pub Vec); impl SpeedCheckModeList { @@ -72,6 +72,16 @@ impl From> for SpeedCheckModeList { } } +impl std::fmt::Debug for SpeedCheckModeList { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, m) in self.0.iter().enumerate() { + let last = i == self.len() - 1; + write!(f, "{:?}{}", m, if !last { ", " } else { "" })?; + } + Ok(()) + } +} + impl std::ops::Deref for SpeedCheckModeList { type Target = Vec; diff --git a/src/dns.rs b/src/dns.rs index 35eed947..9a28410e 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -193,6 +193,17 @@ mod serial_message { .into()) } } + + impl TryFrom for Message { + type Error = ProtoError; + + fn try_from(value: SerialMessage) -> Result { + match value { + SerialMessage::Raw(message, _, _) => Ok(message), + SerialMessage::Bytes(bytes, _, _) => Message::from_vec(&bytes), + } + } + } } mod request { diff --git a/src/dns_client.rs b/src/dns_client.rs index d7a96e43..42c43bc9 100644 --- a/src/dns_client.rs +++ b/src/dns_client.rs @@ -684,25 +684,30 @@ impl GenericResolver for NameServer { let name = name.into_name()?; let options: LookupOptions = options.into(); + let query = Query::query(name, options.record_type); + + let client_subnet = options.client_subnet.or(self.opts.client_subnet); + + if options.client_subnet.is_none() { + if let Some(subnet) = client_subnet.as_ref() { + log::debug!( + "query name: {} type: {} subnet: {}/{}", + query.name(), + query.query_type(), + subnet.addr(), + subnet.scope_prefix(), + ); + } + } + let request_options = { let opts = &self.options(); let mut request_opts = DnsRequestOptions::default(); request_opts.recursion_desired = opts.recursion_desired; - request_opts.use_edns = opts.edns0; + request_opts.use_edns = opts.edns0 || client_subnet.is_some(); request_opts }; - let query = Query::query(name, options.record_type); - - let client_subnet = options.client_subnet.or(self.opts.client_subnet); - - log::debug!( - "query name: {} type: {}, {:?}", - query.name(), - query.query_type(), - client_subnet - ); - let req = DnsRequest::new( build_message(query, request_options, client_subnet, options.is_dnssec), request_options, diff --git a/src/dns_conf.rs b/src/dns_conf.rs index 24b8bf34..9b173a80 100644 --- a/src/dns_conf.rs +++ b/src/dns_conf.rs @@ -134,6 +134,35 @@ impl RuntimeConfig { DEFAULT_GROUP ); } + + info!( + "cache: {}", + if self.cache_size() > 0 { + format!("size({})", self.cache_size()) + } else { + "OFF".to_string() + } + ); + + if self.cache_size() > 0 { + info!( + "cache persist: {}", + if self.cache_persist() { "YES" } else { "NO" } + ); + + info!( + "domain prefetch: {}", + if self.prefetch_domain() { "ON" } else { "OFF" } + ); + } + + info!( + "speed check mode: {}", + match self.speed_check_mode() { + Some(mode) => format!("{:?}", mode), + None => "OFF".to_string(), + } + ); } pub fn server_name(&self) -> Name { diff --git a/src/dns_error.rs b/src/dns_error.rs index 264bb0cf..b3610a7e 100644 --- a/src/dns_error.rs +++ b/src/dns_error.rs @@ -36,6 +36,22 @@ pub enum LookupError { Io(Arc), } +impl PartialEq for LookupError { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::ResponseCode(l0), Self::ResponseCode(r0)) => l0 == r0, + (Self::Proto(l0), Self::Proto(r0)) => l0.to_string() == r0.to_string(), + (Self::ResolveError(l0), Self::ResolveError(r0)) => l0.to_string() == r0.to_string(), + #[cfg(feature = "hickory-recursor")] + (Self::RecursiveError(l0), Self::RecursiveError(r0)) => { + l0.to_string() == r0.to_string() + } + (Self::Io(l0), Self::Io(r0)) => l0.to_string() == r0.to_string(), + _ => core::mem::discriminant(self) == core::mem::discriminant(other), + } + } +} + impl LookupError { pub fn is_nx_domain(&self) -> bool { matches!(self, Self::ResponseCode(resc) if resc.eq(&ResponseCode::NXDomain)) diff --git a/src/dns_mw_cache.rs b/src/dns_mw_cache.rs index 7a8663f5..d0f383cf 100644 --- a/src/dns_mw_cache.rs +++ b/src/dns_mw_cache.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::HashSet; use serde::{Deserialize, Serialize}; use std::fs::File; @@ -15,6 +14,7 @@ use std::time::Instant; use crate::config::ServerOpts; use crate::dns_conf::RuntimeConfig; use crate::libdns::proto::error::ProtoResult; +use crate::log; use crate::server::DnsHandle; use crate::{ dns::*, @@ -27,7 +27,7 @@ use crate::{ }; use lru::LruCache; use tokio::sync::Notify; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock}; use tokio::time::sleep; pub struct DnsCacheMiddleware { @@ -39,20 +39,7 @@ pub struct DnsCacheMiddleware { impl DnsCacheMiddleware { pub fn new(cfg: &Arc, dns_handle: DnsHandle) -> Self { - // create - let mut ttl = TtlOpts::default(); - - if let Some(positive_min_ttl) = cfg.rr_ttl_min().map(Duration::from_secs) { - ttl.set_positive_min(positive_min_ttl); - } - - if let Some(positive_max_ttl) = cfg.rr_ttl_max().map(Duration::from_secs) { - ttl.set_positive_min(positive_max_ttl); - } - ttl.set_negative_max(Duration::from_secs(cfg.serve_expired_ttl())); - ttl.set_negative_min(Duration::from_secs(cfg.serve_expired_reply_ttl())); - - let cache = DnsCache::new(cfg.cache_size(), ttl); + let cache = DnsCache::new(cfg.cache_size()); if cfg.cache_persist() { let cache_file = cfg.cache_file(); @@ -74,6 +61,7 @@ impl DnsCacheMiddleware { prefetch_notify: Arc::new(DomainPrefetchingNotify::new()), bg_client: dns_handle.with_new_opt(ServerOpts { is_background: true, + no_cache: Some(true), ..Default::default() }), }; @@ -92,121 +80,124 @@ impl DnsCacheMiddleware { fn start_prefetching(&self) { let prefetch_notify = self.prefetch_notify.clone(); - let (tx, mut rx) = mpsc::channel::>(100); - let client = self.bg_client.clone(); - let cache = self.cache.cache(); + tokio::spawn(async move { + let num_workers = std::cmp::max( + tokio::runtime::Handle::current().metrics().num_workers() / 5, + 1, + ); - { - // prefetch domain. - tokio::spawn(async move { - let querying: Arc>> = Default::default(); - - loop { - if let Some(queries) = rx.recv().await { - let client = client.clone(); - let querying = querying.clone(); - - for query in queries { - if !querying.lock().await.insert(query.clone()) { - continue; - } - let querying = querying.clone(); - - let (client, name, typ) = - (client.clone(), query.name().to_owned(), query.query_type()); - tokio::spawn(async move { - let now = Instant::now(); - let mut message = Message::new(); - message.add_query(query.clone()); - client.send(message.into()).await; - - debug!( - "Prefetch domain {} {}, elapsed {:?}", - name, - typ, - now.elapsed() - ); - querying.lock().await.remove(&query); - }); - } - } - } - }); - } + let concurrent = Arc::new(tokio::sync::Semaphore::new(num_workers)); - { - // check expired domain. - let cache = cache.clone(); - let prefetch_notify = prefetch_notify.clone(); + let min_interval = Duration::from_secs( + std::env::var("PREFETCH_MIN_INTERVAL") + .as_deref() + .unwrap_or("1") + .parse() + .unwrap_or(1), + ); + let mut last_check = Instant::now(); - tokio::spawn(async move { - let min_interval = Duration::from_secs( - std::env::var("PREFETCH_MIN_INTERVAL") - .as_deref() - .unwrap_or("1") - .parse() - .unwrap_or(1), - ); - let mut last_check = Instant::now(); + loop { + prefetch_notify.notified().await; - loop { - prefetch_notify.notified().await; + let now = Instant::now(); + let mut most_recent; + if now - last_check > min_interval { + last_check = now; - let now = Instant::now(); - let mut most_recent; - if now - last_check > min_interval { - last_check = now; + most_recent = Duration::from_secs(MAX_TTL as u64); + let mut expired = vec![]; - most_recent = Duration::from_secs(MAX_TTL as u64); - let mut expired = vec![]; + { + let mut cache = cache.lock().await; + let len = cache.len(); + if len == 0 { + continue; + } - { - let mut cache = cache.lock().await; - let len = cache.len(); - if len == 0 { + for (query, entry) in cache.iter_mut() { + if entry.is_in_prefetching { + continue; + } + // only prefetch query type ip addr + if !query.query_type().is_ip_addr() { continue; } - for (query, entry) in cache.iter_mut() { - if entry.is_in_prefetching { - continue; - } - // only prefetch query type ip addr - if !query.query_type().is_ip_addr() { - continue; - } - - if entry.is_current(now) { - most_recent = most_recent.min(entry.ttl(now)); - continue; - } + if entry.is_current(now) { + most_recent = most_recent.min(entry.ttl(now)); + continue; + } - entry.is_in_prefetching = true; + entry.is_in_prefetching = true; - expired.push(query.to_owned()); - } - debug!( - "Domain prefetch check(total: {}), elapsed {:?}", - len, - now.elapsed() - ); + expired.push(query.to_owned()); } + debug!( + "Domain prefetch check(total: {}), elapsed {:?}", + len, + now.elapsed() + ); + } - if !expired.is_empty() && tx.send(expired).await.is_err() { - error!("Failed to send queries to prefetch domain!"); + if !expired.is_empty() { + for query in expired { + let client = client.clone(); + let cache = cache.clone(); + let concurrent = concurrent.clone(); + + tokio::spawn(async move { + match concurrent.acquire().await { + Ok(_) => { + let now = Instant::now(); + let mut message = Message::new(); + message.add_query(query.clone()); + let serial_message = client.send(message.into()).await; + + if let Ok(message) = Message::try_from(serial_message) { + if let Some(entry) = cache.lock().await.peek_mut(&query) + { + let data = message.into(); + entry.set_data(data); + entry.set_valid_until( + Instant::now() + + Duration::from_secs( + entry + .data + .min_ttl() + .unwrap_or_default() + .min(600) + .into(), + ), + ) + } + } + + debug!( + "Prefetch domain {} {}, elapsed {:?}", + query.name(), + query.query_type(), + now.elapsed() + ); + } + Err(err) => { + log::error!("{:?}", err); + } + } + }); } - } else { - most_recent = Duration::ZERO; } - - // sleep and wait for next check. - let dura = most_recent.max(min_interval); - prefetch_notify.notify_after(dura).await; + } else { + most_recent = Duration::ZERO; } - }); - } + + // sleep and wait for next check. + let dura = most_recent.max(min_interval); + prefetch_notify.notify_after(dura).await; + } + }); } } @@ -352,24 +343,21 @@ const MAX_TTL: u32 = 86400_u32; /// An LRU eviction cache specifically for storing DNS records pub struct DnsCache { cache: Arc>>, - ttl: TtlOpts, } impl DnsCache { - fn new(cache_size: usize, ttl: TtlOpts) -> Self { + fn new(cache_size: usize) -> Self { let cache = Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(cache_size).unwrap(), ))); - Self { cache, ttl } + Self { cache } } fn cache(&self) -> Arc>> { self.cache.clone() } - // fn insert - pub async fn clear(&self) { self.cache.lock().await.clear(); } @@ -379,14 +367,11 @@ impl DnsCache { .lock() .await .iter() - .flat_map(|(query, v)| match &v.lookup { - Ok(lookup) => Some(CachedQueryRecord { - name: query.name().clone(), - query_type: query.query_type(), - query_class: query.query_class(), - records: lookup.records().to_vec().into_boxed_slice(), - }), - Err(_) => None, + .map(|(query, entry)| CachedQueryRecord { + name: query.name().clone(), + query_type: query.query_type(), + query_class: query.query_class(), + records: entry.data.records().to_vec().into_boxed_slice(), }) .collect() } @@ -401,7 +386,7 @@ impl DnsCache { let len = records_and_ttl.len(); // collapse the values, we're going to take the Minimum TTL as the correct one let (records, ttl): (Vec, Duration) = records_and_ttl.into_iter().fold( - (Vec::with_capacity(len), self.ttl.positive_max), + (Vec::with_capacity(len), Duration::from_secs(600)), |(mut records, mut min_ttl), (record, ttl)| { records.push(record); let ttl = Duration::from_secs(u64::from(ttl)); @@ -410,28 +395,21 @@ impl DnsCache { }, ); - // If the cache was configured with a minimum TTL, and that value is higher - // than the minimum TTL in the values, use it instead. - let ttl = self.ttl.positive_min.max(ttl); - let ttl = self.ttl.positive_max.min(ttl); - let valid_until = now + ttl; // insert into the LRU let lookup = DnsResponse::new_with_deadline(query.clone(), records, valid_until) .with_name_server_group(name_server_group.to_string()); - if let Ok(mut cache) = self.cache.try_lock() { - cache.put( - query, - DnsCacheEntry { - lookup: Ok(lookup.clone()), - valid_until, - is_in_prefetching: false, - }, - ); - } else { - debug!("Get dns cache lock to write failed"); + { + let cache = self.cache.clone(); + let lookup = lookup.clone(); + tokio::spawn(async move { + cache + .lock() + .await + .put(query, DnsCacheEntry::new(lookup, valid_until)); + }); } lookup @@ -484,6 +462,7 @@ impl DnsCache { .into_iter() .flat_map(|(_, r)| r) .collect::>(); + lookup = Some( self.insert(original_query.clone(), records, now, name_server_group) .await, @@ -534,37 +513,24 @@ impl DnsCache { } }; - let mut should_pop = false; - let lookup = cache.get_mut(query).and_then(|value| { + let mut expired = false; + let lookup = cache.get_mut(query).map(|value| { + value.last_access = Instant::now(); if value.is_current(now) { - let result = match value.lookup.clone() { - Ok(mut res) => { - res.set_max_ttl(value.ttl(now).as_secs() as u32); - Ok(res) - } - Err(mut err) => { - Self::nx_error_with_ttl(&mut err, value.ttl(now)); - Err(err) - } - }; + let mut res = value.data.clone(); + res.set_max_ttl(value.ttl(now).as_secs() as u32); - Some((OutOfDate::No, result)) + (OutOfDate::No, Ok(res)) } else { + expired = true; let negative_ttl = now - value.valid_until; - if negative_ttl < self.ttl.negative_max { - let result = value.lookup.clone(); - if let Ok(ref mut lookup) = value.lookup { - lookup.set_new_ttl(negative_ttl.as_secs() as u32) - } - Some((OutOfDate::Yes, result)) - } else { - should_pop = true; - None - } + let mut res = value.data.clone(); + res.set_new_ttl(negative_ttl.as_secs() as u32); + (OutOfDate::Yes, Ok(res)) } }); - if should_pop { + if expired { cache.pop(query).unwrap(); } lookup @@ -579,106 +545,38 @@ pub struct CachedQueryRecord { records: Box<[Record]>, } -struct TtlOpts { - /// A minimum TTL value for positive responses. - /// - /// Positive responses with TTLs under `positive_max_ttl` will use - /// `positive_max_ttl` instead. - /// - /// If this value is not set on the `TtlConfig` used to construct this - /// `DnsLru`, it will default to 0. - positive_min: Duration, - - /// A maximum TTL value for positive responses. - /// - /// Positive responses with TTLs over `positive_max_ttl` will use - /// `positive_max_ttl` instead. - /// - /// If this value is not set on the `TtlConfig` used to construct this - /// `DnsLru`, it will default to [`MAX_TTL`] seconds. - /// - /// [`MAX_TTL`]: const.MAX_TTL.html - positive_max: Duration, - - /// A minimum TTL value for negative (`NXDOMAIN`) responses. - /// - /// `NXDOMAIN` responses with TTLs under `negative_min_ttl` will use - /// `negative_min_ttl` instead. - /// - /// If this value is not set on the `TtlConfig` used to construct this - /// `DnsLru`, it will default to 0. - negative_min: Duration, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OutOfDate { + Yes, + No, +} - /// A maximum TTL value for negative (`NXDOMAIN`) responses. - /// - /// `NXDOMAIN` responses with TTLs over `negative_max_ttl` will use - /// `negative_max_ttl` instead. - /// - /// If this value is not set on the `TtlConfig` used to construct this - /// `DnsLru`, it will default to [`MAX_TTL`] seconds. - /// - /// [`MAX_TTL`]: const.MAX_TTL.html - negative_max: Duration, +struct DnsCacheEntry { + data: T, + valid_until: Instant, + is_in_prefetching: bool, + last_access: Instant, } -impl TtlOpts { - fn default() -> Self { +impl DnsCacheEntry { + fn new(data: T, valid_until: Instant) -> Self { Self { - positive_min: Duration::from_secs(0), - positive_max: Duration::from_secs(u64::from(MAX_TTL)), - negative_min: Duration::from_secs(0), - negative_max: Duration::from_secs(u64::from(MAX_TTL)), + data, + valid_until, + is_in_prefetching: false, + last_access: Instant::now(), } } - fn with_positive_min(mut self, ttl: Duration) -> Self { - self.positive_min = ttl; - self - } - - fn with_positive_max(mut self, ttl: Duration) -> Self { - self.positive_max = ttl; - self - } - - fn with_negative_min(mut self, ttl: Duration) -> Self { - self.negative_min = ttl; - self - } - fn with_negative_max(mut self, ttl: Duration) -> Self { - self.negative_max = ttl; - self - } - - fn set_positive_min(&mut self, ttl: Duration) { - self.positive_min = ttl; - } - - fn set_positive_max(&mut self, ttl: Duration) { - self.positive_max = ttl; + fn set_data(&mut self, data: T) { + self.data = data; + self.is_in_prefetching = false; } - fn set_negative_min(&mut self, ttl: Duration) { - self.negative_min = ttl; - } - fn set_negative_max(&mut self, ttl: Duration) { - self.negative_max = ttl; + fn set_valid_until(&mut self, valid_until: Instant) { + self.valid_until = valid_until; } -} -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum OutOfDate { - Yes, - No, -} - -struct DnsCacheEntry { - lookup: Result, - valid_until: Instant, - is_in_prefetching: bool, -} - -impl DnsCacheEntry { /// Returns true if this set of ips is still valid fn is_current(&self, now: Instant) -> bool { now <= self.valid_until @@ -794,7 +692,7 @@ impl PersistCache for LruCache { let lookups = self .iter() - .filter_map(|(_, entry)| entry.lookup.clone().ok()) + .map(|(_, entry)| entry.data.clone()) .collect::>(); match cache_to_file(&lookups, path) { @@ -823,12 +721,7 @@ impl PersistCache for LruCache { cache.put(query, { let valid_until = lookup.valid_until(); - - DnsCacheEntry { - lookup: Ok(lookup), - valid_until, - is_in_prefetching: false, - } + DnsCacheEntry::new(lookup, valid_until) }); } info!( @@ -877,85 +770,85 @@ mod tests { assert_eq!(&lookups[1], &lookup2[1]); } - #[test] - fn test_cache_persist() { - tokio::runtime::Runtime::new().unwrap().block_on(async { - let lookup1 = create_lookup( - "abc.exmample.com.", - RData::A("127.0.0.1".parse().unwrap()), - 3000, - ); - let lookup2 = create_lookup( - "xyz.exmample.com.", - RData::AAAA("::1".parse().unwrap()), - 3000, - ); + #[tokio::test] + async fn test_cache_persist() { + let lookup1 = create_lookup( + "abc.exmample.com.", + RData::A("127.0.0.1".parse().unwrap()), + 3000, + ); + let lookup2 = create_lookup( + "xyz.exmample.com.", + RData::AAAA("::1".parse().unwrap()), + 3000, + ); - let cache = DnsCache::new(10, TtlOpts::default()); + let cache = DnsCache::new(10); - let now = Instant::now(); + let now = Instant::now(); - cache - .insert_records( - lookup1.query().clone(), - lookup1.record_iter().cloned(), - now, - "default", - ) - .await; + cache + .insert_records( + lookup1.query().clone(), + lookup1.record_iter().cloned(), + now, + "default", + ) + .await; + + cache + .insert_records( + lookup2.query().clone(), + lookup2.record_iter().cloned(), + now, + "default", + ) + .await; - cache - .insert_records( - lookup2.query().clone(), - lookup2.record_iter().cloned(), - now, - "default", - ) - .await; + sleep(Duration::from_millis(500)).await; - assert!(cache.get(lookup1.query(), now).await.is_some()); + assert!(cache.get(lookup1.query(), now).await.is_some()); - { - let lru_cache = cache.cache(); - let mut lru_cache = lru_cache.lock().await; - assert_eq!(lru_cache.len(), 2); + { + let lru_cache = cache.cache(); + let mut lru_cache = lru_cache.lock().await; + assert_eq!(lru_cache.len(), 2); - lru_cache.persist("./logs/smartdns-test.cache"); + lru_cache.persist("./logs/smartdns-test.cache"); - assert!(lru_cache.get(lookup1.query()).is_some()); + assert!(lru_cache.get(lookup1.query()).is_some()); - lru_cache.clear(); + lru_cache.clear(); - assert_eq!(lru_cache.len(), 0); + assert_eq!(lru_cache.len(), 0); - lru_cache.load("./logs/smartdns-test.cache"); + lru_cache.load("./logs/smartdns-test.cache"); - assert_eq!(lru_cache.len(), 2); + assert_eq!(lru_cache.len(), 2); - assert!(lru_cache - .iter() - .map(|(q, _)| q) - .any(|q| q == lookup1.query())); - assert!(lru_cache - .iter() - .map(|(q, _)| q) - .any(|q| q == lookup2.query())); + assert!(lru_cache + .iter() + .map(|(q, _)| q) + .any(|q| q == lookup1.query())); + assert!(lru_cache + .iter() + .map(|(q, _)| q) + .any(|q| q == lookup2.query())); - assert!(lru_cache.contains(lookup1.query())); - assert!(lru_cache.contains(lookup2.query())); - }; + assert!(lru_cache.contains(lookup1.query())); + assert!(lru_cache.contains(lookup2.query())); + }; - let res = cache.get(lookup1.query(), now).await; + let res = cache.get(lookup1.query(), now).await; - assert!(res.is_some()); + assert!(res.is_some()); - let (out_of_date, res) = res.unwrap(); + let (out_of_date, res) = res.unwrap(); - assert_eq!(out_of_date, OutOfDate::No); + assert_eq!(out_of_date, OutOfDate::No); - let lookup = res.unwrap(); - assert_eq!(lookup.query(), lookup1.query()); - assert_eq!(lookup.records(), lookup1.records()); - }) + let lookup = res.unwrap(); + assert_eq!(lookup.query(), lookup1.query()); + assert_eq!(lookup.records(), lookup1.records()); } } diff --git a/src/dns_mw_dualstack.rs b/src/dns_mw_dualstack.rs index a9cc4f2e..c372d020 100644 --- a/src/dns_mw_dualstack.rs +++ b/src/dns_mw_dualstack.rs @@ -209,8 +209,8 @@ async fn multi_mode_ping_fastest( match ping_res { Ok(ping_out) => { // ping success - let ip = ping_out.destination().ip(); - let duration = ping_out.duration(); + let ip = ping_out.dest().ip_addr(); + let duration = ping_out.elapsed(); fastest_ip = Some((ip, duration)); break; } diff --git a/src/dns_mw_ns.rs b/src/dns_mw_ns.rs index ecc42eba..617068a5 100644 --- a/src/dns_mw_ns.rs +++ b/src/dns_mw_ns.rs @@ -6,6 +6,8 @@ use std::{borrow::Borrow, net::IpAddr, time::Duration}; use crate::dns_client::{LookupOptions, NameServer}; use crate::infra::ipset::IpSet; +use crate::infra::ping::{PingError, PingOutput}; +use crate::third_ext::FutureTimeoutExt; use crate::{ config::{ResponseMode, SpeedCheckMode, SpeedCheckModeList}, dns::*, @@ -93,8 +95,14 @@ impl Middleware for NameServerMid }; debug!( - "query name: {} type: {} via [Group: {}]", - name, rtype, group_name + "query name: {} type: {}{} via [Group: {}]", + name, + rtype, + match lookup_options.client_subnet.as_ref() { + Some(subnet) => format!("\tsubnet: {}/{}", subnet.addr(), subnet.scope_prefix()), + None => String::with_capacity(0), + }, + group_name ); ctx.source = LookupFrom::Server(group_name.to_string()); @@ -175,18 +183,17 @@ async fn lookup_ip( name: Name, options: &LookupIpOptions, ) -> Result { - use crate::third_ext::FutureJoinAllExt; use futures_util::future::{select, select_all, Either}; use ResponseMode::*; assert!(options.record_type.is_ip_addr()); - let mut tasks = server + let mut query_tasks = server .iter() .map(|ns| per_nameserver_lookup_ip(ns, name.clone(), options).boxed()) .collect::>(); - if tasks.is_empty() { + if query_tasks.is_empty() { return Err(ProtoErrorKind::NoConnections.into()); } @@ -208,11 +215,10 @@ async fn lookup_ip( let selected_ip = match response_strategy { FirstPing => { - let mut tasks = tasks; let mut ping_tasks = vec![]; - let mut selected_ip = None; + let mut fastest_ip = None; loop { - let (fastest_ip, res) = match (tasks.len(), ping_tasks.len()) { + let (ping_res, query_res) = match (query_tasks.len(), ping_tasks.len()) { (0, 0) => break, (0, _) => { let (fastest_ip, _, rest) = select_all(ping_tasks).await; @@ -220,22 +226,22 @@ async fn lookup_ip( (fastest_ip, None) } (_, 0) => { - let (res, _idx, rest) = select_all(tasks).await; - tasks = rest; + let (res, _idx, rest) = select_all(query_tasks).await; + query_tasks = rest; (None, Some(res)) } _ => { let a = select_all(ping_tasks); - let b = select_all(tasks); + let b = select_all(query_tasks); let c = select(a, b).await; match c { Either::Left(((fastest_ip, _, rest), other)) => { ping_tasks = rest; - tasks = other.into_inner(); + query_tasks = other.into_inner(); (fastest_ip, None) } Either::Right(((res, _, rest), other)) => { - tasks = rest; + query_tasks = rest; ping_tasks = other.into_inner(); (None, Some(res)) } @@ -243,12 +249,12 @@ async fn lookup_ip( } }; - if let Some(fastest_ip) = fastest_ip { - selected_ip = Some(fastest_ip); + if let Some(ip) = ping_res { + fastest_ip = Some(ip); break; } - match res { + match query_res { Some(v) => match v { Ok(lookup) => { let ip_addrs = lookup.ip_addrs(); @@ -273,17 +279,17 @@ async fn lookup_ip( } } - let selected_ip = match selected_ip { - Some(selected_ip) => Some(selected_ip), + let selected_ip = match fastest_ip { + Some(ip) => Some(ip), None => { - let ip_addrs_map = ok_tasks.iter().flat_map(|r| r.ip_addrs()).fold( + let ip_addr_stats = ok_tasks.iter().flat_map(|r| r.ip_addrs()).fold( HashMap::::new(), |mut map, ip| { map.entry(ip).and_modify(|n| *n += 1).or_insert(1); map }, ); - ip_addrs_map + ip_addr_stats .into_iter() .max_by_key(|(_, n)| *n) .map(|(ip, _)| ip) @@ -293,53 +299,128 @@ async fn lookup_ip( selected_ip } FastestIp => { - for res in tasks.join_all().await { - match res { - Ok(v) => ok_tasks.push(v), - Err(e) => err_tasks.push(e), + let mut ping_tasks = vec![]; + + let mut ip_addr_stats = HashMap::new(); + + let mut fastest_ip: Option = None; + + loop { + #[allow(clippy::type_complexity)] + let (ping_res, query_res): ( + Option>, + Option>, + ) = match (query_tasks.len(), ping_tasks.len()) { + (0, 0) => break, + (0, _) => { + let (res, _idx, rest) = select_all(ping_tasks).await; + ping_tasks = rest; + (Some(res), None) + } + (_, 0) => { + let (res, _idx, rest) = select_all(query_tasks).await; + query_tasks = rest; + (None, Some(res)) + } + _ => { + let a = select_all(ping_tasks); + let b = select_all(query_tasks); + let c = select(a, b).await; + match c { + Either::Left(((res, _, rest), other)) => { + ping_tasks = rest; + query_tasks = other.into_inner(); + (Some(res), None) + } + Either::Right(((res, _, rest), other)) => { + query_tasks = rest; + ping_tasks = other.into_inner(); + (None, Some(res)) + } + } + } + }; + + if let Some(Ok(out)) = ping_res { + if match fastest_ip.as_ref() { + Some(t) => out.elapsed() < t.elapsed(), + None => { + // first get speed, add timeout + query_tasks = query_tasks + .into_iter() + .map(|q| { + async { + match q.timeout(Duration::from_millis(200)).await { + Ok(t) => t, + Err(_) => Err(ProtoErrorKind::Timeout.into()), + } + } + .boxed() + }) + .collect(); + + true + } + } { + fastest_ip = Some(out); + } + } + + if let Some(res) = query_res { + match res { + Ok(lookup) => { + let ip_addrs = lookup.ip_addrs(); + + for ip_addr in &ip_addrs { + *ip_addr_stats.entry(*ip_addr).or_insert_with(|| { + ping_tasks.push( + multi_mode_ping( + name.clone(), + *ip_addr, + speed_check_mode.to_vec(), + ) + .boxed(), + ); + 0u8 + }) += 1; + } + ok_tasks.push(lookup); + } + Err(err) => { + err_tasks.push(err); + } + } } } - if ok_tasks.is_empty() { - return Err(err_tasks.into_iter().next().unwrap()); // There is definitely one. + match fastest_ip { + Some(fastest_ip) => Some(fastest_ip.dest().ip_addr()), + None => ip_addr_stats + .into_iter() + .max_by_key(|(_, n)| *n) + .map(|(ip, _)| ip), } + } + FastestResponse => { + let mut last_error = None; + loop { + let (res, _idx, rest) = select_all(query_tasks).await; + if rest.is_empty() + || matches!(&res, Ok(res) if res.answers().iter().any(|r| r.record_type() == options.record_type)) + { + return res; + } - let ip_addrs_map = ok_tasks.iter().flat_map(|r| r.ip_addrs()).fold( - HashMap::::new(), - |mut map, ip| { - map.entry(ip).and_modify(|n| *n += 1).or_insert(1); - map - }, - ); - - let mut ip_addrs = ip_addrs_map.keys().cloned().collect::>(); - - match ip_addrs.len() { - 0 => None, - 1 => ip_addrs.pop(), - _ => { - let fastest_ip = - multi_mode_ping_fastest(name.clone(), ip_addrs, speed_check_mode.to_vec()) - .await; - - fastest_ip.or_else(|| { - ip_addrs_map - .into_iter() - .max_by_key(|(_, n)| *n) - .map(|(ip, _)| ip) - }) + if let Err(err) = res { + if matches!(last_error, Some(e) if e == err) { + return Err(err); + } else { + last_error = Some(err); + } } + query_tasks = rest; } } - FastestResponse => loop { - let (res, _idx, rest) = select_all(tasks).await; - if rest.is_empty() - || matches!(&res, Ok(res) if res.answers().iter().any(|r| r.record_type() == options.record_type)) - { - return res; - } - tasks = rest; - }, }; if let Some(selected_ip) = selected_ip { @@ -387,12 +468,12 @@ async fn multi_mode_ping_fastest( match ping_res { Ok(ping_out) => { // ping success - let ip = ping_out.destination().ip(); + let ip = ping_out.dest().ip_addr(); debug!( "The fastest ip of {} is {}, delay: {:?}", name, ip, - ping_out.duration() + ping_out.elapsed() ); fastest_ip = Some(ip); break; @@ -410,6 +491,44 @@ async fn multi_mode_ping_fastest( fastest_ip } +async fn multi_mode_ping( + name: Name, + ip_addr: IpAddr, + modes: Vec, +) -> Result { + use crate::infra::ping::{ping, PingOptions}; + let duration = Duration::from_millis(200); + let ping_ops = PingOptions::default().with_timeout_secs(2); + + for mode in &modes { + let dest = mode.to_ping_addr(ip_addr); + + let ping_task = ping(dest, ping_ops).boxed(); + let timeout_task = sleep(duration).boxed(); + match futures_util::future::select(ping_task, timeout_task).await { + futures::future::Either::Left((ping_res, _)) => match ping_res { + Ok(ping_out) => { + debug!( + "Speed test {} {:?} ping {:?} elapsed {:?}", + name, + mode, + ip_addr, + ping_out.elapsed() + ); + return Ok(ping_out); + } + Err(_) => continue, + }, + futures::future::Either::Right((_, _)) => { + // timeout + continue; + } + } + } + + Err(PingError::Timeout) +} + async fn per_nameserver_lookup_ip( server: &NameServer, name: Name, diff --git a/src/infra/ping.rs b/src/infra/ping.rs index 915ecc13..9ba78a24 100644 --- a/src/infra/ping.rs +++ b/src/infra/ping.rs @@ -8,7 +8,19 @@ use std::{ }; use thiserror::Error; -pub async fn ping(dests: &[PingAddr], opts: PingOptions) -> Vec> { +pub async fn ping(dest: PingAddr, opts: PingOptions) -> Result { + match dest { + PingAddr::Icmp(addr) => icmp::ping(addr, opts).await, + PingAddr::Tcp(addr) => tcp::ping(addr, opts).await, + PingAddr::Http(addr) => http::ping(addr, opts).await, + PingAddr::Https(addr) => https::ping(addr, opts).await, + } +} + +pub async fn ping_batch( + dests: &[PingAddr], + opts: PingOptions, +) -> Vec> { let mut outs = Vec::new(); for dest in dests.iter() { @@ -22,19 +34,6 @@ pub async fn ping(dests: &[PingAddr], opts: PingOptions) -> Vec>( - dest: D, - opts: PingOptions, -) -> Result { - let dest = dest.try_into()?; - match dest { - PingAddr::Icmp(addr) => icmp::ping(addr, opts).await, - PingAddr::Tcp(addr) => tcp::ping(addr, opts).await, - PingAddr::Http(addr) => http::ping(addr, opts).await, - PingAddr::Https(addr) => https::ping(addr, opts).await, - } -} - pub async fn ping_fastest( dests: Vec, opts: PingOptions, @@ -117,7 +116,7 @@ pub enum PingAddr { } impl PingAddr { - pub fn ip(self) -> IpAddr { + pub fn ip_addr(self) -> IpAddr { match self { PingAddr::Icmp(ip) => ip, PingAddr::Tcp(addr) => addr.ip(), @@ -129,7 +128,7 @@ impl PingAddr { impl PartialEq for PingAddr { fn eq(&self, other: &IpAddr) -> bool { - self.ip() == *other + self.ip_addr() == *other } } impl PartialEq for PingAddr { @@ -187,7 +186,7 @@ impl TryFrom<&str> for PingAddr { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PingOutput { seq: u16, duration: Duration, @@ -201,12 +200,12 @@ impl PingOutput { } #[inline] - pub fn duration(&self) -> Duration { + pub fn elapsed(&self) -> Duration { self.duration } #[inline] - pub fn destination(&self) -> PingAddr { + pub fn dest(&self) -> PingAddr { self.destination } } @@ -217,6 +216,8 @@ pub enum PingError { PingTargetParseError, #[error("addr parse error {0}")] AddrParseError(#[from] AddrParseError), + #[error("addr parse error {0}")] + AddrParseError2(String), #[error("Ping timeout")] Timeout, #[error("io error {0}")] @@ -232,6 +233,7 @@ impl Clone for PingError { match self { Self::PingTargetParseError => Self::PingTargetParseError, Self::AddrParseError(arg0) => Self::AddrParseError(arg0.clone()), + Self::AddrParseError2(arg0) => Self::AddrParseError2(arg0.clone()), Self::Timeout => Self::Timeout, Self::IoError(err) => Self::IoError(err.kind().into()), Self::SurgeError => Self::SurgeError, @@ -855,7 +857,7 @@ mod tests { .unwrap(); rt.block_on(async { - let results = ping( + let results = ping_batch( &[ "127.0.0.1".parse().unwrap(), "icmp://223.6.6.6".parse().unwrap(), @@ -908,7 +910,7 @@ mod tests { .build() .unwrap() .block_on(async { - let res = ping_one("https://1.1.1.1:443", Default::default()) + let res = ping("https://1.1.1.1:443".parse().unwrap(), Default::default()) .await .unwrap(); assert!(res.duration < Duration::from_secs(5))