Skip to content

Commit

Permalink
chore(deps): upgrade async-io
Browse files Browse the repository at this point in the history
  • Loading branch information
hanabi1224 committed Nov 25, 2024
1 parent 3d5879c commit 0cb2ed9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 56 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ features = ["os-poll", "os-ext"]

[dependencies.async-io]
optional = true
version = "1.13"
version = "2"

[features]
default = []
mio_socket = ["mio"]
tokio_socket = ["tokio", "futures"]
smol_socket = ["async-io","futures"]
smol_socket = ["async-io", "futures"]

[dev-dependencies]
netlink-packet-audit = "0.4.1"
Expand Down
27 changes: 6 additions & 21 deletions src/smol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use std::{
io,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
task::{Context, Poll},
};

Expand All @@ -17,20 +16,6 @@ use crate::{AsyncSocket, Socket, SocketAddr};
/// An I/O object representing a Netlink socket.
pub struct SmolSocket(Async<Socket>);

impl FromRawFd for SmolSocket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
let socket = Socket::from_raw_fd(fd);
socket.set_non_blocking(true).unwrap();
SmolSocket(Async::new(socket).unwrap())
}
}

impl AsRawFd for SmolSocket {
fn as_raw_fd(&self) -> RawFd {
self.0.get_ref().as_raw_fd()
}
}

