Skip to content

Commit

Permalink
Instrument bytes copied to and from remotes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ongy committed Jan 7, 2025
1 parent 67506c0 commit a3375f3
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 28 deletions.
23 changes: 18 additions & 5 deletions src/tunnel/server/handler_http2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use http_body_util::{BodyStream, Either, StreamBody};
use hyper::body::{Frame, Incoming};
use hyper::header::CONTENT_TYPE;
use hyper::{Request, Response, StatusCode};
use opentelemetry::KeyValue;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
Expand Down Expand Up @@ -43,16 +44,28 @@ pub(super) async fn http_server_upgrade(
.body(Either::Right(body))
.expect("bug: failed to build response");

let attributes = [
KeyValue::new("remote_host", format!("{:}", remote_addr.host)),
KeyValue::new("remote_port", i64::from(remote_addr.port)),
];
tokio::spawn(
async move {
// We need a copy of the attributes that can be owned by the closure...
let to_remote_attributes = attributes.clone();
let to_metric = server.metrics.bytes_to_remote.clone();
let from_metric = server.metrics.bytes_from_remote.clone();
let (close_tx, close_rx) = oneshot::channel::<()>();
tokio::task::spawn(
transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx)
.instrument(Span::current()),
);
tokio::task::spawn(async move {
let bytes_written =
transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx)
.instrument(Span::current())
.await;
to_metric.record(bytes_written as u64, &to_remote_attributes);
});

let _ =
let bytes_written =
transport::io::propagate_local_to_remote(local_rx, Http2TunnelWrite::new(ws_tx), close_tx, None).await;
from_metric.record(bytes_written as u64, &attributes);
}
.instrument(Span::current()),
);
Expand Down
22 changes: 18 additions & 4 deletions src/tunnel/server/handler_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use http_body_util::Either;
use hyper::body::Incoming;
use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL};
use hyper::{Request, Response};
use opentelemetry::KeyValue;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -43,6 +44,10 @@ pub(super) async fn ws_server_upgrade(
return bad_request();
}
};
let attributes = [
KeyValue::new("remote_host", format!("{:}", remote_addr.host)),
KeyValue::new("remote_port", i64::from(remote_addr.port)),
];

