Skip to content

Commit

Permalink
Introduce a RawSocketAddr type
Browse files Browse the repository at this point in the history
It wraps SockAddrStorage, but includes a length, for when a sockaddr has
to be stored for a longer lifetime.
  • Loading branch information
colinmarc committed Oct 18, 2024
1 parent 22e9043 commit 666f5bd
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub use crate::maybe_polyfill::net::{
};
pub use send_recv::*;
pub use socket::*;
pub use socket_addr_any::{SocketAddrAny, SocketAddrStorage};
pub use socket_addr_any::{RawSocketAddr, SocketAddrAny, SocketAddrStorage};
#[cfg(not(any(windows, target_os = "wasi")))]
pub use socketpair::socketpair;
pub use types::*;
Expand Down
47 changes: 47 additions & 0 deletions src/net/socket_addr_any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::net::xdp::SocketAddrXdp;
#[cfg(unix)]
use crate::net::SocketAddrUnix;
use crate::net::{AddressFamily, SocketAddr, SocketAddrV4, SocketAddrV6};
use crate::utils::{as_mut_ptr, as_ptr};
use crate::{backend, io};
#[cfg(feature = "std")]
use core::fmt;
Expand Down Expand Up @@ -83,6 +84,23 @@ impl SocketAddrAny {
}
}

/// Creates a platform-specific encoding of this socket address,
/// and returns it.
pub fn to_raw(&self) -> RawSocketAddr {
let mut raw = RawSocketAddr {
storage: unsafe { std::mem::zeroed() },
len: 0,
};

raw.len = unsafe { self.write(raw.as_mut_ptr()) };
raw
}

/// Reads a platform-specific encoding of a socket address.
pub fn from_raw(raw: RawSocketAddr) -> io::Result<Self> {
unsafe { Self::read(raw.as_ptr(), raw.len) }
}

/// Writes a platform-specific encoding of this socket address to
/// the memory pointed to by `storage`, and returns the number of
/// bytes used.
Expand All @@ -107,6 +125,35 @@ impl SocketAddrAny {
}
}

/// A raw sockaddr and its length.
#[repr(C)]
pub struct RawSocketAddr {
pub(crate) storage: SocketAddrStorage,
pub(crate) len: usize,
}

impl RawSocketAddr {
/// Creates a raw encoded sockaddr from the given address.
pub fn new(addr: impl Into<SocketAddrAny>) -> Self {
addr.into().to_raw()
}

/// Returns a raw pointer to the sockaddr.
pub fn as_ptr(&self) -> *const SocketAddrStorage {
as_ptr(&self.storage)
}

/// Returns a raw mutable pointer to the sockaddr.
pub fn as_mut_ptr(&mut self) -> *mut SocketAddrStorage {
as_mut_ptr(&mut self.storage)
}

/// Returns the length of the encoded sockaddr.
pub fn namelen(&self) -> usize {
self.len
}
}

#[cfg(feature = "std")]
impl fmt::Debug for SocketAddrAny {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
15 changes: 15 additions & 0 deletions tests/net/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,34 @@ fn encode_decode() {
let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap();
assert_eq!(decoded, SocketAddrAny::V4(orig));

let orig = SocketAddrV4::new(Ipv4Addr::new(2, 3, 5, 6), 33);
let encoded = SocketAddrAny::V4(orig).to_raw();
let decoded = SocketAddrAny::from_raw(encoded).unwrap();
assert_eq!(decoded, SocketAddrAny::V4(orig));

let orig = SocketAddrV6::new(Ipv6Addr::new(2, 3, 5, 6, 8, 9, 11, 12), 33, 34, 36);
let mut encoded = std::mem::MaybeUninit::<SocketAddrStorage>::uninit();
let len = SocketAddrAny::V6(orig).write(encoded.as_mut_ptr());
let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap();
assert_eq!(decoded, SocketAddrAny::V6(orig));

let orig = SocketAddrV6::new(Ipv6Addr::new(2, 3, 5, 6, 8, 9, 11, 12), 33, 34, 36);
let encoded = SocketAddrAny::V6(orig).to_raw();
let decoded = SocketAddrAny::from_raw(encoded).unwrap();
assert_eq!(decoded, SocketAddrAny::V6(orig));

#[cfg(not(windows))]
{
let orig = SocketAddrUnix::new("/path/to/socket").unwrap();
let mut encoded = std::mem::MaybeUninit::<SocketAddrStorage>::uninit();
let len = SocketAddrAny::Unix(orig.clone()).write(encoded.as_mut_ptr());
let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap();
assert_eq!(decoded, SocketAddrAny::Unix(orig));

let orig = SocketAddrUnix::new("/path/to/socket").unwrap();
let encoded = SocketAddrAny::Unix(orig.clone()).to_raw();
let decoded = SocketAddrAny::from_raw(encoded).unwrap();
assert_eq!(decoded, SocketAddrAny::Unix(orig));
}
}
}
Expand Down

0 comments on commit 666f5bd

Please sign in to comment.