diff --git a/.vscode/launch.json b/.vscode/launch.json index a43e7870..4cb8815a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -19,9 +19,9 @@ }, "args": [ "run", - // "-c", "tests/example/smartdns.conf", - "-c", "etc/smartdns/smartdns.conf", - "-d" + // "-c", "etc/smartdns/smartdns.conf", + "-c", "tests/example/smartdns.conf", + // "-d" ], "cwd": "${workspaceFolder}" }, diff --git a/Cargo.toml b/Cargo.toml index 47694906..bfc12db4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ async-trait = "0.1.43" time = "0.3" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["std", "fmt", "env-filter"] } +# tracing-appender = "0.2" tokio = { version = "1.21", features = ["time", "rt", "signal"] } url = "2.3.1" trust-dns-proto = { version = "0.22.0", features = ["dns-over-https-rustls"]} diff --git a/README.md b/README.md index 61c1950a..832004d1 100644 --- a/README.md +++ b/README.md @@ -132,10 +132,10 @@ sudo ./target/release/smartdns run -c ./etc/smartdns/smartdns.conf | rr-ttl-reply-max | 允许返回给客户端的最大 TTL 值 | :construction: | 远程查询结果 | 大于 0 的数字 | rr-ttl-reply-max 60 | | local-ttl | 本地HOST,address的TTL值 | :construction: | rr-ttl-min | 大于 0 的数字 | local-ttl 60 | | max-reply-ip-num | 允许返回给客户的最大IP数量 | :construction: | IP数量 | 大于 0 的数字 | max-reply-ip-num 1 | -| log-level | 设置日志级别 | :construction: | error | fatal、error、warn、notice、info 或 debug | log-level error | -| log-file | 日志文件路径 | :construction: | /var/log/smartdns/smartdns.log | 合法路径字符串 | log-file /var/log/smartdns/smartdns.log | -| log-size | 日志大小 | :construction: | 128K | 数字 + K、M 或 G | log-size 128K | -| log-num | 日志归档个数 | :construction: | 2 | 大于等于 0 的数字 | log-num 2 | +| log-level | 设置日志级别 | :white_check_mark: | error | fatal、error、warn、notice、info 或 debug | log-level error | +| log-file | 日志文件路径 | :white_check_mark: | /var/log/smartdns/smartdns.log | 合法路径字符串 | log-file /var/log/smartdns/smartdns.log | +| log-size | 日志大小 | :white_check_mark: | 128K | 数字 + K、M 或 G | log-size 128K | +| log-num | 日志归档个数 | :white_check_mark: | 2 | 大于等于 0 的数字 | log-num 2 | | audit-enable | 设置审计启用 | :white_check_mark: | no | [yes\|no] | audit-enable yes | | audit-file | 审计文件路径 | :white_check_mark: | /var/log/smartdns/smartdns-audit.log | 合法路径字符串,log 后缀可改成 csv | audit-file /var/log/smartdns/smartdns-audit.log | | audit-size | 审计大小 | :white_check_mark: | 128K | 数字 + K、M 或 G | audit-size 128K | diff --git a/src/dns.rs b/src/dns.rs index d95ae3a6..25c17093 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,10 +1,13 @@ +use cfg_if::cfg_if; use std::fmt::Debug; +use std::path::PathBuf; use std::{str::FromStr, sync::Arc, time::Duration}; use trust_dns_proto::rr::rdata::SOA; use trust_dns_resolver::error::ResolveError; use crate::dns_server::Request as OriginRequest; +use crate::log::info; use crate::{dns_client::DnsClient, dns_conf::SmartDnsConfig}; pub use trust_dns_proto::{ @@ -60,6 +63,34 @@ pub type DnsResponse = Lookup; pub type DnsError = ResolveError; impl SmartDnsConfig { + pub fn summary(&self) { + info!(r#"whoami 👉 {}"#, self.server_name()); + + const DEFAULT_GROUP: &'static str = "default"; + for (group, servers) in self.servers.iter() { + if group == DEFAULT_GROUP { + continue; + } + for server in servers { + info!( + "upstream server: {} [group: {}]", + server.url.to_string(), + group + ); + } + } + + if let Some(ss) = self.servers.get(DEFAULT_GROUP) { + for s in ss { + info!( + "upstream server: {} [group: {}]", + s.url.to_string(), + DEFAULT_GROUP + ); + } + } + } + pub fn server_name(&self) -> Name { match self.server_name { Some(ref server_name) => Some(server_name.clone()), @@ -90,6 +121,54 @@ impl SmartDnsConfig { pub fn audit_num(&self) -> usize { self.audit_num.unwrap_or(2) } + + pub fn log_enabled(&self) -> bool { + self.log_num() > 0 + } + + pub fn log_file(&self) -> PathBuf { + match self.log_file.as_ref() { + Some(e) => e.to_owned(), + None => { + cfg_if! { + if #[cfg(target_os="windows")] { + let mut path = std::env::temp_dir(); + path.push("smartdns"); + path.push("smartdns.log"); + path + } else { + PathBuf::from(r"/var/log/smartdns/smartdns.log") + } + + } + } + } + } + + pub fn log_level(&self) -> tracing::Level { + use tracing::Level; + match self + .log_level + .as_ref() + .map(|s| s.as_str()) + .unwrap_or("error") + { + "tarce" => Level::TRACE, + "debug" => Level::DEBUG, + "info" | "notice" => Level::INFO, + "warn" => Level::WARN, + "error" | "fatal" => Level::ERROR, + _ => Level::ERROR, + } + } + + pub fn log_num(&self) -> u64 { + self.log_num.unwrap_or(2) + } + pub fn log_size(&self) -> u64 { + use byte_unit::n_kb_bytes; + self.audit_size.unwrap_or(n_kb_bytes(128) as u64) + } } pub trait DefaultSOA { diff --git a/src/dns_conf.rs b/src/dns_conf.rs index 35353eff..7e5d3ee1 100644 --- a/src/dns_conf.rs +++ b/src/dns_conf.rs @@ -13,7 +13,7 @@ use trust_dns_resolver::Name; use crate::dns::RecordType; use crate::dns_url::DnsUrl; -use crate::log::{error, info, warn}; +use crate::log::{debug, error, info, warn}; const DEFAULT_SERVER: &'static str = "https://cloudflare-dns.com/dns-query"; @@ -327,12 +327,6 @@ impl SmartDnsConfig { .push(DnsServer::from_str(DEFAULT_SERVER).unwrap()); } - if let Some(ss) = cfg.servers.get("default") { - for s in ss { - info!("default server: {}", s.url.to_string()); - } - } - cfg } } @@ -695,8 +689,6 @@ mod parse { use super::*; use std::{collections::hash_map::Entry, ffi::OsStr, net::AddrParseError}; - use crate::log::{info, warn}; - impl SmartDnsConfig { pub fn load_file>( &mut self, @@ -705,6 +697,7 @@ mod parse { let path = find_path(path, self.conf_file.as_ref()); if path.exists() { + debug!("loading extra configuration from {:?}", path); let file = File::open(path)?; let reader = BufReader::new(file); for line in reader.lines() { @@ -846,7 +839,7 @@ mod parse { } if server.group.is_some() { - info!( + debug!( "append server {} to group {}", server.url.to_string(), server.group.as_ref().unwrap() diff --git a/src/infra/mapped_file.rs b/src/infra/mapped_file.rs index a3c03f76..cfa3d655 100644 --- a/src/infra/mapped_file.rs +++ b/src/infra/mapped_file.rs @@ -3,6 +3,7 @@ use std::fs; use std::fs::File; use std::io::{self, Write}; use std::path::{Path, PathBuf}; +use std::sync::Mutex; use chrono::Local; @@ -64,6 +65,20 @@ impl MappedFile { } } + #[inline] + pub fn touch(&mut self) -> io::Result<()> { + if !self.path().exists() { + let dir = self + .path() + .parent() + .ok_or(io::Error::from(io::ErrorKind::NotFound))?; + fs::create_dir_all(dir)?; + } + let file = self.get_active_file()?; + file.sync_all()?; + Ok(()) + } + pub fn mapped_files(&self) -> io::Result> { match ( self.path @@ -204,6 +219,35 @@ impl Write for MappedFile { } } +pub struct MutexMappedFile(pub Mutex); + +impl MutexMappedFile { + #[inline] + pub fn open>(path: P, size: u64, num: Option) -> Self { + Self(Mutex::new(MappedFile::open(path, size, num))) + } +} + +impl io::Write for MutexMappedFile { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.get_mut().unwrap().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.get_mut().unwrap().flush() + } +} + +impl io::Write for &MutexMappedFile { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.lock().unwrap().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.lock().unwrap().flush() + } +} + #[cfg(test)] mod tests { diff --git a/src/log.rs b/src/log.rs index a323f018..447ac143 100644 --- a/src/log.rs +++ b/src/log.rs @@ -1,32 +1,111 @@ -use std::{env, fmt}; +use std::{env, fmt, io, path::Path}; use time::OffsetDateTime; -use tracing::{Event, Subscriber}; +use tracing::{ + dispatcher::{set_default, set_global_default}, + subscriber::DefaultGuard, + Dispatch, Event, Subscriber, +}; use tracing_subscriber::{ - fmt::{format, FmtContext, FormatEvent, FormatFields, FormattedFields}, + fmt::{ + format, writer::MakeWriterExt, FmtContext, FormatEvent, FormatFields, FormattedFields, + MakeWriter, + }, prelude::__tracing_subscriber_SubscriberExt, registry::LookupSpan, - util::SubscriberInitExt, + EnvFilter, }; -pub use tracing::{debug, error, info, trace, warn}; +pub use tracing::{debug, error, info, trace, warn, Level}; -pub fn logger(level: tracing::Level) { - // Setup tracing for logging based on input - let filter = tracing_subscriber::EnvFilter::builder() - .with_default_directive(tracing::Level::WARN.into()) - .parse(all_trust_dns(level)) - .expect("failed to configure tracing/logging"); +type MappedFile = crate::infra::mapped_file::MutexMappedFile; + +pub fn init_global_default>( + path: P, + level: tracing::Level, + size: u64, + num: u64, +) -> DefaultGuard { + let file = MappedFile::open(path.as_ref(), size, Some(num as usize)); + + let writable = file + .0 + .lock() + .unwrap() + .touch() + .map(|_| true) + .unwrap_or_else(|err| { + warn!("{:?}, {:?}", path.as_ref(), err); + false + }); + + let console_level = console_level(); + let console_writer = io::stdout.with_max_level(console_level); + + let dispatch = if writable { + // log hello + { + let writer = file.with_max_level(level); + let dispatch = make_dispatch(level, writer); + + let _guard = set_default(&dispatch); + crate::hello_starting(); + } + + let file_writer = + MappedFile::open(path.as_ref(), size, Some(num as usize)).with_max_level(level); + + make_dispatch(level.max(console_level), file_writer.and(console_writer)) + } else { + make_dispatch(console_level, console_writer) + }; + + let guard = set_default(&dispatch); + + set_global_default(dispatch).expect(""); + guard +} + +pub fn default() -> DefaultGuard { + let console_level = console_level(); + let console_writer = io::stdout.with_max_level(console_level); + set_default(&make_dispatch(console_level, console_writer)) +} - let formatter = tracing_subscriber::fmt::layer().event_format(TdnsFormatter { level }); +#[inline] +fn make_dispatch MakeWriter<'writer> + 'static + Send + Sync>( + level: tracing::Level, + writer: W, +) -> Dispatch { + let layer = tracing_subscriber::fmt::layer() + .event_format(TdnsFormatter) + .with_writer(writer); + + Dispatch::from( + tracing_subscriber::registry() + .with(layer) + .with(make_filter(level)), + ) +} + +fn console_level() -> Level { + if std::env::args().any(|arg| arg == "-d" || arg == "--debug") { + tracing::Level::DEBUG + } else { + tracing::Level::INFO + } +} - tracing_subscriber::registry() - .with(formatter) - .with(filter) - .init(); +#[inline] +fn make_filter(level: tracing::Level) -> EnvFilter { + EnvFilter::builder() + .with_default_directive(tracing::Level::WARN.into()) + .parse(all_smart_dns(level)) + .expect("failed to configure tracing/logging") } -fn all_trust_dns(level: impl ToString) -> String { +#[inline] +fn all_smart_dns(level: impl ToString) -> String { format!( "named={level},smartdns={level},{env}", level = level.to_string().to_lowercase(), @@ -34,13 +113,12 @@ fn all_trust_dns(level: impl ToString) -> String { ) } +#[inline] fn get_env() -> String { env::var("RUST_LOG").unwrap_or_default() } -struct TdnsFormatter { - level: tracing::Level, -} +struct TdnsFormatter; impl FormatEvent for TdnsFormatter where @@ -59,7 +137,7 @@ where // Format values from the event's's metadata: let metadata = event.metadata(); - if self.level == tracing::Level::INFO { + if metadata.level() == &tracing::Level::INFO { write!(&mut writer, "{}:{}", now_secs, metadata.level())?; } else { write!( @@ -98,3 +176,10 @@ where writeln!(writer) } } + +impl<'a> MakeWriter<'a> for MappedFile { + type Writer = &'a MappedFile; + fn make_writer(&'a self) -> Self::Writer { + self + } +} diff --git a/src/main.rs b/src/main.rs index 4ab2d6bf..e3199f29 100644 --- a/src/main.rs +++ b/src/main.rs @@ -37,7 +37,6 @@ use dns_mw_spdt::DnsSpeedTestMiddleware; use dns_mw_zone::DnsZoneMiddleware; use dns_server::{MiddlewareBasedRequestHandler, ServerFuture}; use infra::middleware; -use log::logger; use crate::log::{debug, info}; use crate::{ @@ -68,7 +67,7 @@ pub fn version() -> &'static str { #[cfg(not(windows))] fn main() { - run_command(Cli::parse()); + Cli::parse().run(); } #[cfg(windows)] @@ -77,54 +76,54 @@ fn main() -> windows_service::Result<()> { { return service::windows_service::run(); } - run_command(Cli::parse()); + Cli::parse().run(); Ok(()) } -fn run_command(cli: Cli) { - match cli.command { - Commands::Run { conf, debug } => { - run_server(conf, debug); - } - Commands::Service { - command: service_command, - } => { - use service::*; - use ServiceCommands::*; - match service_command { - Install => install(), - Uninstall { purge } => uninstall(purge), - Start => start(), - Stop => stop(), - Restart => restart(), - Status => status(), +impl Cli { + #[inline] + pub fn run(self) { + let _guard = log::default(); + + match self.command { + Commands::Run { conf, .. } => { + run_server(conf); + } + Commands::Service { + command: service_command, + } => { + use service::*; + use ServiceCommands::*; + match service_command { + Install => install(), + Uninstall { purge } => uninstall(purge), + Start => start(), + Stop => stop(), + Restart => restart(), + Status => status(), + } } } } } -fn run_server(conf: Option, debug: bool) { - logger(if debug { - tracing::Level::DEBUG - } else { - tracing::Level::INFO - }); - - info!("Smart-DNS 🐋 {} starting", version()); +fn run_server(conf: Option) { + hello_starting(); let cfg = SmartDnsConfig::load(conf); - info!(r#"whoami 👉 "{}""#, cfg.server_name()); + let _guard = if cfg.log_enabled() { + Some(log::init_global_default( + cfg.log_file(), + cfg.log_level(), + cfg.log_size(), + cfg.log_num(), + )) + } else { + None + }; - // if !args.debug { - // cfg.log_level.as_ref().map(|lvl| { - // if let Ok(lvl) = tracing::Level::from_str(lvl) { - // logger(lvl); - // } else { - // warn!("log-level expect: debug,info,warn,error"); - // } - // }); - // } + cfg.summary(); let runtime = runtime::Builder::new_multi_thread() .enable_all() @@ -200,7 +199,7 @@ fn run_server(conf: Option, debug: bool) { // and TCP as necessary for tcp_listener in tcp_socket_addrs { - info!("binding TCP to {:?}", tcp_listener); + debug!("binding TCP to {:?}", tcp_listener); let tcp_listener = runtime .block_on(TcpListener::bind(tcp_listener)) .unwrap_or_else(|_| panic!("could not bind to tcp: {}", tcp_listener)); @@ -232,3 +231,8 @@ fn run_server(conf: Option, debug: bool) { drop(runtime); } + +#[inline] +fn hello_starting() { + info!("Smart-DNS 🐋 {} starting", version()); +} diff --git a/src/service.rs b/src/service.rs index 827c1d86..97afabaa 100644 --- a/src/service.rs +++ b/src/service.rs @@ -228,7 +228,6 @@ pub mod windows_service { // Handle stop ServiceControl::Stop => { - std::fs::write("D:\\sss12366.txt", "即将发送 ctrl+c").unwrap(); unsafe { windows::Win32::System::Console::GenerateConsoleCtrlEvent( windows::Win32::System::Console::CTRL_C_EVENT, @@ -266,7 +265,7 @@ pub mod windows_service { let args = std::env::args() .filter(|s| s != "--ws7642ea814a90496daaa54f2820254f12") .collect::>(); - crate::run_command(Cli::parse_from(args)); + Cli::parse_from(args).run(); } // Tell the system that service has stopped. diff --git a/tests/example/smartdns.conf b/tests/example/smartdns.conf index 09a5471f..bb72454f 100644 --- a/tests/example/smartdns.conf +++ b/tests/example/smartdns.conf @@ -8,7 +8,8 @@ rr-ttl-min 10 rr-ttl-max 30 log-size 64K log-num 1 -log-level error +log-level debug +log-file ./logs/smartdns.log audit-enable yes audit-file ./logs/smartdns-audit.csv resolv-file /tmp/resolv.conf.d/resolv.conf.auto