@@ -4,11 +4,13 @@ use std::{
44 time,
55} ;
66
7+ use anyhow:: { anyhow, bail} ;
78use crossterm:: event:: { read, Event } ;
9+ use itertools:: Itertools ;
810use pnet:: datalink:: { self , Channel :: Ethernet , Config , DataLinkReceiver , NetworkInterface } ;
911use tokio:: runtime:: Runtime ;
1012
11- use crate :: { network:: dns, os:: errors:: GetInterfaceError , OsInputOutput } ;
13+ use crate :: { mt_log , network:: dns, os:: errors:: GetInterfaceError , OsInputOutput } ;
1214
1315#[ cfg( target_os = "linux" ) ]
1416use crate :: os:: linux:: get_open_sockets;
@@ -63,160 +65,134 @@ fn get_interface(interface_name: &str) -> Option<NetworkInterface> {
6365}
6466
6567fn create_write_to_stdout ( ) -> Box < dyn FnMut ( String ) + Send > {
68+ let mut stdout = io:: stdout ( ) ;
6669 Box :: new ( {
67- let mut stdout = io:: stdout ( ) ;
6870 move |output : String | {
6971 writeln ! ( stdout, "{}" , output) . unwrap ( ) ;
7072 }
7173 } )
7274}
7375
74- #[ derive( Debug ) ]
75- pub struct UserErrors {
76- permission : Option < String > ,
77- other : Option < String > ,
78- }
79-
80- pub fn collect_errors < ' a , I > ( network_frames : I ) -> String
81- where
82- I : Iterator <
83- Item = (
84- & ' a NetworkInterface ,
85- Result < Box < dyn DataLinkReceiver > , GetInterfaceError > ,
86- ) ,
87- > ,
88- {
89- let errors = network_frames. fold (
90- UserErrors {
91- permission : None ,
92- other : None ,
93- } ,
94- |acc, ( _, elem) | {
95- if let Some ( iface_error) = elem. err ( ) {
96- match iface_error {
97- GetInterfaceError :: PermissionError ( interface_name) => {
98- if let Some ( prev_interface) = acc. permission {
99- return UserErrors {
100- permission : Some ( format ! ( "{prev_interface}, {interface_name}" ) ) ,
101- ..acc
102- } ;
103- } else {
104- return UserErrors {
105- permission : Some ( interface_name) ,
106- ..acc
107- } ;
108- }
109- }
110- error => {
111- if let Some ( prev_errors) = acc. other {
112- return UserErrors {
113- other : Some ( format ! ( "{prev_errors} \n {error}" ) ) ,
114- ..acc
115- } ;
116- } else {
117- return UserErrors {
118- other : Some ( format ! ( "{error}" ) ) ,
119- ..acc
120- } ;
121- }
122- }
123- } ;
124- }
125- acc
126- } ,
127- ) ;
128- if let Some ( interface_name) = errors. permission {
129- if let Some ( other_errors) = errors. other {
130- format ! (
131- "\n \n {interface_name}: {} \n Additional Errors: \n {other_errors}" ,
132- eperm_message( ) ,
133- )
134- } else {
135- format ! ( "\n \n {interface_name}: {}" , eperm_message( ) )
136- }
137- } else {
138- let other_errors = errors
139- . other
140- . expect ( "asked to collect errors but found no errors" ) ;
141- format ! ( "\n \n {other_errors}" )
142- }
143- }
144-
14576pub fn get_input (
14677 interface_name : Option < & str > ,
14778 resolve : bool ,
14879 dns_server : Option < Ipv4Addr > ,
14980) -> anyhow:: Result < OsInputOutput > {
150- let network_interfaces = if let Some ( name) = interface_name {
151- match get_interface ( name) {
152- Some ( interface) => vec ! [ interface] ,
153- None => {
154- anyhow:: bail!( "Cannot find interface {name}" ) ;
155- // the homebrew formula relies on this wording, please be careful when changing
156- }
157- }
158- } else {
159- datalink:: interfaces ( )
160- } ;
161-
162- #[ cfg( target_os = "windows" ) ]
163- let network_frames = network_interfaces
164- . iter ( )
165- . filter ( |iface| !iface. ips . is_empty ( ) )
166- . map ( |iface| ( iface, get_datalink_channel ( iface) ) ) ;
167- #[ cfg( not( target_os = "windows" ) ) ]
168- let network_frames = network_interfaces
169- . iter ( )
170- . filter ( |iface| iface. is_up ( ) && !iface. ips . is_empty ( ) )
171- . map ( |iface| ( iface, get_datalink_channel ( iface) ) ) ;
172-
173- let ( available_network_frames, network_interfaces) = {
174- let network_frames = network_frames. clone ( ) ;
175- let mut available_network_frames = Vec :: new ( ) ;
176- let mut available_interfaces: Vec < NetworkInterface > = Vec :: new ( ) ;
177- for ( iface, rx) in network_frames. filter_map ( |( iface, channel) | {
178- if let Ok ( rx) = channel {
179- Some ( ( iface, rx) )
81+ // get the user's requested interface, if any
82+ // IDEA: allow requesting multiple interfaces
83+ let requested_interfaces = interface_name
84+ . map ( |name| get_interface ( name) . ok_or_else ( || anyhow ! ( "Cannot find interface {name}" ) ) )
85+ . transpose ( ) ?
86+ . map ( |interface| vec ! [ interface] ) ;
87+
88+ // take the user's requested interfaces (or all interfaces), and filter for up ones
89+ let available_interfaces = requested_interfaces
90+ . unwrap_or_else ( datalink:: interfaces)
91+ . into_iter ( )
92+ . filter ( |interface| {
93+ // see https://github.com/libpnet/libpnet/issues/564
94+ let keep = if cfg ! ( target_os = "windows" ) {
95+ !interface. ips . is_empty ( )
18096 } else {
181- None
97+ interface. is_up ( ) && !interface. ips . is_empty ( )
98+ } ;
99+ if !keep {
100+ mt_log ! ( debug, "{} is down. Skipping it." , interface. name) ;
182101 }
183- } ) {
184- available_interfaces. push ( iface. clone ( ) ) ;
185- available_network_frames. push ( rx) ;
186- }
187- ( available_network_frames, available_interfaces)
188- } ;
102+ keep
103+ } )
104+ . collect_vec ( ) ;
189105
190- if available_network_frames. is_empty ( ) {
191- let all_errors = collect_errors ( network_frames. clone ( ) ) ;
192- if !all_errors. is_empty ( ) {
193- anyhow:: bail!( all_errors) ;
194- }
106+ // bail if no interfaces are up
107+ if available_interfaces. is_empty ( ) {
108+ bail ! ( "Failed to find any network interface to listen on." ) ;
109+ }
195110
196- anyhow:: bail!( "Failed to find any network interface to listen on." ) ;
111+ // try to get a frame receiver for each interface
112+ let interfaces_with_frames_res = available_interfaces
113+ . into_iter ( )
114+ . map ( |interface| {
115+ let frames_res = get_datalink_channel ( & interface) ;
116+ ( interface, frames_res)
117+ } )
118+ . collect_vec ( ) ;
119+
120+ // warn for all frame receivers we failed to acquire
121+ interfaces_with_frames_res
122+ . iter ( )
123+ . filter_map ( |( interface, frames_res) | frames_res. as_ref ( ) . err ( ) . map ( |err| ( interface, err) ) )
124+ . for_each ( |( interface, err) | {
125+ mt_log ! (
126+ warn,
127+ "Failed to acquire a frame receiver for {}: {err}" ,
128+ interface. name
129+ )
130+ } ) ;
131+
132+ // bail if all of them fail
133+ // note that `Iterator::all` returns `true` for an empty iterator, so it is important to handle
134+ // that failure mode separately, which we already have
135+ if interfaces_with_frames_res
136+ . iter ( )
137+ . all ( |( _, frames) | frames. is_err ( ) )
138+ {
139+ let ( permission_err_interfaces, other_errs) = interfaces_with_frames_res. iter ( ) . fold (
140+ ( vec ! [ ] , vec ! [ ] ) ,
141+ |( mut perms, mut others) , ( _, res) | {
142+ match res {
143+ Ok ( _) => ( ) ,
144+ Err ( GetInterfaceError :: PermissionError ( interface) ) => {
145+ perms. push ( interface. as_str ( ) )
146+ }
147+ Err ( GetInterfaceError :: OtherError ( err) ) => others. push ( err. as_str ( ) ) ,
148+ }
149+ ( perms, others)
150+ } ,
151+ ) ;
152+
153+ let err_msg = match ( permission_err_interfaces. is_empty ( ) , other_errs. is_empty ( ) ) {
154+ ( false , false ) => format ! (
155+ "\n \n {}: {}\n Additional errors:\n {}" ,
156+ permission_err_interfaces. join( ", " ) ,
157+ eperm_message( ) ,
158+ other_errs. join( "\n " )
159+ ) ,
160+ ( false , true ) => format ! (
161+ "\n \n {}: {}" ,
162+ permission_err_interfaces. join( ", " ) ,
163+ eperm_message( )
164+ ) ,
165+ ( true , false ) => format ! ( "\n \n {}" , other_errs. join( "\n " ) ) ,
166+ ( true , true ) => unreachable ! ( "Found no errors in error handling code path." ) ,
167+ } ;
168+ bail ! ( err_msg) ;
197169 }
198170
199- let keyboard_events = Box :: new ( TerminalEvents ) ;
200- let write_to_stdout = create_write_to_stdout ( ) ;
171+ // filter out interfaces for which we failed to acquire a frame receiver
172+ let interfaces_with_frames = interfaces_with_frames_res
173+ . into_iter ( )
174+ . filter_map ( |( interface, res) | res. ok ( ) . map ( |frames| ( interface, frames) ) )
175+ . collect ( ) ;
176+
201177 let dns_client = if resolve {
202178 let runtime = Runtime :: new ( ) ?;
203- let resolver = match runtime. block_on ( dns:: Resolver :: new ( dns_server) ) {
204- Ok ( resolver) => resolver,
205- Err ( err) => anyhow:: bail!(
206- "Could not initialize the DNS resolver. Are you offline?\n \n Reason: {err:?}"
207- ) ,
208- } ;
179+ let resolver = runtime
180+ . block_on ( dns:: Resolver :: new ( dns_server) )
181+ . map_err ( |err| {
182+ anyhow ! ( "Could not initialize the DNS resolver. Are you offline?\n \n Reason: {err}" )
183+ } ) ?;
209184 let dns_client = dns:: Client :: new ( resolver, runtime) ?;
210185 Some ( dns_client)
211186 } else {
212187 None
213188 } ;
214189
190+ let write_to_stdout = create_write_to_stdout ( ) ;
191+
215192 Ok ( OsInputOutput {
216- network_interfaces,
217- network_frames : available_network_frames,
193+ interfaces_with_frames,
218194 get_open_sockets,
219- terminal_events : keyboard_events ,
195+ terminal_events : Box :: new ( TerminalEvents ) ,
220196 dns_client,
221197 write_to_stdout,
222198 } )
0 commit comments