// async_io::Async<..>::{read,write}_with[_mut] functions try IO first,
// and only register context if it would block.
// replicate this in these poll functions:
Expand Down Expand Up @@ -79,7 +64,7 @@ impl AsyncSocket for SmolSocket {

/// Mutable access to underyling [`Socket`]
fn socket_mut(&mut self) -> &mut Socket {
self.0.get_mut()
unsafe { self.0.get_mut() }
}

fn new(protocol: isize) -> io::Result<Self> {
Expand All @@ -92,7 +77,7 @@ impl AsyncSocket for SmolSocket {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write_with(cx, |this| this.0.get_mut().send(buf, 0))
self.poll_write_with(cx, |this| this.socket_mut().send(buf, 0))
}

fn poll_send_to(
Expand All @@ -101,7 +86,7 @@ impl AsyncSocket for SmolSocket {
buf: &[u8],
addr: &SocketAddr,
) -> Poll<io::Result<usize>> {
self.poll_write_with(cx, |this| this.0.get_mut().send_to(buf, addr, 0))
self.poll_write_with(cx, |this| this.socket_mut().send_to(buf, addr, 0))
}

fn poll_recv<B>(
Expand All @@ -113,7 +98,7 @@ impl AsyncSocket for SmolSocket {
B: bytes::BufMut,
{
self.poll_read_with(cx, |this| {
this.0.get_mut().recv(buf, 0).map(|_len| ())
this.socket_mut().recv(buf, 0).map(|_len| ())
})
}

Expand All @@ -126,7 +111,7 @@ impl AsyncSocket for SmolSocket {
B: bytes::BufMut,
{
self.poll_read_with(cx, |this| {
let x = this.0.get_mut().recv_from(buf, 0);
let x = this.socket_mut().recv_from(buf, 0);
trace!("poll_recv_from: {:?}", x);
x.map(|(_len, addr)| addr)
})
Expand All @@ -136,6 +121,6 @@ impl AsyncSocket for SmolSocket {
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
self.poll_read_with(cx, |this| this.0.get_mut().recv_from_full())
self.poll_read_with(cx, |this| this.socket_mut().recv_from_full())
}
}
101 changes: 68 additions & 33 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
use std::{
io::{Error, Result},
mem,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::{
fd::{AsFd, BorrowedFd},
unix::io::{AsRawFd, RawFd},
},
};

use crate::SocketAddr;
Expand Down Expand Up @@ -51,7 +54,7 @@ use crate::SocketAddr;
/// }
/// }
/// ```
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct Socket(RawFd);

impl AsRawFd for Socket {
Expand All @@ -60,15 +63,15 @@ impl AsRawFd for Socket {
}
}

impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Socket(fd)
impl AsFd for Socket {
fn as_fd(&self) -> BorrowedFd<'_> {
unsafe { BorrowedFd::borrow_raw(self.0) }
}
}

impl Drop for Socket {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
unsafe { libc::close(self.as_raw_fd()) };
}
}

Expand All @@ -94,7 +97,7 @@ impl Socket {
/// Bind the socket to the given address
pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_raw();
let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
let res = unsafe { libc::bind(self.as_raw_fd(), addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -115,7 +118,9 @@ impl Socket {
let (addr_ptr, mut addr_len) = addr.as_raw_mut();
let addr_len_copy = addr_len;
let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
let res = unsafe {
libc::getsockname(self.as_raw_fd(), addr_ptr, addr_len_ptr)
};
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -128,8 +133,9 @@ impl Socket {
/// Make this socket non-blocking
pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
let mut non_blocking = non_blocking as libc::c_int;
let res =
unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
let res = unsafe {
libc::ioctl(self.as_raw_fd(), libc::FIONBIO, &mut non_blocking)
};
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -152,7 +158,7 @@ impl Socket {
/// 2. connect it to the kernel with [`Socket::connect`]
/// 3. send a request to the kernel with [`Socket::send`]
/// 4. read the response (which can span over several messages)
/// [`Socket::recv`]
/// [`Socket::recv`]
///
/// ```rust
/// use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
Expand Down Expand Up @@ -216,7 +222,7 @@ impl Socket {
// - https://stackoverflow.com/a/14046386/1836144
// - https://lists.isc.org/pipermail/bind-users/2009-August/077527.html
let (addr, addr_len) = remote_addr.as_raw();
let res = unsafe { libc::connect(self.0, addr, addr_len) };
let res = unsafe { libc::connect(self.as_raw_fd(), addr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand Down Expand Up @@ -293,7 +299,7 @@ impl Socket {

let res = unsafe {
libc::recvfrom(
self.0,
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
Expand Down Expand Up @@ -324,7 +330,8 @@ impl Socket {
let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void;
let buf_len = chunk.len() as libc::size_t;

let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
let res =
unsafe { libc::recv(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
} else {
Expand Down Expand Up @@ -367,7 +374,14 @@ impl Socket {
let buf_len = buf.len() as libc::size_t;

let res = unsafe {
libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len)
libc::sendto(
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
addr_ptr,
addr_len,
)
};
if res < 0 {
return Err(Error::last_os_error());
Expand All @@ -382,7 +396,8 @@ impl Socket {
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;

let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
let res =
unsafe { libc::send(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -391,12 +406,17 @@ impl Socket {

pub fn set_pktinfo(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
value,
)
}

pub fn get_pktinfo(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
)?;
Expand All @@ -405,7 +425,7 @@ impl Socket {

pub fn add_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_ADD_MEMBERSHIP,
group,
Expand All @@ -414,7 +434,7 @@ impl Socket {

pub fn drop_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_DROP_MEMBERSHIP,
group,
Expand All @@ -434,7 +454,7 @@ impl Socket {
pub fn set_broadcast_error(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
value,
Expand All @@ -443,7 +463,7 @@ impl Socket {

pub fn get_broadcast_error(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
)?;
Expand All @@ -454,12 +474,17 @@ impl Socket {
/// unicast and broadcast listeners to avoid receiving `ENOBUFS` errors.
pub fn set_no_enobufs(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
value,
)
}

pub fn get_no_enobufs(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
)?;
Expand All @@ -474,7 +499,7 @@ impl Socket {
pub fn set_listen_all_namespaces(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
value,
Expand All @@ -483,7 +508,7 @@ impl Socket {

pub fn get_listen_all_namespaces(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
)?;
Expand All @@ -498,12 +523,17 @@ impl Socket {
/// acknowledgment.
pub fn set_cap_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
value,
)
}

pub fn get_cap_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
)?;
Expand All @@ -515,12 +545,17 @@ impl Socket {
/// NLMSG_ERROR and NLMSG_DONE messages.
pub fn set_ext_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_EXT_ACK, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
value,
)
}

pub fn get_ext_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
)?;
Expand All @@ -534,13 +569,13 @@ impl Socket {
/// the maximum allowed value is set by the /proc/sys/net/core/rmem_max
/// file. The minimum (doubled) value for this option is 256.
pub fn set_rx_buf_sz<T>(&self, size: T) -> Result<()> {
setsockopt(self.0, libc::SOL_SOCKET, libc::SO_RCVBUF, size)
setsockopt(self.as_raw_fd(), libc::SOL_SOCKET, libc::SO_RCVBUF, size)
}

/// Gets socket receive buffer in bytes
pub fn get_rx_buf_sz(&self) -> Result<usize> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_RCVBUF,
)?;
Expand All @@ -552,7 +587,7 @@ impl Socket {
pub fn set_netlink_get_strict_chk(&self, value: bool) -> Result<()> {
let value: u32 = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_GET_STRICT_CHK,
value,
Expand Down

0 comments on commit 0cb2ed9

Please sign in to comment.