diff --git a/Cargo.toml b/Cargo.toml index bd63c03b..d601a2ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,9 +25,13 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(nightly)'] } [features] -default = ["dns-over-tls", "dns-over-https", "dns-over-quic", "dns-over-h3", "dnssec", "service", "nft", "self-update" ] +default = ["dns-over-tls", "dns-over-https", "dns-over-quic", "dns-over-h3", "dnssec", "service", "nft", "nom-recipes-all", "self-update" ] -homebrew = ["dns-over-tls", "dns-over-https", "dns-over-quic", "dns-over-h3", "dnssec", "service", "nft" ] +homebrew = ["dns-over-tls", "dns-over-https", "dns-over-quic", "dns-over-h3", "dnssec", "service", "nft", "nom-recipes-all" ] + +nom-recipes-all =["nom-recipes-ipv4", "nom-recipes-ipv6"] +nom-recipes-ipv4 = [] +nom-recipes-ipv6 = [] failed_tests = [] disable_icmp_ping = [] diff --git a/src/config/domain_rule.rs b/src/config/domain_rule.rs index 73b0f924..d83fa384 100644 --- a/src/config/domain_rule.rs +++ b/src/config/domain_rule.rs @@ -10,10 +10,12 @@ pub struct DomainRule { pub address: Option, - pub cname: Option, + pub cname: Option, pub srv: Option, + pub https: Option, + /// The mode of speed checking. pub speed_check_mode: Option, @@ -23,7 +25,7 @@ pub struct DomainRule { pub no_cache: Option, pub no_serve_expired: Option, - pub nftset: Option>>, + pub nftset: Option>>, pub rr_ttl: Option, pub rr_ttl_min: Option, diff --git a/src/config/mod.rs b/src/config/mod.rs index 4d797056..23f0b043 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -7,7 +7,10 @@ use std::{ use crate::{ infra::{file_mode::FileMode, ipset::IpSet}, - libdns::proto::rr::{rdata::SRV, Name, RecordType}, + libdns::proto::rr::{ + rdata::{HTTPS, SRV}, + Name, RecordType, + }, log::Level, proxy::ProxyConfig, third_ext::serde_str, @@ -48,8 +51,9 @@ pub type DomainSets = HashMap>; pub type ForwardRules = Vec; pub type AddressRules = Vec; pub type DomainRules = Vec>; -pub type CNameRules = Vec>; +pub type CNameRules = Vec>; pub type SrvRecords = Vec>; +pub type HttpsRecords = Vec>; #[derive(Default)] pub struct Config { @@ -237,10 +241,12 @@ pub struct Config { pub srv_records: SrvRecords, + pub https_records: HttpsRecords, + /// The proxy server for upstream querying. pub proxy_servers: HashMap, - pub nftsets: Vec>>>, + pub nftsets: Vec>>>, pub resolv_file: Option, pub domain_set_providers: HashMap>, @@ -274,7 +280,7 @@ pub enum ConfigForIP { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct NftsetConfig { +pub struct NFTsetConfig { pub family: &'static str, pub table: String, pub name: String, @@ -331,7 +337,7 @@ pub enum Ignorable { Value(T), } -pub type CName = Ignorable; +pub type CNameRule = Ignorable; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AddressRule { @@ -349,6 +355,18 @@ pub struct ForwardRule { pub nameserver: String, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[allow(clippy::upper_case_acronyms)] +pub enum HttpsRecordRule { + SOA, + Ignore, + Filter { + no_ipv4_hint: bool, + no_ipv6_hint: bool, + }, + RecordData(HTTPS), +} + macro_rules! impl_from_str { ($($type:ty),*) => { $( diff --git a/src/config/parser/cname.rs b/src/config/parser/cname.rs index 0086380d..ea5a898c 100644 --- a/src/config/parser/cname.rs +++ b/src/config/parser/cname.rs @@ -1,10 +1,10 @@ use super::*; -impl NomParser for CName { +impl NomParser for CNameRule { fn parse(input: &str) -> IResult<&str, Self> { alt(( - value(CName::IGN, char('-')), - map(NomParser::parse, CName::Value), + value(CNameRule::IGN, char('-')), + map(NomParser::parse, CNameRule::Value), ))(input) } } @@ -16,10 +16,10 @@ mod tests { #[test] fn test() { - assert_eq!(CName::parse("-"), Ok(("", CName::IGN))); + assert_eq!(CNameRule::parse("-"), Ok(("", CNameRule::IGN))); assert_eq!( - CName::parse("example.com"), - Ok(("", CName::Value("example.com".parse().unwrap()))) + CNameRule::parse("example.com"), + Ok(("", CNameRule::Value("example.com".parse().unwrap()))) ); } } diff --git a/src/config/parser/config_for_domain.rs b/src/config/parser/config_for_domain.rs index 3a965273..51dfdb8d 100644 --- a/src/config/parser/config_for_domain.rs +++ b/src/config/parser/config_for_domain.rs @@ -2,7 +2,9 @@ use super::*; impl NomParser for ConfigForDomain { fn parse(input: &str) -> IResult<&str, Self> { - let domain = delimited(char('/'), Domain::parse, char('/')); + let domain = map(opt(delimited(char('/'), Domain::parse, char('/'))), |n| { + n.unwrap_or_else(|| Domain::Name(WildcardName::Default(Name::root()))) + }); let config = T::parse; map( pair(domain, preceded(space0, config)), @@ -23,7 +25,7 @@ mod tests { "", ConfigForDomain { domain: Domain::Name("www.example.com".parse().unwrap()), - config: ConfigForIP::V4(NftsetConfig { + config: ConfigForIP::V4(NFTsetConfig { family: "inet", table: "tab".to_string(), name: "dns4".to_string() @@ -38,7 +40,22 @@ mod tests { "", ConfigForDomain { domain: Domain::Set("abc".to_string()), - config: ConfigForIP::V6(NftsetConfig { + config: ConfigForIP::V6(NFTsetConfig { + family: "inet", + table: "tab".to_string(), + name: "dns4".to_string() + }) + } + ) + ); + + assert_eq!( + ConfigForDomain::parse("#6:inet#tab#dns4").unwrap(), + ( + "", + ConfigForDomain { + domain: Domain::Name(WildcardName::Default(Name::root())), + config: ConfigForIP::V6(NFTsetConfig { family: "inet", table: "tab".to_string(), name: "dns4".to_string() diff --git a/src/config/parser/https_record.rs b/src/config/parser/https_record.rs new file mode 100644 index 00000000..c501aa1a --- /dev/null +++ b/src/config/parser/https_record.rs @@ -0,0 +1,101 @@ +use super::*; + +impl NomParser for HttpsRecordRule { + fn parse(input: &str) -> IResult<&str, Self> { + alt(( + map(char('#'), |_| Self::SOA), + map(char('-'), |_| Self::Ignore), + map( + separated_list1( + char(','), + delimited( + space0, + alt(( + value(4u8, tag_no_case("noipv4hint")), + value(6u8, tag_no_case("noipv6hint")), + )), + space0, + ), + ), + |no_hints| Self::Filter { + no_ipv4_hint: no_hints.contains(&4), + no_ipv6_hint: no_hints.contains(&6), + }, + ), + map(NomParser::parse, Self::RecordData), + ))(input) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::libdns::proto::rr::rdata::svcb::{Alpn, IpHint, SvcParamKey, SvcParamValue, SVCB}; + + #[test] + fn test_parse() { + assert_eq!(HttpsRecordRule::parse("#"), Ok(("", HttpsRecordRule::SOA))); + assert_eq!( + HttpsRecordRule::parse("-"), + Ok(("", HttpsRecordRule::Ignore)) + ); + assert_eq!( + HttpsRecordRule::parse("noipv4hint"), + Ok(( + "", + HttpsRecordRule::Filter { + no_ipv4_hint: true, + no_ipv6_hint: false + } + )) + ); + assert_eq!( + HttpsRecordRule::parse("noipv6hint, noipv4hint"), + Ok(( + "", + HttpsRecordRule::Filter { + no_ipv4_hint: true, + no_ipv6_hint: true + } + )) + ); + assert_eq!( + HttpsRecordRule::parse(r#"alpn="h2,http/1.1""#), + Ok(( + "", + HttpsRecordRule::RecordData(HTTPS(SVCB::new( + 0, + ".".parse().unwrap(), + vec![( + SvcParamKey::Alpn, + SvcParamValue::Alpn(Alpn(vec!["h2".to_string(), "http/1.1".to_string()])) + ),] + ))) + )) + ); + + assert_eq!( + HttpsRecordRule::parse(r#"ipv4hint=127.0.0.1,ipv6hint="::1, 2001:db8::1""#), + Ok(( + "", + HttpsRecordRule::RecordData(HTTPS(SVCB::new( + 0, + ".".parse().unwrap(), + vec![ + ( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec!["127.0.0.1".parse().unwrap()])) + ), + ( + SvcParamKey::Ipv6Hint, + SvcParamValue::Ipv6Hint(IpHint(vec![ + "::1".parse().unwrap(), + "2001:db8::1".parse().unwrap() + ])) + ) + ] + ))) + )) + ); + } +} diff --git a/src/config/parser/mod.rs b/src/config/parser/mod.rs index a1ca7c9f..d034e801 100644 --- a/src/config/parser/mod.rs +++ b/src/config/parser/mod.rs @@ -15,11 +15,13 @@ mod domain_set; mod file_mode; mod forward_rule; mod glob_pattern; +mod https_record; mod ipnet; mod listener; mod log_level; mod nameserver; mod nftset; +mod nom_recipes; mod options; mod path; mod proxy_config; @@ -27,6 +29,7 @@ mod record_type; mod response_mode; mod speed_mode; mod srv; +mod svcb; use super::*; @@ -91,8 +94,9 @@ pub enum OneConfig { CacheCheckpointTime(u64), CaFile(PathBuf), CaPath(PathBuf), - CNAME(ConfigForDomain), - SRV(ConfigForDomain), + CNAME(ConfigForDomain), + SrvRecord(ConfigForDomain), + HttpsRecord(ConfigForDomain), ConfFile(PathBuf), DnsmasqLeaseFile(PathBuf), Domain(Name), @@ -119,7 +123,7 @@ pub enum OneConfig { LogFilter(String), MaxReplyIpNum(u8), MdnsLookup(bool), - NftSet(ConfigForDomain>>), + NftSet(ConfigForDomain>>), NumWorkers(usize), PrefetchDomain(bool), ProxyConfig(NamedProxyConfig), @@ -205,6 +209,7 @@ pub fn parse_config(input: &str) -> IResult<&str, OneConfig> { map(parse_item("num-workers"), OneConfig::NumWorkers), map(parse_item("domain"), OneConfig::Domain), map(parse_item("hosts-file"), OneConfig::HostsFile), + map(parse_item("https-record"), OneConfig::HttpsRecord), map(parse_item("local-ttl"), OneConfig::LocalTtl), map(parse_item("log-console"), OneConfig::LogConsole), map(parse_item("log-file-mode"), OneConfig::LogFileMode), @@ -238,7 +243,7 @@ pub fn parse_config(input: &str) -> IResult<&str, OneConfig> { )); let group4 = alt(( - map(parse_item("srv-record"), OneConfig::SRV), + map(parse_item("srv-record"), OneConfig::SrvRecord), map(parse_item("resolv-hostname"), OneConfig::ResolvHostname), map(parse_item("tcp-idle-time"), OneConfig::TcpIdleTime), map(parse_item("nftset"), OneConfig::NftSet), @@ -264,7 +269,7 @@ mod tests { "", OneConfig::NftSet(ConfigForDomain { domain: Domain::Name("www.example.com".parse().unwrap()), - config: vec![ConfigForIP::V4(NftsetConfig { + config: vec![ConfigForIP::V4(NFTsetConfig { family: "inet", table: "tab".to_string(), name: "dns4".to_string() @@ -279,7 +284,7 @@ mod tests { "", OneConfig::NftSet(ConfigForDomain { domain: Domain::Name("www.example.com".parse().unwrap()), - config: vec![ConfigForIP::V4(NftsetConfig { + config: vec![ConfigForIP::V4(NFTsetConfig { family: "inet", table: "tab".to_string(), name: "dns4".to_string() diff --git a/src/config/parser/nftset.rs b/src/config/parser/nftset.rs index 727903c9..909a01f1 100644 --- a/src/config/parser/nftset.rs +++ b/src/config/parser/nftset.rs @@ -1,8 +1,8 @@ use super::*; -use super::NftsetConfig; +use super::NFTsetConfig; -impl NomParser for NftsetConfig { +impl NomParser for NFTsetConfig { #[inline] fn parse(input: &str) -> IResult<&str, Self> { let family = alt(( @@ -23,7 +23,7 @@ impl NomParser for NftsetConfig { let (input, (family, table, name)) = tuple((family, table, name))(input)?; Ok(( input, - NftsetConfig { + NFTsetConfig { family, table: table.to_string(), name: name.to_string(), @@ -32,18 +32,18 @@ impl NomParser for NftsetConfig { } } -impl NomParser for ConfigForIP { +impl NomParser for ConfigForIP { #[inline] fn parse(input: &str) -> IResult<&str, Self> { let v4 = preceded( tag("#4:"), - verify(NftsetConfig::parse, |x| { + verify(NFTsetConfig::parse, |x| { x.family == "inet" || x.family == "ip" }), ); let v6 = preceded( tag("#6:"), - verify(NftsetConfig::parse, |x| { + verify(NFTsetConfig::parse, |x| { x.family == "inet" || x.family == "ip6" }), ); @@ -56,11 +56,11 @@ impl NomParser for ConfigForIP { } } -impl NomParser for Vec> { +impl NomParser for Vec> { fn parse(input: &str) -> IResult<&str, Self> { separated_list1( tuple((space0, char(','), space0)), - ConfigForIP::::parse, + ConfigForIP::::parse, )(input) } } @@ -72,10 +72,10 @@ mod tests { #[test] fn test() { assert_eq!( - NftsetConfig::parse("inet#tab1#dns_4").unwrap(), + NFTsetConfig::parse("inet#tab1#dns_4").unwrap(), ( "", - NftsetConfig { + NFTsetConfig { family: "inet", table: "tab1".to_string(), name: "dns_4".to_string() @@ -84,10 +84,10 @@ mod tests { ); assert_eq!( - NftsetConfig::parse("inet#tab1#dns4").unwrap(), + NFTsetConfig::parse("inet#tab1#dns4").unwrap(), ( "", - NftsetConfig { + NFTsetConfig { family: "inet", table: "tab1".to_string(), name: "dns4".to_string() @@ -96,10 +96,10 @@ mod tests { ); assert_eq!( - NftsetConfig::parse("ip6#tab1#dns6").unwrap(), + NFTsetConfig::parse("ip6#tab1#dns6").unwrap(), ( "", - NftsetConfig { + NFTsetConfig { family: "ip6", table: "tab1".to_string(), name: "dns6".to_string() @@ -111,10 +111,10 @@ mod tests { #[test] fn test_ip() { assert_eq!( - ConfigForIP::::parse("#4:inet#tab#dns4").unwrap(), + ConfigForIP::::parse("#4:inet#tab#dns4").unwrap(), ( "", - ConfigForIP::V4(NftsetConfig { + ConfigForIP::V4(NFTsetConfig { family: "inet", table: "tab".to_string(), name: "dns4".to_string(), @@ -123,10 +123,10 @@ mod tests { ); assert_eq!( - ConfigForIP::::parse("#6:ip6#tab#dns6").unwrap(), + ConfigForIP::::parse("#6:ip6#tab#dns6").unwrap(), ( "", - ConfigForIP::V6(NftsetConfig { + ConfigForIP::V6(NFTsetConfig { family: "ip6", table: "tab".to_string(), name: "dns6".to_string(), @@ -135,7 +135,7 @@ mod tests { ); assert_eq!( - ConfigForIP::::parse("-").unwrap(), + ConfigForIP::::parse("-").unwrap(), ("", ConfigForIP::None) ); } diff --git a/src/config/parser/nom_recipes/ipv4.rs b/src/config/parser/nom_recipes/ipv4.rs new file mode 100644 index 00000000..392a9812 --- /dev/null +++ b/src/config/parser/nom_recipes/ipv4.rs @@ -0,0 +1,48 @@ +use std::net::Ipv4Addr; + +use nom::{ + character::complete::{char, digit1}, + combinator::{map, map_res, recognize}, + error::context, + multi::many_m_n, + sequence::{preceded, tuple}, + IResult, +}; + +pub fn ipv4(input: &str) -> IResult<&str, Ipv4Addr> { + fn octal(input: &str) -> IResult<&str, u8> { + map_res(recognize(many_m_n(1, 3, digit1)), |s: &str| s.parse())(input) + } + + context( + "Ipv4Addr", + map( + tuple(( + octal, + preceded(char('.'), octal), + preceded(char('.'), octal), + preceded(char('.'), octal), + )), + |(a, b, c, d)| Ipv4Addr::new(a, b, c, d), + ), + )(input) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ipv4() { + assert_eq!(ipv4("127.0.0.1"), Ok(("", Ipv4Addr::new(127, 0, 0, 1)))); + assert_eq!( + ipv4("255.255.255.255"), + Ok(("", Ipv4Addr::new(255, 255, 255, 255))) + ); + assert_eq!(ipv4("0.0.0.0"), Ok(("", Ipv4Addr::new(0, 0, 0, 0)))); + assert_eq!(ipv4("1.2.3.4"), Ok(("", Ipv4Addr::new(1, 2, 3, 4)))); + assert!(ipv4("256.0.0.0").is_err()); + assert!(ipv4("0.0 .0.256").is_err()); + assert!(ipv4("0.0.0").is_err()); + } +} diff --git a/src/config/parser/nom_recipes/ipv6.rs b/src/config/parser/nom_recipes/ipv6.rs new file mode 100644 index 00000000..f2a8acf2 --- /dev/null +++ b/src/config/parser/nom_recipes/ipv6.rs @@ -0,0 +1,82 @@ +use std::net::Ipv6Addr; + +use nom::{ + bytes::complete::tag, + character::complete::{char, hex_digit1}, + combinator::{map, map_res, opt, recognize, verify}, + error::context, + multi::{many_m_n, separated_list0}, + sequence::{pair, preceded}, + IResult, +}; + +pub fn ipv6(input: &str) -> IResult<&str, Ipv6Addr> { + fn octal(input: &str) -> IResult<&str, u16> { + map_res(recognize(many_m_n(1, 4, hex_digit1)), |s| { + u16::from_str_radix(s, 16) + })(input) + } + + context( + "Ipv6Addr", + map( + verify( + pair( + separated_list0(char(':'), octal), + map( + opt(preceded(tag("::"), separated_list0(char(':'), octal))), + |v| v.unwrap_or_default(), + ), + ), + |(pre, post)| pre.len() == 8 || pre.len() + post.len() < 8, + ), + |(pre, post)| { + let mut octets = [0u16; 8]; + for (i, octet) in pre.iter().enumerate() { + octets[i] = *octet; + } + if !post.is_empty() { + let n = 8 - post.len(); + for (i, octet) in post.iter().enumerate() { + octets[i + n] = *octet; + } + } + Ipv6Addr::new( + octets[0], octets[1], octets[2], octets[3], octets[4], octets[5], octets[6], + octets[7], + ) + }, + ), + )(input) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ipv6() { + assert_eq!(ipv6("::1"), Ok(("", Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)))); + assert_eq!(ipv6("::"), Ok(("", Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)))); + assert_eq!( + ipv6("::ffff:0:0"), + Ok(("", Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0, 0))) + ); + assert_eq!( + ipv6("::ffff:192:0:2:128"), + Ok(("", Ipv6Addr::new(0, 0, 0, 0xffff, 0x192, 0x0, 0x2, 0x128))) + ); + assert_eq!( + ipv6("2001:db8::1"), + Ok(("", Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1))) + ); + assert_eq!( + ipv6("2001:db8:0:0:0:0:2:1"), + Ok(("", Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0x2, 0x1))) + ); + assert_eq!( + ipv6("2001:db8:0:0:0:0:2:1"), + Ok(("", Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0x2, 0x1))) + ); + } +} diff --git a/src/config/parser/nom_recipes/mod.rs b/src/config/parser/nom_recipes/mod.rs new file mode 100644 index 00000000..77ccbbac --- /dev/null +++ b/src/config/parser/nom_recipes/mod.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "nom-recipes-ipv4")] +mod ipv4; +#[cfg(feature = "nom-recipes-ipv4")] +pub use ipv4::ipv4; + +#[cfg(feature = "nom-recipes-ipv6")] +mod ipv6; +#[cfg(feature = "nom-recipes-ipv6")] +pub use ipv6::ipv6; diff --git a/src/config/parser/svcb.rs b/src/config/parser/svcb.rs new file mode 100644 index 00000000..65eb44e8 --- /dev/null +++ b/src/config/parser/svcb.rs @@ -0,0 +1,348 @@ +use super::*; +use crate::libdns::proto::rr::rdata::svcb::{ + Alpn, EchConfigList, IpHint, SvcParamKey, SvcParamValue, SVCB, +}; +use crate::libdns::proto::rr::rdata::{A, AAAA}; + +impl NomParser for SVCB { + fn parse(input: &str) -> IResult<&str, Self> { + let param_key = |name: &'static str| { + terminated(tag_no_case(name), delimited(space0, char('='), space0)) + }; + + let mut target_name = Name::root(); + let mut svc_priority = 0; + + let alpn = map( + preceded( + param_key("alpn"), + delimited( + char('"'), + separated_list1( + char(','), + delimited( + space0, + take_while1(|c: char| c != ',' && c != '"' && !c.is_whitespace()), + space0, + ), + ), + char('"'), + ), + ), + |alpn| { + ( + SvcParamKey::Alpn, + SvcParamValue::Alpn(Alpn( + alpn.into_iter().map(|s: &str| s.to_string()).collect(), + )), + ) + }, + ); + + let port = map(preceded(param_key("port"), u16), |port| { + (SvcParamKey::Port, SvcParamValue::Port(port)) + }); + + let ipv4hint = map( + preceded( + param_key("ipv4hint"), + alt(( + delimited( + char('"'), + separated_list1( + char(','), + delimited(space0, map(nom_recipes::ipv4, A::from), space0), + ), + char('"'), + ), + separated_list1( + char(','), + delimited(space0, map(nom_recipes::ipv4, A::from), space0), + ), + )), + ), + |ip_addrs| { + ( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(ip_addrs)), + ) + }, + ); + let ipv6hint = map( + preceded( + param_key("ipv6hint"), + alt(( + delimited( + char('"'), + separated_list1( + char(','), + delimited(space0, map(nom_recipes::ipv6, AAAA::from), space0), + ), + char('"'), + ), + separated_list1( + char(','), + delimited(space0, map(nom_recipes::ipv6, AAAA::from), space0), + ), + )), + ), + |ip_addrs| { + ( + SvcParamKey::Ipv6Hint, + SvcParamValue::Ipv6Hint(IpHint(ip_addrs)), + ) + }, + ); + let ech = map( + preceded( + param_key("ech"), + delimited(char('"'), is_not("\""), char('"')), + ), + |ech| { + ( + SvcParamKey::EchConfigList, + SvcParamValue::EchConfigList(EchConfigList(ech.as_bytes().to_vec())), + ) + }, + ); + + let mut svc_params = vec![]; + + let (input, _) = separated_list0( + char(','), + delimited( + space0, + alt(( + map(alt((alpn, port, ipv4hint, ipv6hint, ech)), |v| { + svc_params.push(v); + }), + map(preceded(param_key("target"), NomParser::parse), |v| { + target_name = v; + }), + map(preceded(param_key("priority"), u16), |v| { + svc_priority = v; + }), + map(space0, |_| {}), + )), + space0, + ), + )(input)?; + + Ok((input, SVCB::new(svc_priority, target_name, svc_params))) + } +} + +impl NomParser for HTTPS { + fn parse(input: &str) -> IResult<&str, Self> { + map(SVCB::parse, HTTPS)(input) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_svcb() { + assert_eq!( + SVCB::parse(r#"ech="aaa""#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![( + SvcParamKey::EchConfigList, + SvcParamValue::EchConfigList(EchConfigList(b"aaa".to_vec())) + ),] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=127.0.0.1"#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec!["127.0.0.1".parse().unwrap()])) + ),] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=127.0.0.1,192.168.1.1"#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap() + ])) + ),] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=127.0.0.1, 192.168.1.1"#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap() + ])) + ),] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=" 127.0.0.1, 192.168.1.1""#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap() + ])) + ),] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=127.0.0.1,,ipv6hint=::1"#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![ + ( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec!["127.0.0.1".parse().unwrap()])) + ), + ( + SvcParamKey::Ipv6Hint, + SvcParamValue::Ipv6Hint(IpHint(vec!["::1".parse().unwrap()])) + ) + ] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=127.0.0.1, ,ipv6hint=::1"#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![ + ( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec!["127.0.0.1".parse().unwrap()])) + ), + ( + SvcParamKey::Ipv6Hint, + SvcParamValue::Ipv6Hint(IpHint(vec!["::1".parse().unwrap()])) + ) + ] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=127.0.0.1,ipv6hint="::1""#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![ + ( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec!["127.0.0.1".parse().unwrap()])) + ), + ( + SvcParamKey::Ipv6Hint, + SvcParamValue::Ipv6Hint(IpHint(vec!["::1".parse().unwrap()])) + ) + ] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"ipv4hint=127.0.0.1,ipv6hint="::1, 2001:db8::1""#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![ + ( + SvcParamKey::Ipv4Hint, + SvcParamValue::Ipv4Hint(IpHint(vec!["127.0.0.1".parse().unwrap()])) + ), + ( + SvcParamKey::Ipv6Hint, + SvcParamValue::Ipv6Hint(IpHint(vec![ + "::1".parse().unwrap(), + "2001:db8::1".parse().unwrap() + ])) + ) + ] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"alpn="h2,http/1.1""#), + Ok(( + "", + SVCB::new( + 0, + ".".parse().unwrap(), + vec![( + SvcParamKey::Alpn, + SvcParamValue::Alpn(Alpn(vec!["h2".to_string(), "http/1.1".to_string()])) + ),] + ) + )) + ); + + assert_eq!( + SVCB::parse(r#"alpn="h2,http/1.1" , priority=3"#), + Ok(( + "", + SVCB::new( + 3, + ".".parse().unwrap(), + vec![( + SvcParamKey::Alpn, + SvcParamValue::Alpn(Alpn(vec!["h2".to_string(), "http/1.1".to_string()])) + ),] + ) + )) + ); + } +} diff --git a/src/dns_client.rs b/src/dns_client.rs index 199bfa96..9541e044 100644 --- a/src/dns_client.rs +++ b/src/dns_client.rs @@ -1454,7 +1454,7 @@ mod tests { .join_all() .await; - assert!(results.into_iter().all(|r| r)); + assert!(results.into_iter().any(|r| r)); } #[tokio::test] diff --git a/src/dns_conf.rs b/src/dns_conf.rs index 88e53d48..320c1028 100644 --- a/src/dns_conf.rs +++ b/src/dns_conf.rs @@ -543,7 +543,7 @@ impl RuntimeConfig { &self.cnames } - pub fn valid_nftsets(&self) -> Vec<&ConfigForIP> { + pub fn valid_nftsets(&self) -> Vec<&ConfigForIP> { self.nftsets .iter() .flat_map(|x| &x.config) @@ -625,6 +625,7 @@ impl RuntimeConfigBuilder { &domain_sets, &cfg.cnames, &cfg.srv_records, + &cfg.https_records, &cfg.nftsets, ); @@ -790,7 +791,8 @@ impl RuntimeConfigBuilder { CacheCheckpointTime(v) => self.cache.checkpoint_time = Some(v), CNAME(v) => self.cnames.push(v), ExpandPtrFromAddress(v) => self.expand_ptr_from_address = Some(v), - NftSet(config) => self.nftsets.push(config), + NftSet(v) => self.nftsets.push(v), + HttpsRecord(v) => self.https_records.push(v), Server(server) => self.nameservers.push(server), ResponseMode(mode) => self.response_mode = Some(mode), ResolvHostname(v) => self.resolv_hostname = Some(v), @@ -833,7 +835,7 @@ impl RuntimeConfigBuilder { ConfFile(v) => self.load_file(v).expect("load_file failed"), DnsmasqLeaseFile(v) => self.dnsmasq_lease_file = Some(v), ResolvFile(v) => self.resolv_file = Some(v), - SRV(v) => self.srv_records.push(v), + SrvRecord(v) => self.srv_records.push(v), DomainRule(v) => self.domain_rules.push(v), ForwardRule(v) => self.forward_rules.push(v), User(v) => self.user = Some(v), @@ -1532,4 +1534,11 @@ mod tests { // assert!(domain_set.is_match(&Name::from_str("ads3.net").unwrap().into())); // assert!(domain_set.is_match(&Name::from_str("q.ads3.net").unwrap().into())); } + + #[test] + fn test_parse_https_record() { + let cfg = RuntimeConfig::builder().with("https-record #").build(); + assert_eq!(cfg.https_records.len(), 1); + assert_eq!(cfg.https_records[0].config, HttpsRecordRule::SOA); + } } diff --git a/src/dns_mw_addr.rs b/src/dns_mw_addr.rs index 67bc8d8a..2addf756 100644 --- a/src/dns_mw_addr.rs +++ b/src/dns_mw_addr.rs @@ -105,7 +105,7 @@ impl Middleware for AddressMiddle } fn handle_rule_addr(query_type: RecordType, ctx: &DnsContext) -> Option { - use RecordType::{A, AAAA, HTTPS}; + use RecordType::{A, AAAA}; let cfg = ctx.cfg(); let server_opts = ctx.server_opts(); @@ -126,7 +126,7 @@ fn handle_rule_addr(query_type: RecordType, ctx: &DnsContext) -> Option { } // skip address rule. - if server_opts.no_rule_addr() || (!query_type.is_ip_addr() && query_type != HTTPS) { + if server_opts.no_rule_addr() || !query_type.is_ip_addr() { return None; } @@ -139,10 +139,7 @@ fn handle_rule_addr(query_type: RecordType, ctx: &DnsContext) -> Option { match address { IPv4(ipv4) if query_type == A => return Some(RData::A(ipv4.into())), IPv6(ipv6) if query_type == AAAA => return Some(RData::AAAA(ipv6.into())), - IPv4(_) | IPv6(_) - if !no_rule_soa - && (query_type == AAAA || query_type == A || query_type == HTTPS) => - { + IPv4(_) | IPv6(_) if !no_rule_soa && (query_type == AAAA || query_type == A) => { return Some(RData::default_soa()) } SOA if !no_rule_soa => return Some(RData::default_soa()), diff --git a/src/dns_mw_cname.rs b/src/dns_mw_cname.rs index 42ccd460..0803ca83 100644 --- a/src/dns_mw_cname.rs +++ b/src/dns_mw_cname.rs @@ -2,7 +2,7 @@ use std::time::{Duration, Instant}; use crate::libdns::proto::rr::rdata::CNAME; -use crate::config::CName; +use crate::config::CNameRule; use crate::dns::*; use crate::middleware::*; @@ -19,8 +19,8 @@ impl Middleware for DnsCNameMiddl let cname = match &ctx.domain_rule { Some(rule) => rule.get(|r| match &r.cname { Some(cname) => match cname { - CName::IGN => None, - CName::Value(n) => Some(n.clone()), + CNameRule::IGN => None, + CNameRule::Value(n) => Some(n.clone()), }, None => None, }), diff --git a/src/dns_mw_zone.rs b/src/dns_mw_zone.rs index d61e32d1..a4eea581 100644 --- a/src/dns_mw_zone.rs +++ b/src/dns_mw_zone.rs @@ -6,6 +6,7 @@ use std::str::FromStr; use crate::libdns::proto::rr::rdata::PTR; use ipnet::IpNet; +use crate::config::HttpsRecordRule; use crate::dns::*; use crate::dns_conf::RuntimeConfig; use crate::infra::ipset::IpSet; @@ -101,13 +102,74 @@ impl Middleware for DnsZoneMiddle } } RecordType::SRV => { - if let Some(srv) = ctx.domain_rule.as_ref().and_then(|r| r.srv.clone()) { + if let Some(srv) = ctx + .domain_rule + .as_ref() + .and_then(|r| r.get_ref(|r| r.srv.as_ref())) + { return Ok(DnsResponse::from_rdata( req.query().original().to_owned(), - RData::SRV(srv), + RData::SRV(srv.clone()), )); } } + RecordType::HTTPS => { + if let Some(https_rule) = ctx + .domain_rule + .as_ref() + .and_then(|r| r.get_ref(|r| r.https.as_ref())) + { + match https_rule { + HttpsRecordRule::Ignore => (), + HttpsRecordRule::SOA => { + return Ok(DnsResponse::from_rdata( + req.query().original().to_owned(), + RData::default_soa(), + )); + } + HttpsRecordRule::Filter { + no_ipv4_hint, + no_ipv6_hint, + } => { + use crate::libdns::proto::rr::rdata::{svcb::SvcParamKey, SVCB}; + let no_ipv4_hint = *no_ipv4_hint; + let no_ipv6_hint = *no_ipv6_hint; + return match next.run(ctx, req).await { + Ok(mut lookup) => { + for record in lookup.answers_mut() { + if let Some(https) = record.data_mut().as_https_mut() { + let svc_params = https + .svc_params() + .iter() + .filter(|(k, _)| match k { + SvcParamKey::Ipv4Hint => !no_ipv4_hint, + SvcParamKey::Ipv6Hint => !no_ipv6_hint, + _ => true, + }) + .cloned() + .collect(); + + https.0 = SVCB::new( + https.svc_priority(), + https.target_name().clone(), + svc_params, + ); + } + } + Ok(lookup) + } + Err(err) => Err(err), + }; + } + HttpsRecordRule::RecordData(https) => { + return Ok(DnsResponse::from_rdata( + req.query().original().to_owned(), + RData::HTTPS(https.clone()), + )) + } + } + } + } _ => (), } diff --git a/src/dns_rule.rs b/src/dns_rule.rs index 051dc1c7..1d167cea 100644 --- a/src/dns_rule.rs +++ b/src/dns_rule.rs @@ -6,7 +6,7 @@ use crate::{ collections::DomainMap, config::{ AddressRules, CNameRules, ConfigForDomain, ConfigForIP, Domain, DomainRule, DomainRules, - DomainSets, ForwardRules, NftsetConfig, SrvRecords, + DomainSets, ForwardRules, HttpsRecords, NFTsetConfig, SrvRecords, }, }; @@ -16,6 +16,7 @@ pub struct DomainRuleMap { } impl DomainRuleMap { + #[allow(clippy::too_many_arguments)] pub fn create( domain_rules: &DomainRules, address_rules: &AddressRules, @@ -23,24 +24,24 @@ impl DomainRuleMap { domain_sets: &DomainSets, cnames: &CNameRules, srv_records: &SrvRecords, - nftsets: &Vec>>>, + https_records: &HttpsRecords, + nftsets: &Vec>>>, ) -> Self { + let expand_domain = |domain: &Domain| match &domain { + Domain::Name(name) => { + vec![name.clone()] + } + Domain::Set(s) => domain_sets + .get(s) + .map(|v| v.iter().map(|n| n.to_owned()).collect::>()) + .unwrap_or_default(), + }; + let mut name_rule_map = HashMap::::new(); // append domain_rules - for rule in domain_rules { - let names = match &rule.domain { - Domain::Name(name) => { - vec![name.clone()] - } - Domain::Set(s) => domain_sets - .get(s) - .map(|v| v.iter().map(|n| n.to_owned()).collect::>()) - .unwrap_or_default(), - }; - - for name in names { + for name in expand_domain(&rule.domain) { // overide *(name_rule_map.entry(name).or_default()) += rule.config.clone(); } @@ -48,82 +49,41 @@ impl DomainRuleMap { // append address rule for rule in address_rules.iter() { - let names = match &rule.domain { - Domain::Name(name) => { - vec![name.clone()] - } - Domain::Set(s) => domain_sets - .get(s) - .map(|v| v.iter().map(|n| n.to_owned()).collect::>()) - .unwrap_or_default(), - }; - - for name in names { - name_rule_map.entry(name).or_default().address = Some(rule.address); + for name in expand_domain(&rule.domain) { + (name_rule_map.entry(name).or_default()).address = Some(rule.address); } } // append forward rule for rule in forward_rules.iter() { - let names = match &rule.domain { - Domain::Name(name) => { - vec![name.clone()] - } - Domain::Set(s) => domain_sets - .get(s) - .map(|v| v.iter().map(|n| n.to_owned()).collect::>()) - .unwrap_or_default(), - }; - - for name in names { + for name in expand_domain(&rule.domain) { name_rule_map.entry(name).or_default().nameserver = Some(rule.nameserver.clone()) } } // set cname for rule in cnames { - let names = match &rule.domain { - Domain::Name(name) => { - vec![name.clone()] - } - Domain::Set(s) => domain_sets - .get(s) - .map(|v| v.iter().map(|n| n.to_owned()).collect::>()) - .unwrap_or_default(), - }; - for name in names { + for name in expand_domain(&rule.domain) { name_rule_map.entry(name).or_default().cname = Some(rule.config.clone()) } } // set srv for rule in srv_records { - let names = match &rule.domain { - Domain::Name(name) => { - vec![name.clone()] - } - Domain::Set(s) => domain_sets - .get(s) - .map(|v| v.iter().map(|n| n.to_owned()).collect::>()) - .unwrap_or_default(), - }; - for name in names { + for name in expand_domain(&rule.domain) { name_rule_map.entry(name).or_default().srv = Some(rule.config.clone()) } } + // set https + for rule in https_records { + for name in expand_domain(&rule.domain) { + name_rule_map.entry(name).or_default().https = Some(rule.config.clone()) + } + } + for rule in nftsets { - let names = match &rule.domain { - Domain::Name(name) => { - vec![name.clone()] - } - Domain::Set(s) => domain_sets - .get(s) - .map(|v| v.iter().map(|n| n.to_owned()).collect::>()) - .unwrap_or_default(), - }; - - for name in names { + for name in expand_domain(&rule.domain) { name_rule_map.entry(name).or_default().nftset = Some(rule.config.clone()); } } @@ -175,7 +135,11 @@ impl DomainRuleTreeNode { } pub fn get(&self, f: impl Fn(&Self) -> Option) -> Option { - f(self).or_else(|| self.zone().map(|z| f(z)).unwrap_or_default()) + f(self).or_else(|| self.zone().and_then(|z| f(z))) + } + + pub fn get_ref(&self, f: impl Fn(&Self) -> Option<&T>) -> Option<&T> { + f(self).or_else(|| self.zone().and_then(|z| f(z))) } } @@ -244,6 +208,7 @@ mod tests { &Default::default(), &Default::default(), &Default::default(), + &Default::default(), ); let rule1 = map.find(&"z.a.b.c.www.example.com".parse().unwrap());