From 1ec955bc0effa473479fd29c0fb447660d37153a Mon Sep 17 00:00:00 2001 From: Ingvar Stepanyan Date: Tue, 27 Aug 2024 12:36:56 +0100 Subject: [PATCH] Tweak Windows implementation - Fewer deps for Windows (doesn't need libc and memalloc anymore). - Somewhat safer allocation - relying on Vec + Drop which is guaranteed to run on any panic / return / etc instead of manually tracking a raw pointer for the buffer. - Increase initial buffer size as per official recommendations. - Add safe iteration helper for Windows linked lists (assuming that, if the first pointer is valid, then all others should be as well and will have the same lifetime). This significantly reduces number of individual `unsafe` usages and paves path for #45 as the entire function is now just an iterator + `.collect()`. - Simplify socket address conversions by relying on Rust's built-in endianness helpers and the SOCKADDR_INET helper union. --- Cargo.toml | 9 +- src/interface/windows.rs | 271 +++++++++++++++++---------------------- 2 files changed, 122 insertions(+), 158 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0adf3ef..0467a60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,11 @@ categories = ["network-programming"] license = "MIT" [dependencies] -libc = "0.2" serde = { version = "1", features = ["derive"], optional = true } +[target.'cfg(unix)'.dependencies] +libc = "0.2" + [target.'cfg(target_os = "android")'.dependencies] # DL Open dlopen2 = { version = "0.5", default-features = false } @@ -23,12 +25,9 @@ netlink-packet-core = "0.7" netlink-packet-route = "0.17" netlink-sys = "0.8" -[target.'cfg(windows)'.dependencies] -memalloc = "0.1.0" - [target.'cfg(windows)'.dependencies.windows-sys] version = "0.52" -features = ["Win32_Foundation","Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_Ndis"] +features = ["Win32_Foundation", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_Ndis"] [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies] system-configuration = "0.6" diff --git a/src/interface/windows.rs b/src/interface/windows.rs index 35b4f7b..8899f89 100644 --- a/src/interface/windows.rs +++ b/src/interface/windows.rs @@ -1,18 +1,13 @@ -use core::ffi::c_void; -use libc::{c_char, strlen, wchar_t, wcslen}; -use memalloc::{allocate, deallocate}; use std::convert::TryFrom; -use std::convert::TryInto; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr}; use windows_sys::Win32::Foundation::{ERROR_BUFFER_OVERFLOW, NO_ERROR}; use windows_sys::Win32::NetworkManagement::IpHelper::{ GetAdaptersAddresses, GetIfEntry2, SendARP, GAA_FLAG_INCLUDE_GATEWAYS, IP_ADAPTER_ADDRESSES_LH, MIB_IF_ROW2, MIB_IF_ROW2_0, }; -use windows_sys::Win32::NetworkManagement::Ndis::{IF_OPER_STATUS, NET_IF_OPER_STATUS_UP}; -use windows_sys::Win32::Networking::WinSock::SOCKET_ADDRESS; +use windows_sys::Win32::NetworkManagement::Ndis::NET_IF_OPER_STATUS_UP; use windows_sys::Win32::Networking::WinSock::{ - AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN, SOCKADDR_IN6, + AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_INET, SOCKET_ADDRESS, }; use crate::device::NetworkDevice; @@ -20,6 +15,8 @@ use crate::interface::{Interface, InterfaceType}; use crate::ip::{Ipv4Net, Ipv6Net}; use crate::mac::MacAddr; use crate::sys; +use std::ffi::CStr; +use std::mem::MaybeUninit; //const IFF_HARDWARE_INTERFACE: u8 = 0b0000_0001; //const IFF_FILTER_INTERFACE: u8 = 0b0000_0010; @@ -30,58 +27,37 @@ const IFF_CONNECTOR_PRESENT: u8 = 0b0000_0100; //const IFF_LOW_POWER: u8 = 0b0100_0000; //const IFF_END_POINT_INTERFACE: u8 = 0b1000_0000; -#[cfg(target_endian = "little")] -fn htonl(val: u32) -> u32 { - let o3 = (val >> 24) as u8; - let o2 = (val >> 16) as u8; - let o1 = (val >> 8) as u8; - let o0 = val as u8; - (o0 as u32) << 24 | (o1 as u32) << 16 | (o2 as u32) << 8 | (o3 as u32) -} - -#[cfg(target_endian = "big")] -fn htonl(val: u32) -> u32 { - val -} - fn get_mac_through_arp(src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> MacAddr { - let src_ip_int: u32 = htonl(u32::from(src_ip)); - let dst_ip_int: u32 = htonl(u32::from(dst_ip)); - let mut out_buf_len: u32 = 6; - let mut target_mac_addr: [u8; 6] = [0; 6]; + let src_ip_int = u32::from_ne_bytes(src_ip.octets()); + let dst_ip_int = u32::from_ne_bytes(dst_ip.octets()); + let mut out_buf_len = 6; + let mut target_mac_addr = MaybeUninit::<[u8; 6]>::uninit(); let res = unsafe { SendARP( dst_ip_int, src_ip_int, - target_mac_addr.as_mut_ptr() as *mut c_void, + target_mac_addr.as_mut_ptr().cast(), &mut out_buf_len, ) }; if res == NO_ERROR { - MacAddr::from_octets(target_mac_addr) + assert_eq!(out_buf_len, 6); + MacAddr::from_octets(unsafe { target_mac_addr.assume_init() }) } else { MacAddr::zero() } } unsafe fn socket_address_to_ipaddr(addr: &SOCKET_ADDRESS) -> Option { - let sockaddr = unsafe { *addr.lpSockaddr }; - if sockaddr.sa_family == AF_INET { - let sockaddr: *mut SOCKADDR_IN = addr.lpSockaddr as *mut SOCKADDR_IN; - let a = unsafe { (*sockaddr).sin_addr.S_un.S_addr }; - let ipv4 = if cfg!(target_endian = "little") { - Ipv4Addr::from(a.swap_bytes()) - } else { - Ipv4Addr::from(a) - }; - return Some(IpAddr::V4(ipv4)); - } else if sockaddr.sa_family == AF_INET6 { - let sockaddr: *mut SOCKADDR_IN6 = addr.lpSockaddr as *mut SOCKADDR_IN6; - let a = unsafe { (*sockaddr).sin6_addr.u.Byte }; - let ipv6 = Ipv6Addr::from(a); - return Some(IpAddr::V6(ipv6)); - } - None + let sockaddr = addr.lpSockaddr.cast::().as_ref()?; + + Some(match sockaddr.si_family { + AF_INET => unsafe { sockaddr.Ipv4.sin_addr.S_un.S_addr } + .to_ne_bytes() + .into(), + AF_INET6 => unsafe { sockaddr.Ipv6.sin6_addr.u.Byte }.into(), + _ => return None, + }) } pub fn is_running(interface: &Interface) -> bool { @@ -113,6 +89,33 @@ pub fn is_physical_interface(interface: &Interface) -> bool { && !interface.is_loopback()) } +unsafe fn from_wide_string(ptr: *const u16) -> String { + let mut len = 0; + while *ptr.add(len) != 0 { + len += 1; + } + String::from_utf16_lossy(std::slice::from_raw_parts(ptr, len)) +} + +// Note: We take `&*mut T` instead of just `*mut T` to tie the lifetime of all the returned items +// to the lifetime of the pointer for some extra safety. +unsafe fn linked_list_iter(ptr: &*mut T, next: fn(&T) -> *mut T) -> impl Iterator { + let mut ptr = ptr.cast_const(); + + std::iter::from_fn(move || { + let cur = ptr.as_ref()?; + ptr = next(cur); + Some(cur) + }) +} + +// The `Next` element is always the same, so use a macro to avoid the repetition. +macro_rules! linked_list_iter { + ($ptr:expr) => { + linked_list_iter($ptr, |cur| cur.Next) + }; +} + // Get network interfaces using the IP Helper API // Reference: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses pub fn interfaces() -> Vec { @@ -120,49 +123,52 @@ pub fn interfaces() -> Vec { Some(local_ip) => local_ip, None => IpAddr::V4(Ipv4Addr::LOCALHOST), }; - let mut interfaces: Vec = vec![]; - let mut dwsize: u32 = 2000; - let mut mem = unsafe { allocate(dwsize as usize) } as *mut IP_ADAPTER_ADDRESSES_LH; + // "The recommended method of calling the GetAdaptersAddresses function is to pre-allocate a 15KB working buffer pointed to by the AdapterAddresses parameter." + // (c) https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses + let mut mem = Vec::::with_capacity(15000); let mut retries = 3; - let mut ret_val; loop { - let old_size = dwsize as usize; - ret_val = unsafe { + let mut dwsize = mem.capacity() as u32; + let ret_val = unsafe { GetAdaptersAddresses( AF_UNSPEC as u32, GAA_FLAG_INCLUDE_GATEWAYS, - std::ptr::null_mut::(), - mem, + std::ptr::null_mut(), + mem.as_mut_ptr().cast(), &mut dwsize, ) }; - if ret_val != ERROR_BUFFER_OVERFLOW || retries <= 0 { - break; + match ret_val { + NO_ERROR => { + unsafe { + mem.set_len(dwsize as usize); + } + break; + } + ERROR_BUFFER_OVERFLOW if retries > 0 => { + mem.reserve(dwsize as usize); + retries -= 1; + } + _ => { + // TODO: return errors as a Result someday? + return vec![]; + } } - unsafe { deallocate(mem as *mut u8, old_size as usize) }; - mem = unsafe { allocate(dwsize as usize) as *mut IP_ADAPTER_ADDRESSES_LH }; - retries -= 1; } - if ret_val == NO_ERROR { - // Enumerate all adapters - let mut cur = mem; - while !cur.is_null() { - let if_type_int: u32 = unsafe { (*cur).IfType }; - let if_type = match InterfaceType::try_from(if_type_int) { - Ok(if_type) => if_type, - Err(_) => { - cur = unsafe { (*cur).Next }; - continue; - } - }; + // Enumerate all adapters + let mem = mem.as_mut_ptr().cast::(); + unsafe { linked_list_iter!(&mem) } + .filter_map(|cur| { + let if_type = InterfaceType::try_from(cur.IfType).ok()?; // Index - let anon1 = unsafe { (*cur).Anonymous1 }; - let anon = unsafe { anon1.Anonymous }; - let index = anon.IfIndex; + let index = { + let anon1 = cur.Anonymous1; + let anon = unsafe { &anon1.Anonymous }; + anon.IfIndex + }; // Flags and Status let mut flags: u32 = 0; - let status: IF_OPER_STATUS = unsafe { (*cur).OperStatus }; - if status == NET_IF_OPER_STATUS_UP { + if cur.OperStatus == NET_IF_OPER_STATUS_UP { flags |= sys::IFF_UP; } match if_type { @@ -184,74 +190,48 @@ pub fn interfaces() -> Vec { _ => {} } // Name - let p_aname = unsafe { (*cur).AdapterName }; - let aname_len = unsafe { strlen(p_aname as *const c_char) }; - let aname_slice = unsafe { std::slice::from_raw_parts(p_aname, aname_len) }; - let adapter_name = String::from_utf8(aname_slice.to_vec()).unwrap(); - // Friendly Name - let p_fname = unsafe { (*cur).FriendlyName }; - let fname_len = unsafe { wcslen(p_fname as *const wchar_t) }; - let fname_slice = unsafe { std::slice::from_raw_parts(p_fname, fname_len) }; - let friendly_name = String::from_utf16(fname_slice).unwrap(); - // Description - let p_desc = unsafe { (*cur).Description }; - let desc_len = unsafe { wcslen(p_desc as *const wchar_t) }; - let desc_slice = unsafe { std::slice::from_raw_parts(p_desc, desc_len) }; - let description = String::from_utf16(desc_slice).unwrap(); + let adapter_name = unsafe { CStr::from_ptr(cur.AdapterName.cast()) } + .to_string_lossy() + .into_owned(); // MAC address - let mac_addr_arr: [u8; 6] = unsafe { (*cur).PhysicalAddress }[..6] - .try_into() - .unwrap_or([0, 0, 0, 0, 0, 0]); + let mac_addr_arr: [u8; 6] = cur.PhysicalAddress[..6].try_into().unwrap_or_default(); let mac_addr: MacAddr = MacAddr::from_octets(mac_addr_arr); - // TransmitLinkSpeed (bits per second) - let transmit_speed = unsafe { (*cur).TransmitLinkSpeed }; - // ReceiveLinkSpeed (bits per second) - let receive_speed = unsafe { (*cur).ReceiveLinkSpeed }; let mut ipv4_vec: Vec = vec![]; let mut ipv6_vec: Vec = vec![]; // Enumerate all IPs - let mut cur_a = unsafe { (*cur).FirstUnicastAddress }; - while !cur_a.is_null() { - let addr: SOCKET_ADDRESS = unsafe { (*cur_a).Address }; - let ip_addr = unsafe { socket_address_to_ipaddr(&addr) }; - let prefix_len = unsafe { (*cur_a).OnLinkPrefixLength }; - if let Some(ip_addr) = ip_addr { - match ip_addr { - IpAddr::V4(ipv4) => { - let ipv4_net: Ipv4Net = Ipv4Net::new(ipv4, prefix_len); - ipv4_vec.push(ipv4_net); - } - IpAddr::V6(ipv6) => { - let ipv6_net: Ipv6Net = Ipv6Net::new(ipv6, prefix_len); - ipv6_vec.push(ipv6_net); - } + for cur_a in unsafe { linked_list_iter!(&cur.FirstUnicastAddress) } { + let Some(ip_addr) = (unsafe { socket_address_to_ipaddr(&cur_a.Address) }) else { + continue; + }; + let prefix_len = cur_a.OnLinkPrefixLength; + match ip_addr { + IpAddr::V4(ipv4) => { + let ipv4_net: Ipv4Net = Ipv4Net::new(ipv4, prefix_len); + ipv4_vec.push(ipv4_net); + } + IpAddr::V6(ipv6) => { + let ipv6_net: Ipv6Net = Ipv6Net::new(ipv6, prefix_len); + ipv6_vec.push(ipv6_net); } } - cur_a = unsafe { (*cur_a).Next }; } // Gateway - let mut gateway_ips: Vec = vec![]; - let mut cur_g = unsafe { (*cur).FirstGatewayAddress }; - while !cur_g.is_null() { - let addr: SOCKET_ADDRESS = unsafe { (*cur_g).Address }; - if let Some(ip_addr) = unsafe { socket_address_to_ipaddr(&addr) } { - gateway_ips.push(ip_addr); - } - cur_g = unsafe { (*cur_g).Next }; - } + let gateway_ips: Vec = unsafe { linked_list_iter!(&cur.FirstGatewayAddress) } + .filter_map(|cur_g| unsafe { socket_address_to_ipaddr(&cur_g.Address) }) + .collect(); let mut default_gateway: NetworkDevice = NetworkDevice::new(); if flags & sys::IFF_UP != 0 { for gateway_ip in gateway_ips { match gateway_ip { IpAddr::V4(ipv4) => { - if let Some(ip_net) = ipv4_vec.get(0) { + if let Some(ip_net) = ipv4_vec.first() { let mac_addr = get_mac_through_arp(ip_net.addr, ipv4); default_gateway.mac_addr = mac_addr; default_gateway.ipv4.push(ipv4); } } IpAddr::V6(ipv6) => { - if let Some(_ip_net) = ipv6_vec.get(0) { + if !ipv6_vec.is_empty() { default_gateway.ipv6.push(ipv6); } } @@ -259,49 +239,34 @@ pub fn interfaces() -> Vec { } } // DNS Servers - let mut dns_servers: Vec = vec![]; - let mut cur_d = unsafe { (*cur).FirstDnsServerAddress }; - while !cur_d.is_null() { - let addr: SOCKET_ADDRESS = unsafe { (*cur_d).Address }; - if let Some(ip_addr) = unsafe { socket_address_to_ipaddr(&addr) } { - dns_servers.push(ip_addr); - } - cur_d = unsafe { (*cur_d).Next }; - } + let dns_servers: Vec = unsafe { linked_list_iter!(&cur.FirstDnsServerAddress) } + .filter_map(|cur_d| unsafe { socket_address_to_ipaddr(&cur_d.Address) }) + .collect(); let default: bool = match local_ip { IpAddr::V4(local_ipv4) => ipv4_vec.iter().any(|x| x.addr == local_ipv4), IpAddr::V6(local_ipv6) => ipv6_vec.iter().any(|x| x.addr == local_ipv6), }; let interface: Interface = Interface { - index: index, + index, name: adapter_name, - friendly_name: Some(friendly_name), - description: Some(description), - if_type: if_type, + friendly_name: Some(unsafe { from_wide_string(cur.FriendlyName) }), + description: Some(unsafe { from_wide_string(cur.Description) }), + if_type, mac_addr: Some(mac_addr), ipv4: ipv4_vec, ipv6: ipv6_vec, - flags: flags, - transmit_speed: Some(transmit_speed), - receive_speed: Some(receive_speed), + flags, + transmit_speed: Some(cur.TransmitLinkSpeed), + receive_speed: Some(cur.ReceiveLinkSpeed), gateway: if default_gateway.mac_addr == MacAddr::zero() { None } else { Some(default_gateway) }, - dns_servers: dns_servers, - default: default, + dns_servers, + default, }; - interfaces.push(interface); - cur = unsafe { (*cur).Next }; - } - } else { - unsafe { - deallocate(mem as *mut u8, dwsize as usize); - } - } - unsafe { - deallocate(mem as *mut u8, dwsize as usize); - } - return interfaces; + Some(interface) + }) + .collect() }