@@ -4,11 +4,13 @@ use std::{
4
4
time,
5
5
} ;
6
6
7
+ use anyhow:: { anyhow, bail} ;
7
8
use crossterm:: event:: { read, Event } ;
9
+ use itertools:: Itertools ;
8
10
use pnet:: datalink:: { self , Channel :: Ethernet , Config , DataLinkReceiver , NetworkInterface } ;
9
11
use tokio:: runtime:: Runtime ;
10
12
11
- use crate :: { network:: dns, os:: errors:: GetInterfaceError , OsInputOutput } ;
13
+ use crate :: { mt_log , network:: dns, os:: errors:: GetInterfaceError , OsInputOutput } ;
12
14
13
15
#[ cfg( target_os = "linux" ) ]
14
16
use crate :: os:: linux:: get_open_sockets;
@@ -63,160 +65,134 @@ fn get_interface(interface_name: &str) -> Option<NetworkInterface> {
63
65
}
64
66
65
67
fn create_write_to_stdout ( ) -> Box < dyn FnMut ( String ) + Send > {
68
+ let mut stdout = io:: stdout ( ) ;
66
69
Box :: new ( {
67
- let mut stdout = io:: stdout ( ) ;
68
70
move |output : String | {
69
71
writeln ! ( stdout, "{}" , output) . unwrap ( ) ;
70
72
}
71
73
} )
72
74
}
73
75
74
- #[ derive( Debug ) ]
75
- pub struct UserErrors {
76
- permission : Option < String > ,
77
- other : Option < String > ,
78
- }
79
-
80
- pub fn collect_errors < ' a , I > ( network_frames : I ) -> String
81
- where
82
- I : Iterator <
83
- Item = (
84
- & ' a NetworkInterface ,
85
- Result < Box < dyn DataLinkReceiver > , GetInterfaceError > ,
86
- ) ,
87
- > ,
88
- {
89
- let errors = network_frames. fold (
90
- UserErrors {
91
- permission : None ,
92
- other : None ,
93
- } ,
94
- |acc, ( _, elem) | {
95
- if let Some ( iface_error) = elem. err ( ) {
96
- match iface_error {
97
- GetInterfaceError :: PermissionError ( interface_name) => {
98
- if let Some ( prev_interface) = acc. permission {
99
- return UserErrors {
100
- permission : Some ( format ! ( "{prev_interface}, {interface_name}" ) ) ,
101
- ..acc
102
- } ;
103
- } else {
104
- return UserErrors {
105
- permission : Some ( interface_name) ,
106
- ..acc
107
- } ;
108
- }
109
- }
110
- error => {
111
- if let Some ( prev_errors) = acc. other {
112
- return UserErrors {
113
- other : Some ( format ! ( "{prev_errors} \n {error}" ) ) ,
114
- ..acc
115
- } ;
116
- } else {
117
- return UserErrors {
118
- other : Some ( format ! ( "{error}" ) ) ,
119
- ..acc
120
- } ;
121
- }
122
- }
123
- } ;
124
- }
125
- acc
126
- } ,
127
- ) ;
128
- if let Some ( interface_name) = errors. permission {
129
- if let Some ( other_errors) = errors. other {
130
- format ! (
131
- "\n \n {interface_name}: {} \n Additional Errors: \n {other_errors}" ,
132
- eperm_message( ) ,
133
- )
134
- } else {
135
- format ! ( "\n \n {interface_name}: {}" , eperm_message( ) )
136
- }
137
- } else {
138
- let other_errors = errors
139
- . other
140
- . expect ( "asked to collect errors but found no errors" ) ;
141
- format ! ( "\n \n {other_errors}" )
142
- }
143
- }
144
-
145
76
pub fn get_input (
146
77
interface_name : Option < & str > ,
147
78
resolve : bool ,
148
79
dns_server : Option < Ipv4Addr > ,
149
80
) -> anyhow:: Result < OsInputOutput > {
150
- let network_interfaces = if let Some ( name) = interface_name {
151
- match get_interface ( name) {
152
- Some ( interface) => vec ! [ interface] ,
153
- None => {
154
- anyhow:: bail!( "Cannot find interface {name}" ) ;
155
- // the homebrew formula relies on this wording, please be careful when changing
156
- }
157
- }
158
- } else {
159
- datalink:: interfaces ( )
160
- } ;
161
-
162
- #[ cfg( target_os = "windows" ) ]
163
- let network_frames = network_interfaces
164
- . iter ( )
165
- . filter ( |iface| !iface. ips . is_empty ( ) )
166
- . map ( |iface| ( iface, get_datalink_channel ( iface) ) ) ;
167
- #[ cfg( not( target_os = "windows" ) ) ]
168
- let network_frames = network_interfaces
169
- . iter ( )
170
- . filter ( |iface| iface. is_up ( ) && !iface. ips . is_empty ( ) )
171
- . map ( |iface| ( iface, get_datalink_channel ( iface) ) ) ;
172
-
173
- let ( available_network_frames, network_interfaces) = {
174
- let network_frames = network_frames. clone ( ) ;
175
- let mut available_network_frames = Vec :: new ( ) ;
176
- let mut available_interfaces: Vec < NetworkInterface > = Vec :: new ( ) ;
177
- for ( iface, rx) in network_frames. filter_map ( |( iface, channel) | {
178
- if let Ok ( rx) = channel {
179
- Some ( ( iface, rx) )
81
+ // get the user's requested interface, if any
82
+ // IDEA: allow requesting multiple interfaces
83
+ let requested_interfaces = interface_name
84
+ . map ( |name| get_interface ( name) . ok_or_else ( || anyhow ! ( "Cannot find interface {name}" ) ) )
85
+ . transpose ( ) ?
86
+ . map ( |interface| vec ! [ interface] ) ;
87
+
88
+ // take the user's requested interfaces (or all interfaces), and filter for up ones
89
+ let available_interfaces = requested_interfaces
90
+ . unwrap_or_else ( datalink:: interfaces)
91
+ . into_iter ( )
92
+ . filter ( |interface| {
93
+ // see https://github.com/libpnet/libpnet/issues/564
94
+ let keep = if cfg ! ( target_os = "windows" ) {
95
+ !interface. ips . is_empty ( )
180
96
} else {
181
- None
97
+ interface. is_up ( ) && !interface. ips . is_empty ( )
98
+ } ;
99
+ if !keep {
100
+ mt_log ! ( debug, "{} is down. Skipping it." , interface. name) ;
182
101
}
183
- } ) {
184
- available_interfaces. push ( iface. clone ( ) ) ;
185
- available_network_frames. push ( rx) ;
186
- }
187
- ( available_network_frames, available_interfaces)
188
- } ;
102
+ keep
103
+ } )
104
+ . collect_vec ( ) ;
189
105
190
- if available_network_frames. is_empty ( ) {
191
- let all_errors = collect_errors ( network_frames. clone ( ) ) ;
192
- if !all_errors. is_empty ( ) {
193
- anyhow:: bail!( all_errors) ;
194
- }
106
+ // bail if no interfaces are up
107
+ if available_interfaces. is_empty ( ) {
108
+ bail ! ( "Failed to find any network interface to listen on." ) ;
109
+ }
195
110
196
- anyhow:: bail!( "Failed to find any network interface to listen on." ) ;
111
+ // try to get a frame receiver for each interface
112
+ let interfaces_with_frames_res = available_interfaces
113
+ . into_iter ( )
114
+ . map ( |interface| {
115
+ let frames_res = get_datalink_channel ( & interface) ;
116
+ ( interface, frames_res)
117
+ } )
118
+ . collect_vec ( ) ;
119
+
120
+ // warn for all frame receivers we failed to acquire
121
+ interfaces_with_frames_res
122
+ . iter ( )
123
+ . filter_map ( |( interface, frames_res) | frames_res. as_ref ( ) . err ( ) . map ( |err| ( interface, err) ) )
124
+ . for_each ( |( interface, err) | {
125
+ mt_log ! (
126
+ warn,
127
+ "Failed to acquire a frame receiver for {}: {err}" ,
128
+ interface. name
129
+ )
130
+ } ) ;
131
+
132
+ // bail if all of them fail
133
+ // note that `Iterator::all` returns `true` for an empty iterator, so it is important to handle
134
+ // that failure mode separately, which we already have
135
+ if interfaces_with_frames_res
136
+ . iter ( )
137
+ . all ( |( _, frames) | frames. is_err ( ) )
138
+ {
139
+ let ( permission_err_interfaces, other_errs) = interfaces_with_frames_res. iter ( ) . fold (
140
+ ( vec ! [ ] , vec ! [ ] ) ,
141
+ |( mut perms, mut others) , ( _, res) | {
142
+ match res {
143
+ Ok ( _) => ( ) ,
144
+ Err ( GetInterfaceError :: PermissionError ( interface) ) => {
145
+ perms. push ( interface. as_str ( ) )
146
+ }
147
+ Err ( GetInterfaceError :: OtherError ( err) ) => others. push ( err. as_str ( ) ) ,
148
+ }
149
+ ( perms, others)
150
+ } ,
151
+ ) ;
152
+
153
+ let err_msg = match ( permission_err_interfaces. is_empty ( ) , other_errs. is_empty ( ) ) {
154
+ ( false , false ) => format ! (
155
+ "\n \n {}: {}\n Additional errors:\n {}" ,
156
+ permission_err_interfaces. join( ", " ) ,
157
+ eperm_message( ) ,
158
+ other_errs. join( "\n " )
159
+ ) ,
160
+ ( false , true ) => format ! (
161
+ "\n \n {}: {}" ,
162
+ permission_err_interfaces. join( ", " ) ,
163
+ eperm_message( )
164
+ ) ,
165
+ ( true , false ) => format ! ( "\n \n {}" , other_errs. join( "\n " ) ) ,
166
+ ( true , true ) => unreachable ! ( "Found no errors in error handling code path." ) ,
167
+ } ;
168
+ bail ! ( err_msg) ;
197
169
}
198
170
199
- let keyboard_events = Box :: new ( TerminalEvents ) ;
200
- let write_to_stdout = create_write_to_stdout ( ) ;
171
+ // filter out interfaces for which we failed to acquire a frame receiver
172
+ let interfaces_with_frames = interfaces_with_frames_res
173
+ . into_iter ( )
174
+ . filter_map ( |( interface, res) | res. ok ( ) . map ( |frames| ( interface, frames) ) )
175
+ . collect ( ) ;
176
+
201
177
let dns_client = if resolve {
202
178
let runtime = Runtime :: new ( ) ?;
203
- let resolver = match runtime. block_on ( dns:: Resolver :: new ( dns_server) ) {
204
- Ok ( resolver) => resolver,
205
- Err ( err) => anyhow:: bail!(
206
- "Could not initialize the DNS resolver. Are you offline?\n \n Reason: {err:?}"
207
- ) ,
208
- } ;
179
+ let resolver = runtime
180
+ . block_on ( dns:: Resolver :: new ( dns_server) )
181
+ . map_err ( |err| {
182
+ anyhow ! ( "Could not initialize the DNS resolver. Are you offline?\n \n Reason: {err}" )
183
+ } ) ?;
209
184
let dns_client = dns:: Client :: new ( resolver, runtime) ?;
210
185
Some ( dns_client)
211
186
} else {
212
187
None
213
188
} ;
214
189
190
+ let write_to_stdout = create_write_to_stdout ( ) ;
191
+
215
192
Ok ( OsInputOutput {
216
- network_interfaces,
217
- network_frames : available_network_frames,
193
+ interfaces_with_frames,
218
194
get_open_sockets,
219
- terminal_events : keyboard_events ,
195
+ terminal_events : Box :: new ( TerminalEvents ) ,
220
196
dns_client,
221
197
write_to_stdout,
222
198
} )
0 commit comments