Skip to content

Commit

Permalink
chore(deps): upgrade async-io
Browse files Browse the repository at this point in the history
  • Loading branch information
hanabi1224 committed Nov 29, 2024
1 parent 3d5879c commit 91e555f
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ features = ["os-poll", "os-ext"]

[dependencies.async-io]
optional = true
version = "1.13"
version = "2"

[features]
default = []
mio_socket = ["mio"]
tokio_socket = ["tokio", "futures"]
smol_socket = ["async-io","futures"]
smol_socket = ["async-io", "futures"]

[dev-dependencies]
netlink-packet-audit = "0.4.1"
Expand Down
20 changes: 13 additions & 7 deletions src/smol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{
io,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd},
task::{Context, Poll},
};

Expand Down Expand Up @@ -31,6 +31,12 @@ impl AsRawFd for SmolSocket {
}
}

impl AsFd for SmolSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.0.get_ref().as_fd()
}
}

// async_io::Async<..>::{read,write}_with[_mut] functions try IO first,
// and only register context if it would block.
// replicate this in these poll functions:
Expand Down Expand Up @@ -79,7 +85,7 @@ impl AsyncSocket for SmolSocket {

/// Mutable access to underyling [`Socket`]
fn socket_mut(&mut self) -> &mut Socket {
self.0.get_mut()
unsafe { self.0.get_mut() }
}

fn new(protocol: isize) -> io::Result<Self> {
Expand All @@ -92,7 +98,7 @@ impl AsyncSocket for SmolSocket {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write_with(cx, |this| this.0.get_mut().send(buf, 0))
self.poll_write_with(cx, |this| this.socket_mut().send(buf, 0))
}

fn poll_send_to(
Expand All @@ -101,7 +107,7 @@ impl AsyncSocket for SmolSocket {
buf: &[u8],
addr: &SocketAddr,
) -> Poll<io::Result<usize>> {
self.poll_write_with(cx, |this| this.0.get_mut().send_to(buf, addr, 0))
self.poll_write_with(cx, |this| this.socket_mut().send_to(buf, addr, 0))
}

fn poll_recv<B>(
Expand All @@ -113,7 +119,7 @@ impl AsyncSocket for SmolSocket {
B: bytes::BufMut,
{
self.poll_read_with(cx, |this| {
this.0.get_mut().recv(buf, 0).map(|_len| ())
this.socket_mut().recv(buf, 0).map(|_len| ())
})
}

Expand All @@ -126,7 +132,7 @@ impl AsyncSocket for SmolSocket {
B: bytes::BufMut,
{
self.poll_read_with(cx, |this| {
let x = this.0.get_mut().recv_from(buf, 0);
let x = this.socket_mut().recv_from(buf, 0);
trace!("poll_recv_from: {:?}", x);
x.map(|(_len, addr)| addr)
})
Expand All @@ -136,6 +142,6 @@ impl AsyncSocket for SmolSocket {
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
self.poll_read_with(cx, |this| this.0.get_mut().recv_from_full())
self.poll_read_with(cx, |this| this.socket_mut().recv_from_full())
}
}
99 changes: 70 additions & 29 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
use std::{
io::{Error, Result},
mem,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::{
fd::{AsFd, BorrowedFd, FromRawFd},
unix::io::{AsRawFd, RawFd},
},
};

use crate::SocketAddr;
Expand Down Expand Up @@ -60,6 +63,12 @@ impl AsRawFd for Socket {
}
}

impl AsFd for Socket {
fn as_fd(&self) -> BorrowedFd<'_> {
unsafe { BorrowedFd::borrow_raw(self.0) }
}
}

impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Socket(fd)
Expand All @@ -68,7 +77,7 @@ impl FromRawFd for Socket {

impl Drop for Socket {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
unsafe { libc::close(self.as_raw_fd()) };
}
}

Expand All @@ -94,7 +103,7 @@ impl Socket {
/// Bind the socket to the given address
pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_raw();
let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
let res = unsafe { libc::bind(self.as_raw_fd(), addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -115,7 +124,9 @@ impl Socket {
let (addr_ptr, mut addr_len) = addr.as_raw_mut();
let addr_len_copy = addr_len;
let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
let res = unsafe {
libc::getsockname(self.as_raw_fd(), addr_ptr, addr_len_ptr)
};
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -128,8 +139,9 @@ impl Socket {
/// Make this socket non-blocking
pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
let mut non_blocking = non_blocking as libc::c_int;
let res =
unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
let res = unsafe {
libc::ioctl(self.as_raw_fd(), libc::FIONBIO, &mut non_blocking)
};
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -152,7 +164,7 @@ impl Socket {
/// 2. connect it to the kernel with [`Socket::connect`]
/// 3. send a request to the kernel with [`Socket::send`]
/// 4. read the response (which can span over several messages)
/// [`Socket::recv`]
/// [`Socket::recv`]
///
/// ```rust
/// use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
Expand Down Expand Up @@ -216,7 +228,7 @@ impl Socket {
// - https://stackoverflow.com/a/14046386/1836144
// - https://lists.isc.org/pipermail/bind-users/2009-August/077527.html
let (addr, addr_len) = remote_addr.as_raw();
let res = unsafe { libc::connect(self.0, addr, addr_len) };
let res = unsafe { libc::connect(self.as_raw_fd(), addr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand Down Expand Up @@ -293,7 +305,7 @@ impl Socket {

let res = unsafe {
libc::recvfrom(
self.0,
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
Expand Down Expand Up @@ -324,7 +336,8 @@ impl Socket {
let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void;
let buf_len = chunk.len() as libc::size_t;

let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
let res =
unsafe { libc::recv(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
} else {
Expand Down Expand Up @@ -367,7 +380,14 @@ impl Socket {
let buf_len = buf.len() as libc::size_t;

let res = unsafe {
libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len)
libc::sendto(
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
addr_ptr,
addr_len,
)
};
if res < 0 {
return Err(Error::last_os_error());
Expand All @@ -382,7 +402,8 @@ impl Socket {
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;

let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
let res =
unsafe { libc::send(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -391,12 +412,17 @@ impl Socket {

pub fn set_pktinfo(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
value,
)
}

pub fn get_pktinfo(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
)?;
Expand All @@ -405,7 +431,7 @@ impl Socket {

pub fn add_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_ADD_MEMBERSHIP,
group,
Expand All @@ -414,7 +440,7 @@ impl Socket {

pub fn drop_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_DROP_MEMBERSHIP,
group,
Expand All @@ -434,7 +460,7 @@ impl Socket {
pub fn set_broadcast_error(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
value,
Expand All @@ -443,7 +469,7 @@ impl Socket {

pub fn get_broadcast_error(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
)?;
Expand All @@ -454,12 +480,17 @@ impl Socket {
/// unicast and broadcast listeners to avoid receiving `ENOBUFS` errors.
pub fn set_no_enobufs(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
value,
)
}

pub fn get_no_enobufs(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
)?;
Expand All @@ -474,7 +505,7 @@ impl Socket {
pub fn set_listen_all_namespaces(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
value,
Expand All @@ -483,7 +514,7 @@ impl Socket {

pub fn get_listen_all_namespaces(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
)?;
Expand All @@ -498,12 +529,17 @@ impl Socket {
/// acknowledgment.
pub fn set_cap_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
value,
)
}

pub fn get_cap_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
)?;
Expand All @@ -515,12 +551,17 @@ impl Socket {
/// NLMSG_ERROR and NLMSG_DONE messages.
pub fn set_ext_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_EXT_ACK, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
value,
)
}

pub fn get_ext_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
)?;
Expand All @@ -534,13 +575,13 @@ impl Socket {
/// the maximum allowed value is set by the /proc/sys/net/core/rmem_max
/// file. The minimum (doubled) value for this option is 256.
pub fn set_rx_buf_sz<T>(&self, size: T) -> Result<()> {
setsockopt(self.0, libc::SOL_SOCKET, libc::SO_RCVBUF, size)
setsockopt(self.as_raw_fd(), libc::SOL_SOCKET, libc::SO_RCVBUF, size)
}

/// Gets socket receive buffer in bytes
pub fn get_rx_buf_sz(&self) -> Result<usize> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_RCVBUF,
)?;
Expand All @@ -552,7 +593,7 @@ impl Socket {
pub fn set_netlink_get_strict_chk(&self, value: bool) -> Result<()> {
let value: u32 = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_GET_STRICT_CHK,
value,
Expand Down

0 comments on commit 91e555f

Please sign in to comment.