diff --git a/tests/integration_tests/src/lib.rs b/tests/integration_tests/src/lib.rs index 68bf837b6..2d3c3f3f8 100644 --- a/tests/integration_tests/src/lib.rs +++ b/tests/integration_tests/src/lib.rs @@ -5,6 +5,7 @@ pub mod pb { pub mod mock { use std::{ + io::IoSlice, pin::Pin, task::{Context, Poll}, }; @@ -51,6 +52,18 @@ pub mod mock { ) -> Poll> { Pin::new(&mut self.0).poll_shutdown(cx) } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } } } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 2b465355f..2230b9b2e 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,6 +1,7 @@ use crate::transport::server::Connected; use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; +use std::io::IoSlice; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -65,6 +66,18 @@ impl AsyncWrite for BoxedIo { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_shutdown(cx) } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } } pub(crate) enum ServerIo { @@ -163,4 +176,24 @@ where Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx), } } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::Io(io) => io.is_write_vectored(), + #[cfg(feature = "tls")] + Self::TlsIo(io) => io.is_write_vectored(), + } + } }