Skip to content

Commit 6fa77d2

Browse files
authored
Refactor OsInputOutput (combine interfaces & frames into single Vec) (#310)
* Refactor `OsInputOutput` (combine interfaces & frames into single Vec) * Add note on handling a separate failure case * Reduce code duplication
1 parent 89e1140 commit 6fa77d2

File tree

5 files changed

+131
-145
lines changed

5 files changed

+131
-145
lines changed

src/main.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ pub struct OpenSockets {
9393
}
9494

9595
pub struct OsInputOutput {
96-
pub network_interfaces: Vec<NetworkInterface>,
97-
pub network_frames: Vec<Box<dyn DataLinkReceiver>>,
96+
pub interfaces_with_frames: Vec<(NetworkInterface, Box<dyn DataLinkReceiver>)>,
9897
pub get_open_sockets: fn() -> OpenSockets,
9998
pub terminal_events: Box<dyn Iterator<Item = Event> + Send>,
10099
pub dns_client: Option<dns::Client>,
@@ -281,9 +280,8 @@ where
281280
active_threads.push(terminal_event_handler);
282281

283282
let sniffer_threads = os_input
284-
.network_interfaces
283+
.interfaces_with_frames
285284
.into_iter()
286-
.zip(os_input.network_frames)
287285
.map(|(iface, frames)| {
288286
let name = format!("sniffing_handler_{}", iface.name);
289287
let running = running.clone();

src/os/shared.rs

+103-127
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ use std::{
44
time,
55
};
66

7+
use anyhow::{anyhow, bail};
78
use crossterm::event::{read, Event};
9+
use itertools::Itertools;
810
use pnet::datalink::{self, Channel::Ethernet, Config, DataLinkReceiver, NetworkInterface};
911
use tokio::runtime::Runtime;
1012

11-
use crate::{network::dns, os::errors::GetInterfaceError, OsInputOutput};
13+
use crate::{mt_log, network::dns, os::errors::GetInterfaceError, OsInputOutput};
1214

1315
#[cfg(target_os = "linux")]
1416
use crate::os::linux::get_open_sockets;
@@ -63,160 +65,134 @@ fn get_interface(interface_name: &str) -> Option<NetworkInterface> {
6365
}
6466

6567
fn create_write_to_stdout() -> Box<dyn FnMut(String) + Send> {
68+
let mut stdout = io::stdout();
6669
Box::new({
67-
let mut stdout = io::stdout();
6870
move |output: String| {
6971
writeln!(stdout, "{}", output).unwrap();
7072
}
7173
})
7274
}
7375

74-
#[derive(Debug)]
75-
pub struct UserErrors {
76-
permission: Option<String>,
77-
other: Option<String>,
78-
}
79-
80-
pub fn collect_errors<'a, I>(network_frames: I) -> String
81-
where
82-
I: Iterator<
83-
Item = (
84-
&'a NetworkInterface,
85-
Result<Box<dyn DataLinkReceiver>, GetInterfaceError>,
86-
),
87-
>,
88-
{
89-
let errors = network_frames.fold(
90-
UserErrors {
91-
permission: None,
92-
other: None,
93-
},
94-
|acc, (_, elem)| {
95-
if let Some(iface_error) = elem.err() {
96-
match iface_error {
97-
GetInterfaceError::PermissionError(interface_name) => {
98-
if let Some(prev_interface) = acc.permission {
99-
return UserErrors {
100-
permission: Some(format!("{prev_interface}, {interface_name}")),
101-
..acc
102-
};
103-
} else {
104-
return UserErrors {
105-
permission: Some(interface_name),
106-
..acc
107-
};
108-
}
109-
}
110-
error => {
111-
if let Some(prev_errors) = acc.other {
112-
return UserErrors {
113-
other: Some(format!("{prev_errors} \n {error}")),
114-
..acc
115-
};
116-
} else {
117-
return UserErrors {
118-
other: Some(format!("{error}")),
119-
..acc
120-
};
121-
}
122-
}
123-
};
124-
}
125-
acc
126-
},
127-
);
128-
if let Some(interface_name) = errors.permission {
129-
if let Some(other_errors) = errors.other {
130-
format!(
131-
"\n\n{interface_name}: {} \nAdditional Errors: \n {other_errors}",
132-
eperm_message(),
133-
)
134-
} else {
135-
format!("\n\n{interface_name}: {}", eperm_message())
136-
}
137-
} else {
138-
let other_errors = errors
139-
.other
140-
.expect("asked to collect errors but found no errors");
141-
format!("\n\n {other_errors}")
142-
}
143-
}
144-
14576
pub fn get_input(
14677
interface_name: Option<&str>,
14778
resolve: bool,
14879
dns_server: Option<Ipv4Addr>,
14980
) -> anyhow::Result<OsInputOutput> {
150-
let network_interfaces = if let Some(name) = interface_name {
151-
match get_interface(name) {
152-
Some(interface) => vec![interface],
153-
None => {
154-
anyhow::bail!("Cannot find interface {name}");
155-
// the homebrew formula relies on this wording, please be careful when changing
156-
}
157-
}
158-
} else {
159-
datalink::interfaces()
160-
};
161-
162-
#[cfg(target_os = "windows")]
163-
let network_frames = network_interfaces
164-
.iter()
165-
.filter(|iface| !iface.ips.is_empty())
166-
.map(|iface| (iface, get_datalink_channel(iface)));
167-
#[cfg(not(target_os = "windows"))]
168-
let network_frames = network_interfaces
169-
.iter()
170-
.filter(|iface| iface.is_up() && !iface.ips.is_empty())
171-
.map(|iface| (iface, get_datalink_channel(iface)));
172-
173-
let (available_network_frames, network_interfaces) = {
174-
let network_frames = network_frames.clone();
175-
let mut available_network_frames = Vec::new();
176-
let mut available_interfaces: Vec<NetworkInterface> = Vec::new();
177-
for (iface, rx) in network_frames.filter_map(|(iface, channel)| {
178-
if let Ok(rx) = channel {
179-
Some((iface, rx))
81+
// get the user's requested interface, if any
82+
// IDEA: allow requesting multiple interfaces
83+
let requested_interfaces = interface_name
84+
.map(|name| get_interface(name).ok_or_else(|| anyhow!("Cannot find interface {name}")))
85+
.transpose()?
86+
.map(|interface| vec![interface]);
87+
88+
// take the user's requested interfaces (or all interfaces), and filter for up ones
89+
let available_interfaces = requested_interfaces
90+
.unwrap_or_else(datalink::interfaces)
91+
.into_iter()
92+
.filter(|interface| {
93+
// see https://github.com/libpnet/libpnet/issues/564
94+
let keep = if cfg!(target_os = "windows") {
95+
!interface.ips.is_empty()
18096
} else {
181-
None
97+
interface.is_up() && !interface.ips.is_empty()
98+
};
99+
if !keep {
100+
mt_log!(debug, "{} is down. Skipping it.", interface.name);
182101
}
183-
}) {
184-
available_interfaces.push(iface.clone());
185-
available_network_frames.push(rx);
186-
}
187-
(available_network_frames, available_interfaces)
188-
};
102+
keep
103+
})
104+
.collect_vec();
189105

190-
if available_network_frames.is_empty() {
191-
let all_errors = collect_errors(network_frames.clone());
192-
if !all_errors.is_empty() {
193-
anyhow::bail!(all_errors);
194-
}
106+
// bail if no interfaces are up
107+
if available_interfaces.is_empty() {
108+
bail!("Failed to find any network interface to listen on.");
109+
}
195110

196-
anyhow::bail!("Failed to find any network interface to listen on.");
111+
// try to get a frame receiver for each interface
112+
let interfaces_with_frames_res = available_interfaces
113+
.into_iter()
114+
.map(|interface| {
115+
let frames_res = get_datalink_channel(&interface);
116+
(interface, frames_res)
117+
})
118+
.collect_vec();
119+
120+
// warn for all frame receivers we failed to acquire
121+
interfaces_with_frames_res
122+
.iter()
123+
.filter_map(|(interface, frames_res)| frames_res.as_ref().err().map(|err| (interface, err)))
124+
.for_each(|(interface, err)| {
125+
mt_log!(
126+
warn,
127+
"Failed to acquire a frame receiver for {}: {err}",
128+
interface.name
129+
)
130+
});
131+
132+
// bail if all of them fail
133+
// note that `Iterator::all` returns `true` for an empty iterator, so it is important to handle
134+
// that failure mode separately, which we already have
135+
if interfaces_with_frames_res
136+
.iter()
137+
.all(|(_, frames)| frames.is_err())
138+
{
139+
let (permission_err_interfaces, other_errs) = interfaces_with_frames_res.iter().fold(
140+
(vec![], vec![]),
141+
|(mut perms, mut others), (_, res)| {
142+
match res {
143+
Ok(_) => (),
144+
Err(GetInterfaceError::PermissionError(interface)) => {
145+
perms.push(interface.as_str())
146+
}
147+
Err(GetInterfaceError::OtherError(err)) => others.push(err.as_str()),
148+
}
149+
(perms, others)
150+
},
151+
);
152+
153+
let err_msg = match (permission_err_interfaces.is_empty(), other_errs.is_empty()) {
154+
(false, false) => format!(
155+
"\n\n{}: {}\nAdditional errors:\n{}",
156+
permission_err_interfaces.join(", "),
157+
eperm_message(),
158+
other_errs.join("\n")
159+
),
160+
(false, true) => format!(
161+
"\n\n{}: {}",
162+
permission_err_interfaces.join(", "),
163+
eperm_message()
164+
),
165+
(true, false) => format!("\n\n{}", other_errs.join("\n")),
166+
(true, true) => unreachable!("Found no errors in error handling code path."),
167+
};
168+
bail!(err_msg);
197169
}
198170

199-
let keyboard_events = Box::new(TerminalEvents);
200-
let write_to_stdout = create_write_to_stdout();
171+
// filter out interfaces for which we failed to acquire a frame receiver
172+
let interfaces_with_frames = interfaces_with_frames_res
173+
.into_iter()
174+
.filter_map(|(interface, res)| res.ok().map(|frames| (interface, frames)))
175+
.collect();
176+
201177
let dns_client = if resolve {
202178
let runtime = Runtime::new()?;
203-
let resolver = match runtime.block_on(dns::Resolver::new(dns_server)) {
204-
Ok(resolver) => resolver,
205-
Err(err) => anyhow::bail!(
206-
"Could not initialize the DNS resolver. Are you offline?\n\nReason: {err:?}"
207-
),
208-
};
179+
let resolver = runtime
180+
.block_on(dns::Resolver::new(dns_server))
181+
.map_err(|err| {
182+
anyhow!("Could not initialize the DNS resolver. Are you offline?\n\nReason: {err}")
183+
})?;
209184
let dns_client = dns::Client::new(resolver, runtime)?;
210185
Some(dns_client)
211186
} else {
212187
None
213188
};
214189

190+
let write_to_stdout = create_write_to_stdout();
191+
215192
Ok(OsInputOutput {
216-
network_interfaces,
217-
network_frames: available_network_frames,
193+
interfaces_with_frames,
218194
get_open_sockets,
219-
terminal_events: keyboard_events,
195+
terminal_events: Box::new(TerminalEvents),
220196
dns_client,
221197
write_to_stdout,
222198
})

src/tests/cases/test_utils.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use rstest::fixture;
1414
use crate::{
1515
network::dns::Client,
1616
tests::fakes::{
17-
create_fake_dns_client, get_interfaces, get_open_sockets, NetworkFrames, TerminalEvent,
18-
TerminalEvents, TestBackend,
17+
create_fake_dns_client, get_interfaces_with_frames, get_open_sockets, NetworkFrames,
18+
TerminalEvent, TerminalEvents, TestBackend,
1919
},
2020
Opt, OsInputOutput,
2121
};
@@ -248,11 +248,13 @@ pub fn os_input_output_dns(
248248
}
249249

250250
pub fn os_input_output_factory(
251-
network_frames: Vec<Box<dyn DataLinkReceiver>>,
251+
network_frames: impl IntoIterator<Item = Box<dyn DataLinkReceiver>>,
252252
stdout: Option<Arc<Mutex<Vec<u8>>>>,
253253
dns_client: Option<Client>,
254254
keyboard_events: Box<dyn Iterator<Item = Event> + Send>,
255255
) -> OsInputOutput {
256+
let interfaces_with_frames = get_interfaces_with_frames(network_frames);
257+
256258
let write_to_stdout: Box<dyn FnMut(String) + Send> = match stdout {
257259
Some(stdout) => Box::new({
258260
move |output: String| {
@@ -264,8 +266,7 @@ pub fn os_input_output_factory(
264266
};
265267

266268
OsInputOutput {
267-
network_interfaces: get_interfaces(),
268-
network_frames,
269+
interfaces_with_frames,
269270
get_open_sockets,
270271
terminal_events: keyboard_events,
271272
dns_client,

0 commit comments

Comments
 (0)