tokio::spawn(
async move {
Expand All @@ -61,17 +66,26 @@ pub(super) async fn ws_server_upgrade(
};
let (close_tx, close_rx) = oneshot::channel::<()>();

tokio::task::spawn(
transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
);
// We need a copy of the attributes that can be owned by the closure...
let to_remote_attributes = attributes.clone();
let to_metric = server.metrics.bytes_to_remote.clone();
let from_metric = server.metrics.bytes_from_remote.clone();
tokio::task::spawn(async move {
let bytes_written = transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx)
.instrument(Span::current())
.await;

let _ = transport::io::propagate_local_to_remote(
to_metric.record(bytes_written as u64, &to_remote_attributes);
});

let bytes_written = transport::io::propagate_local_to_remote(
local_rx,
ws_tx,
close_tx,
server.config.websocket_ping_frequency,
)
.await;
from_metric.record(bytes_written as u64, &attributes);
Ok(())
}
.instrument(Span::current()),
Expand Down
10 changes: 10 additions & 0 deletions src/tunnel/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pub struct WsServerConfig {
pub struct WsServerMetrics {
pub connections: Counter<u64>,
pub connect_latencies: Histogram<u64>,
pub bytes_to_remote: Histogram<u64>,
pub bytes_from_remote: Histogram<u64>,
}

#[derive(Clone)]
Expand All @@ -93,6 +95,14 @@ impl WsServer {
.u64_histogram("connect_latency")
.with_description("Provides a latency histogram per target")
.build(),
bytes_to_remote: meter
.u64_histogram("bytes_to_remote")
.with_description("Provides information about how many bytes were proxied from the websocket to the target")
.build(),
bytes_from_remote: meter
.u64_histogram("bytes_from_remote")
.with_description("Provides information about how many bytes were proxied from the target to the websocket")
.build(),
}),
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/tunnel/transport/http2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ impl Http2TunnelRead {
}

impl TunnelRead for Http2TunnelRead {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<usize, io::Error> {
loop {
match self.inner.next().await {
Some(Ok(frame)) => match frame.into_data() {
Ok(data) => {
return match writer.write_all(data.as_ref()).await {
Ok(_) => Ok(()),
Ok(_) => Ok(data.len()),
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
}
}
Expand Down
44 changes: 29 additions & 15 deletions src/tunnel/transport/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub trait TunnelRead: Send + 'static {
fn copy(
&mut self,
writer: impl AsyncWrite + Unpin + Send,
) -> impl Future<Output = Result<(), std::io::Error>> + Send;
) -> impl Future<Output = Result<usize, std::io::Error>> + Send;
}

pub enum TunnelReader {
Expand All @@ -38,7 +38,7 @@ pub enum TunnelReader {
}

impl TunnelRead for TunnelReader {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<usize, std::io::Error> {
match self {
Self::Websocket(s) => s.copy(writer).await,
Self::Http2(s) => s.copy(writer).await,
Expand Down Expand Up @@ -100,7 +100,7 @@ pub async fn propagate_local_to_remote(
mut ws_tx: impl TunnelWrite,
mut close_tx: oneshot::Sender<()>,
ping_frequency: Option<Duration>,
) -> anyhow::Result<()> {
) -> usize {
let _guard = scopeguard::guard((), |_| {
info!("Closing local => remote tunnel");
});
Expand All @@ -116,6 +116,7 @@ pub async fn propagate_local_to_remote(
let notify = ws_tx.pending_operations_notify();
let mut has_pending_operations = notify.notified();
let mut has_pending_operations_pin = unsafe { Pin::new_unchecked(&mut has_pending_operations) };
let mut bytes_sent = 0;

pin_mut!(timeout);
pin_mut!(should_close);
Expand Down Expand Up @@ -147,14 +148,19 @@ pub async fn propagate_local_to_remote(

_ = timeout.tick(), if ping_frequency.is_some() => {
debug!("sending ping to keep connection alive");
ws_tx.ping().await?;
if let Err(err) = ws_tx.ping().await {
warn!("error while sending ping {err}");
break;
}
continue;
}
};

let _read_len = match read_len {
match read_len {
Ok(0) => break,
Ok(read_len) => read_len,
Ok(read_len) => {
bytes_sent = bytes_sent + read_len as usize;
}
Err(err) => {
warn!("error while reading incoming bytes from local tx tunnel: {}", err);
break;
Expand All @@ -171,35 +177,43 @@ pub async fn propagate_local_to_remote(
// Send normal close
let _ = ws_tx.close().await;

Ok(())
bytes_sent
}

/// Read incoming bytes on the websocket and write them to the local connection.
/// I.e. this function moves data out of the websocket and into the network.
pub async fn propagate_remote_to_local(
local_tx: impl AsyncWrite + Send,
mut ws_rx: impl TunnelRead,
mut close_rx: oneshot::Receiver<()>,
) -> anyhow::Result<()> {
) -> usize {
let _guard = scopeguard::guard((), |_| {
info!("Closing local <= remote tunnel");
});

pin_mut!(local_tx);
let mut bytes_copied = 0;
loop {
let msg = select! {
biased;
msg = ws_rx.copy(&mut local_tx) => msg,
_ = &mut close_rx => break,
};

if let Err(err) = msg {
match err.kind() {
ErrorKind::NotConnected => debug!("Connection closed frame received"),
ErrorKind::BrokenPipe => debug!("Remote side closed connection"),
_ => error!("error while reading from tunnel rx {err}"),
match msg {
Err(err) => {
match err.kind() {
ErrorKind::NotConnected => debug!("Connection closed frame received"),
ErrorKind::BrokenPipe => debug!("Remote side closed connection"),
_ => error!("error while reading from tunnel rx {err}"),
}
break;
}
Ok(v) => {
bytes_copied = bytes_copied + v;
}
break;
}
}

Ok(())
bytes_copied
}
4 changes: 2 additions & 2 deletions src/tunnel/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ fn frame_reader(_: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>>
}

impl TunnelRead for WebsocketTunnelRead {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<usize, io::Error> {
loop {
let msg = match self.inner.read_frame(&mut frame_reader).await {
Ok(msg) => msg,
Expand All @@ -193,7 +193,7 @@ impl TunnelRead for WebsocketTunnelRead {
match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
return match writer.write_all(msg.payload.as_ref()).await {
Ok(_) => Ok(()),
Ok(_) => Ok(msg.payload.len()),
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
}
}
Expand Down

0 comments on commit a3375f3

Please sign in to comment.