Skip to content

Commit

Permalink
🎨 cleanup after shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish authored Apr 1, 2023
1 parent 6ca9e4a commit 462c910
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ once_cell = "1.16.0"
# async/await
futures = { version = "0.3.5", default-features = false, features = ["std"] }
futures-intrusive = "0.5"
futures-util = "0.3"
futures-util = { version = "0.3", default-features = false, features = ["std"]}
async-trait = "0.1.43"
tokio = { version = "1.21", features = ["time", "rt", "signal", "macros", "parking_lot"] }
tokio-rustls = "0.23.4"
socket2 = { version = "*", features = ["all"]}
socket2 = { version = "0.5", features = ["all"]}

# log
tracing = "0.1"
Expand Down
15 changes: 15 additions & 0 deletions src/dns_server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use cfg_if::cfg_if;
use futures::Future;
use futures_util::future::{abortable, Aborted};

use std::{
collections::{hash_map::Entry, HashMap},
Expand All @@ -11,6 +12,7 @@ use crate::{
dns_conf::ServerOpts,
dns_error::LookupError,
log::{debug, error, info, warn},
third_ext::FutureJoinAllExt,
};

use trust_dns_proto::{
Expand Down Expand Up @@ -53,6 +55,19 @@ impl ServerRegistry {
})),
}
}

pub async fn abort(self) -> Result<(), Aborted> {
let (server, abort_handle) = abortable(async move {
let _ = self
.servers
.into_values()
.map(|s| s.block_until_done())
.join_all()
.await;
});
abort_handle.abort();
server.await
}
}

pub struct ServerHandler {
Expand Down
39 changes: 17 additions & 22 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![allow(dead_code)]

use cfg_if::cfg_if;
use cli::*;
use dns_conf::BindServer;
use std::{io, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
Expand Down Expand Up @@ -312,7 +311,10 @@ fn run_server(conf: Option<PathBuf>) {

info!("server starting up");

runtime.block_on(signal::terminate()).unwrap_or_default();
runtime.block_on(async move {
let _ = signal::terminate().await;
let _ = server.abort().await;
});

runtime.shutdown_timeout(Duration::from_secs(5));

Expand Down Expand Up @@ -511,19 +513,15 @@ fn tcp(
debug!("binding {} to {:?}{}", bind_type, sock_addr, device_note);
let tcp_listener = std::net::TcpListener::bind(sock_addr)?;

if let Some(device) = bind_device {
cfg_if! {
if #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] {
let sock_ref = socket2::SockRef::from(&tcp_listener);
sock_ref.bind_device(Some(device.as_bytes()))?;
} else {
drop(device)
}
{
let sock_ref = socket2::SockRef::from(&tcp_listener);
sock_ref.set_nonblocking(true)?;
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(device) = bind_device {
sock_ref.bind_device(Some(device.as_bytes()))?;
}
}

tcp_listener.set_nonblocking(true)?;

let tcp_listener = TcpListener::from_std(tcp_listener)?;

info!(
Expand All @@ -546,19 +544,16 @@ fn udp(sock_addr: SocketAddr, bind_device: Option<&str>, bind_type: &str) -> io:
debug!("binding {} to {:?}{}", bind_type, sock_addr, device_note);
let udp_socket = std::net::UdpSocket::bind(sock_addr)?;

if let Some(device) = bind_device {
cfg_if! {
if #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] {
let sock_ref = socket2::SockRef::from(&udp_socket);
sock_ref.bind_device(Some(device.as_bytes()))?;
} else {
drop(device)
}
{
let sock_ref = socket2::SockRef::from(&udp_socket);
sock_ref.set_nonblocking(true)?;

#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(device) = bind_device {
sock_ref.bind_device(Some(device.as_bytes()))?;
}
}

udp_socket.set_nonblocking(true)?;

let udp_socket = UdpSocket::from_std(udp_socket)?;

info!(
Expand Down

0 comments on commit 462c910

Please sign in to comment.