From e527b7c3d6c1e9f48448e135836fd97e20313e46 Mon Sep 17 00:00:00 2001 From: Josh Robson Chase Date: Mon, 12 Feb 2024 13:18:16 -0500 Subject: [PATCH 1/2] ngrok: update hyper and axum --- .gitignore | 1 + cargo-doc-ngrok/Cargo.toml | 10 +-- cargo-doc-ngrok/src/main.rs | 57 ++++++++++++---- ngrok/Cargo.toml | 16 ++++- ngrok/examples/axum.rs | 79 +++++++++++++++------ ngrok/examples/labeled.rs | 75 ++++++++++++++------ ngrok/examples/tls.rs | 74 ++++++++++++++------ ngrok/src/conn.rs | 64 ++++++++++++++++- ngrok/src/forwarder.rs | 10 +-- ngrok/src/online_tests.rs | 133 +++++++++++++++++++++++------------- ngrok/src/proxy_proto.rs | 50 ++++++++++++++ ngrok/src/session.rs | 2 +- ngrok/src/tunnel.rs | 33 --------- ngrok/src/tunnel_ext.rs | 35 ++++------ 14 files changed, 446 insertions(+), 193 deletions(-) diff --git a/.gitignore b/.gitignore index 6393a36..9541188 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.env /target /Cargo.lock .direnv diff --git a/cargo-doc-ngrok/Cargo.toml b/cargo-doc-ngrok/Cargo.toml index c675084..79afb41 100644 --- a/cargo-doc-ngrok/Cargo.toml +++ b/cargo-doc-ngrok/Cargo.toml @@ -7,14 +7,16 @@ description = "A cargo subcommand to build and serve documentation via ngrok" repository = "https://github.com/ngrok/ngrok-rust" [dependencies] -axum = "0.6.1" +awaitdrop = "0.1.2" +axum = "0.7.4" bstr = "1.4.0" cargo_metadata = "0.15.2" clap = { version = "4.0.29", features = ["derive"] } futures = "0.3.25" -http = "0.2.8" -hyper = { version = "0.14.23", features = ["server"] } -hyper-staticfile = "0.9.2" +http = "1.0.0" +hyper = { version = "1.1.0", features = ["server"] } +hyper-staticfile = "0.10.0" +hyper-util = { version = "0.1.3", features = ["server", "tokio", "server-auto", "http1"] } ngrok = { path = "../ngrok", version = "0.14.0-pre.1", features = ["hyper"] } tokio = { version = "1.23.0", features = ["full"] } watchexec = "2.3.0" diff --git a/cargo-doc-ngrok/src/main.rs b/cargo-doc-ngrok/src/main.rs index d40f6cc..599373b 100644 --- a/cargo-doc-ngrok/src/main.rs +++ b/cargo-doc-ngrok/src/main.rs @@ -1,17 +1,22 @@ use std::{ - error::Error, io, path::PathBuf, process::Stdio, sync::Arc, }; +use axum::BoxError; use clap::{ Args, Parser, Subcommand, }; -use hyper::service::make_service_fn; +use futures::TryStreamExt; +use hyper::service::service_fn; +use hyper_util::{ + rt::TokioExecutor, + server, +}; use ngrok::prelude::*; use watchexec::{ action::{ @@ -56,7 +61,7 @@ struct DocNgrok { } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<(), BoxError> { let Cmd::DocNgrok(args) = Cargo::parse().cmd; std::process::Command::new("cargo") @@ -82,31 +87,57 @@ async fn main() -> Result<(), Box> { .connect() .await?; - let mut tunnel_cfg = sess.http_endpoint(); + let mut listen_cfg = sess.http_endpoint(); if let Some(domain) = args.domain { - tunnel_cfg.domain(domain); + listen_cfg.domain(domain); } - let tunnel = tunnel_cfg.listen().await?; + let mut listener = listen_cfg.listen().await?; + + let service = service_fn(move |req| { + let stat = hyper_staticfile::Static::new(&doc_dir); + stat.serve(req) + }); println!( "serving docs on: {}/{}/", - tunnel.url(), + listener.url(), default_package.replace('-', "_") ); - let srv = hyper::server::Server::builder(tunnel).serve(make_service_fn(move |_| { - let stat = hyper_staticfile::Static::new(&doc_dir); - async move { Result::<_, String>::Ok(stat) } - })); + let server = async move { + let (dropref, waiter) = awaitdrop::awaitdrop(); + + // Continuously accept new connections. + while let Some(conn) = listener.try_next().await? { + let service = service.clone(); + let dropref = dropref.clone(); + // Spawn a task to handle the connection. That way we can multiple connections + // concurrently. + tokio::spawn(async move { + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection(conn, service) + .await + { + eprintln!("failed to serve connection: {err:#}"); + } + drop(dropref); + }); + } + + // Wait until all children have finished, not just the listener. + drop(dropref); + waiter.await; + + Ok::<(), BoxError>(()) + }; if args.watch { let we = make_watcher(args.doc_args, root_dir, target_dir)?; - tokio::spawn(srv); we.main().await??; } else { - srv.await?; + server.await?; } Ok(()) diff --git a/ngrok/Cargo.toml b/ngrok/Cargo.toml index de7210d..e32298e 100644 --- a/ngrok/Cargo.toml +++ b/ngrok/Cargo.toml @@ -23,8 +23,9 @@ tracing = "0.1.37" futures-rustls = { version = "0.25.1" } tokio-util = { version = "0.7.4", features = ["compat"] } futures = "0.3.25" -hyper = { version = "0.14.23" } -axum = { version = "0.6.1", features = ["tokio"], optional = true } +hyper-0-14 = { package = "hyper", version = "0.14" } +hyper = { version = "1.1.0", optional = true } +axum = { version = "0.7.4", features = ["tokio"], optional = true } rustls-pemfile = "2.0.0" async-trait = "0.1.59" bytes = "1.3.0" @@ -43,11 +44,20 @@ url = "2.4.0" rustls-native-certs = "0.7.0" proxy-protocol = "0.5.0" pin-project = "1.1.3" +axum-core = "0.4.3" +futures-util = "0.3.30" [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.45.0", features = ["Win32_Foundation"] } [dev-dependencies] +hyper = "1.1.0" +hyper-util = { version = "0.1.3", features = [ + "tokio", + "server", + "http1", + "http2", +] } tokio = { version = "1.23.0", features = ["full"] } anyhow = "1.0.66" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } @@ -60,6 +70,8 @@ tokio-tungstenite = { version = "0.18.0", features = [ "rustls", "rustls-tls-webpki-roots", ] } +tower = "0.4.13" +axum = { version = "0.7.4", features = ["tokio"] } [[example]] name = "tls" diff --git a/ngrok/examples/axum.rs b/ngrok/examples/axum.rs index b5281ae..571458b 100644 --- a/ngrok/examples/axum.rs +++ b/ngrok/examples/axum.rs @@ -1,17 +1,31 @@ -use std::net::SocketAddr; +use std::{ + convert::Infallible, + net::SocketAddr, +}; use axum::{ extract::ConnectInfo, routing::get, Router, }; -use ngrok::{ - prelude::*, - tunnel::HttpTunnel, +use axum_core::BoxError; +use futures::stream::TryStreamExt; +use hyper::{ + body::Incoming, + Request, +}; +use hyper_util::{ + rt::TokioExecutor, + server, +}; +use ngrok::prelude::*; +use tower::{ + Service, + ServiceExt, }; #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<(), BoxError> { // build our application with a single route let app = Router::new().route( "/", @@ -22,21 +36,7 @@ async fn main() -> anyhow::Result<()> { ), ); - // run it with hyper on localhost:8000 - // axum::Server::bind(&"0.0.0.0:8000".parse().unwrap()) - // Or with an ngrok tunnel - axum::Server::builder(start_tunnel().await?) - .serve(app.into_make_service_with_connect_info::()) - .await - .unwrap(); - - Ok(()) -} - -// const CA_CERT: &[u8] = include_bytes!("ca.crt"); - -async fn start_tunnel() -> anyhow::Result { - let tun = ngrok::Session::builder() + let mut listener = ngrok::Session::builder() .authtoken_from_env() .connect() .await? @@ -74,9 +74,35 @@ async fn start_tunnel() -> anyhow::Result { .listen() .await?; - println!("Tunnel started on URL: {:?}", tun.url()); + println!("Listener started on URL: {:?}", listener.url()); + + let mut make_service = app.into_make_service_with_connect_info::(); - Ok(tun) + let server = async move { + while let Some(conn) = listener.try_next().await? { + let remote_addr = conn.remote_addr(); + let tower_service = unwrap_infallible(make_service.call(remote_addr).await); + + tokio::spawn(async move { + let hyper_service = + hyper::service::service_fn(move |request: Request| { + tower_service.clone().oneshot(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(conn, hyper_service) + .await + { + eprintln!("failed to serve connection: {err:#}"); + } + }); + } + Ok::<(), BoxError>(()) + }; + + server.await?; + + Ok(()) } #[allow(dead_code)] @@ -103,3 +129,12 @@ fn create_policy() -> Result { ) .to_owned()) } + +// const CA_CERT: &[u8] = include_bytes!("ca.crt"); + +fn unwrap_infallible(result: Result) -> T { + match result { + Ok(value) => value, + Err(err) => match err {}, + } +} diff --git a/ngrok/examples/labeled.rs b/ngrok/examples/labeled.rs index 745abae..b71915e 100644 --- a/ngrok/examples/labeled.rs +++ b/ngrok/examples/labeled.rs @@ -1,17 +1,32 @@ -use std::net::SocketAddr; +use std::{ + convert::Infallible, + error::Error, + net::SocketAddr, +}; use axum::{ extract::ConnectInfo, routing::get, + BoxError, Router, }; -use ngrok::{ - prelude::*, - tunnel::LabeledTunnel, +use futures::TryStreamExt; +use hyper::{ + body::Incoming, + Request, +}; +use hyper_util::{ + rt::TokioExecutor, + server, +}; +use ngrok::prelude::*; +use tower::{ + Service, + ServiceExt, }; #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<(), Box> { // build our application with a single route let app = Router::new().route( "/", @@ -22,24 +37,12 @@ async fn main() -> anyhow::Result<()> { ), ); - // run it with hyper on localhost:8000 - // axum::Server::bind(&"0.0.0.0:8000".parse().unwrap()) - // Or with an ngrok tunnel - axum::Server::builder(start_tunnel().await?) - .serve(app.into_make_service_with_connect_info::()) - .await - .unwrap(); - - Ok(()) -} - -async fn start_tunnel() -> anyhow::Result { let sess = ngrok::Session::builder() .authtoken_from_env() .connect() .await?; - let tun = sess + let mut listener = sess .labeled_tunnel() //.app_protocol("http2") .label("edge", "edghts_") @@ -47,7 +50,39 @@ async fn start_tunnel() -> anyhow::Result { .listen() .await?; - println!("Labeled tunnel started!"); + println!("Labeled listener started!"); + + let mut make_service = app.into_make_service_with_connect_info::(); + + let server = async move { + while let Some(conn) = listener.try_next().await? { + let remote_addr = conn.remote_addr(); + let tower_service = unwrap_infallible(make_service.call(remote_addr).await); + + tokio::spawn(async move { + let hyper_service = + hyper::service::service_fn(move |request: Request| { + tower_service.clone().oneshot(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(conn, hyper_service) + .await + { + eprintln!("failed to serve connection: {err:#}"); + } + }); + } + Ok::<(), BoxError>(()) + }; + + server.await?; + Ok(()) +} - Ok(tun) +fn unwrap_infallible(result: Result) -> T { + match result { + Ok(value) => value, + Err(err) => match err {}, + } } diff --git a/ngrok/examples/tls.rs b/ngrok/examples/tls.rs index 098f279..e2a71bb 100644 --- a/ngrok/examples/tls.rs +++ b/ngrok/examples/tls.rs @@ -1,13 +1,28 @@ -use std::net::SocketAddr; +use std::{ + convert::Infallible, + error::Error, + net::SocketAddr, +}; use axum::{ extract::ConnectInfo, routing::get, + BoxError, Router, }; -use ngrok::{ - prelude::*, - tunnel::TlsTunnel, +use futures::TryStreamExt; +use hyper::{ + body::Incoming, + Request, +}; +use hyper_util::{ + rt::TokioExecutor, + server, +}; +use ngrok::prelude::*; +use tower::{ + Service, + ServiceExt, }; const CERT: &[u8] = include_bytes!("domain.crt"); @@ -15,7 +30,7 @@ const KEY: &[u8] = include_bytes!("domain.key"); // const CA_CERT: &[u8] = include_bytes!("ca.crt"); #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<(), Box> { // build our application with a single route let app = Router::new().route( "/", @@ -26,24 +41,12 @@ async fn main() -> anyhow::Result<()> { ), ); - // run it with hyper on localhost:8000 - // axum::Server::bind(&"0.0.0.0:8000".parse().unwrap()) - // Or with an ngrok tunnel - axum::Server::builder(start_tunnel().await?) - .serve(app.into_make_service_with_connect_info::()) - .await - .unwrap(); - - Ok(()) -} - -async fn start_tunnel() -> anyhow::Result { let sess = ngrok::Session::builder() .authtoken_from_env() .connect() .await?; - let tun = sess + let mut listener = sess .tls_endpoint() // .allow_cidr("0.0.0.0/0") // .deny_cidr("10.1.1.1/32") @@ -56,7 +59,38 @@ async fn start_tunnel() -> anyhow::Result { .listen() .await?; - println!("Tunnel started on URL: {:?}", tun.url()); + let mut make_service = app.into_make_service_with_connect_info::(); + + let server = async move { + while let Some(conn) = listener.try_next().await? { + let remote_addr = conn.remote_addr(); + let tower_service = unwrap_infallible(make_service.call(remote_addr).await); + + tokio::spawn(async move { + let hyper_service = + hyper::service::service_fn(move |request: Request| { + tower_service.clone().oneshot(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(conn, hyper_service) + .await + { + eprintln!("failed to serve connection: {err:#}"); + } + }); + } + Ok::<(), BoxError>(()) + }; + + server.await?; + + Ok(()) +} - Ok(tun) +fn unwrap_infallible(result: Result) -> T { + match result { + Ok(value) => value, + Err(err) => match err {}, + } } diff --git a/ngrok/src/conn.rs b/ngrok/src/conn.rs index 7c11399..f3a6d88 100644 --- a/ngrok/src/conn.rs +++ b/ngrok/src/conn.rs @@ -10,6 +10,11 @@ use std::{ // Support for axum's connection info trait. #[cfg(feature = "axum")] use axum::extract::connect_info::Connected; +#[cfg(feature = "hyper")] +use hyper::rt::{ + Read as HyperRead, + Write as HyperWrite, +}; use muxado::typed::TypedStream; use tokio::io::{ AsyncRead, @@ -61,9 +66,24 @@ impl EndpointConnInfo for Info { } } -/// An incoming connection over an ngrok tunnel. -/// Effectively a trait alias for async read+write, plus connection info. -pub trait Conn: ConnInfo + AsyncRead + AsyncWrite + Unpin + Send + 'static {} +// This codgen indirect is required to make the hyper io trait bounds +// dependent on the hyper feature. You can't put a #[cfg] on a single bound, so +// we're putting the whole trait def in a macro. Gross, but gets the job done. +macro_rules! conn_trait { + ($($hyper_bound:tt)*) => { + /// An incoming connection over an ngrok tunnel. + /// Effectively a trait alias for async read+write, plus connection info. + pub trait Conn: ConnInfo + AsyncRead + AsyncWrite $($hyper_bound)* + Unpin + Send + 'static {} + } +} + +#[cfg(not(feature = "hyper"))] +conn_trait!(); + +#[cfg(feature = "hyper")] +conn_trait! { + + hyper::rt::Read + hyper::rt::Write +} /// Information common to all ngrok connections. pub trait ConnInfo { @@ -129,6 +149,44 @@ macro_rules! make_conn_type { } } + #[cfg(feature = "hyper")] + impl HyperRead for $wrapper { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe{ buf.as_mut() }); + let res = std::task::ready!(Pin::new(&mut *self.inner.stream).poll_read(cx, &mut tokio_buf)); + let filled = tokio_buf.filled().len(); + unsafe { buf.advance(filled) }; + Poll::Ready(res) + } + } + + #[cfg(feature = "hyper")] + impl HyperWrite for $wrapper { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut *self.inner.stream).poll_write(cx, buf) + } + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut *self.inner.stream).poll_flush(cx) + } + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut *self.inner.stream).poll_shutdown(cx) + } + } + impl AsyncWrite for $wrapper { fn poll_write( mut self: Pin<&mut Self>, diff --git a/ngrok/src/forwarder.rs b/ngrok/src/forwarder.rs index 3282dea..bcdd361 100644 --- a/ngrok/src/forwarder.rs +++ b/ngrok/src/forwarder.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - io, + error::Error as StdError, }; use async_trait::async_trait; @@ -22,13 +22,13 @@ use crate::{ /// /// Represents a tunnel that is being forwarded to a URL. pub struct Forwarder { - pub(crate) join: JoinHandle>, + pub(crate) join: JoinHandle>>, pub(crate) inner: T, } impl Forwarder { /// Wait for the forwarding task to exit. - pub fn join(&mut self) -> &mut JoinHandle> { + pub fn join(&mut self) -> &mut JoinHandle>> { &mut self.join } } @@ -88,7 +88,9 @@ where ::Conn: crate::tunnel_ext::ConnExt, { let handle = - tokio::spawn(async move { crate::tunnel_ext::forward_tunnel(&mut listener, to_url).await }); + tokio::spawn( + async move { Ok(crate::tunnel_ext::forward_tunnel(&mut listener, to_url).await?) }, + ); Ok(Forwarder { join: handle, diff --git a/ngrok/src/online_tests.rs b/ngrok/src/online_tests.rs index e3fcbf4..5322af9 100644 --- a/ngrok/src/online_tests.rs +++ b/ngrok/src/online_tests.rs @@ -1,4 +1,5 @@ use std::{ + convert::Infallible, io, io::prelude::*, net::SocketAddr, @@ -13,13 +14,10 @@ use std::{ time::Duration, }; -use anyhow::{ - anyhow, - Error, -}; +use anyhow::anyhow; use axum::{ - extract::connect_info::Connected, routing::get, + BoxError, Router, }; use bytes::Bytes; @@ -28,6 +26,7 @@ use futures::{ channel::oneshot, prelude::*, stream::FuturesUnordered, + TryStreamExt, }; use futures_rustls::rustls::{ pki_types, @@ -35,11 +34,19 @@ use futures_rustls::rustls::{ RootCertStore, }; use hyper::{ - header, + body::Incoming, HeaderMap, + Request, +}; +use hyper_0_14::{ + header, StatusCode, Uri, }; +use hyper_util::{ + rt::TokioExecutor, + server, +}; use once_cell::sync::Lazy; use paste::paste; use proxy_protocol::ProxyHeader; @@ -62,6 +69,10 @@ use tokio_tungstenite::{ tungstenite::Message, }; use tokio_util::compat::*; +use tower::{ + Service, + ServiceExt, +}; use tracing_test::traced_test; use url::Url; @@ -74,13 +85,13 @@ use crate::{ Session, }; -async fn setup_session() -> Result { +async fn setup_session() -> Result { Ok(Session::builder().authtoken_from_env().connect().await?) } #[cfg_attr(not(feature = "online-tests"), ignore)] #[test] -async fn listen() -> Result<(), Error> { +async fn listen() -> Result<(), BoxError> { let _ = Session::builder() .authtoken_from_env() .connect() @@ -93,7 +104,7 @@ async fn listen() -> Result<(), Error> { #[cfg_attr(not(feature = "online-tests"), ignore)] #[test] -async fn tunnel() -> Result<(), Error> { +async fn tunnel() -> Result<(), BoxError> { let tun = setup_session() .await? .http_endpoint() @@ -126,7 +137,7 @@ async fn serve_http( build_session: impl FnOnce(&mut SessionBuilder) -> &mut SessionBuilder, build_tunnel: impl FnOnce(&mut HttpTunnelBuilder) -> &mut HttpTunnelBuilder, router: axum::Router, -) -> Result { +) -> Result { let sess = build_session(Session::builder().authtoken_from_env()) .connect() .await?; @@ -136,20 +147,40 @@ async fn serve_http( Ok(start_http_server(tun, router)) } -fn start_http_server(tun: T, router: Router) -> TunnelGuard +fn start_http_server(mut tun: T, router: Router) -> TunnelGuard where - T: EndpointInfo + Tunnel, - for<'a> SocketAddr: Connected<&'a ::Conn>, + T: EndpointInfo + Tunnel + 'static, + T::Conn: crate::tunnel_ext::ConnExt, { let url = tun.url().into(); let (tx, rx) = oneshot::channel::<()>(); - tokio::spawn(futures::future::select( - axum::Server::builder(tun) - .serve(router.into_make_service_with_connect_info::()), - rx, - )); + let mut make_service = router.into_make_service_with_connect_info::(); + + let server = async move { + while let Some(conn) = tun.try_next().await? { + let remote_addr = conn.remote_addr(); + let tower_service = unwrap_infallible(make_service.call(remote_addr).await); + + tokio::spawn(async move { + let hyper_service = + hyper::service::service_fn(move |request: Request| { + tower_service.clone().oneshot(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(conn, hyper_service) + .await + { + eprintln!("failed to serve connection: {err:#}"); + } + }); + } + Ok::<(), BoxError>(()) + }; + + tokio::spawn(futures::future::select(Box::pin(server), rx)); TunnelGuard { tx: tx.into(), url } } @@ -161,7 +192,7 @@ fn hello_router() -> Router { Router::new().route("/", get(|| async { "Hello, world!" })) } -async fn check_body(url: impl AsRef, expected: impl AsRef) -> Result<(), Error> { +async fn check_body(url: impl AsRef, expected: impl AsRef) -> Result<(), BoxError> { let body: String = reqwest::get(url.as_ref()).await?.text().await?; assert_eq!(body, expected.as_ref()); Ok(()) @@ -169,7 +200,7 @@ async fn check_body(url: impl AsRef, expected: impl AsRef) -> Result<( #[cfg_attr(not(feature = "online-tests"), ignore)] #[test] -async fn https() -> Result<(), Error> { +async fn https() -> Result<(), BoxError> { let tun = serve_http(defaults, defaults, hello_router()).await?; let url = tun.url.as_str(); @@ -182,7 +213,7 @@ async fn https() -> Result<(), Error> { #[cfg_attr(not(feature = "online-tests"), ignore)] #[test] -async fn http() -> Result<(), Error> { +async fn http() -> Result<(), BoxError> { let tun = serve_http(defaults, |tun| tun.scheme(Scheme::HTTP), hello_router()).await?; let url = tun.url.as_str(); @@ -195,7 +226,7 @@ async fn http() -> Result<(), Error> { #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] -async fn http_compression() -> Result<(), Error> { +async fn http_compression() -> Result<(), BoxError> { let tun = serve_http(defaults, |tun| tun.compression(), hello_router()).await?; let url = tun.url.as_str(); @@ -224,8 +255,8 @@ async fn http_compression() -> Result<(), Error> { #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] -async fn http_headers() -> Result<(), Error> { - let (tx, mut rx) = mpsc::channel::(16); +async fn http_headers() -> Result<(), BoxError> { + let (tx, mut rx) = mpsc::channel::(16); // For some reason, the hyper machinery keeps a clone of the `tx`, which // causes it to never look closed, even when we drop the tunnel guard, which // shuts down the hyper server. Maybe a leaked task? Work around it by @@ -240,17 +271,14 @@ async fn http_headers() -> Result<(), Error> { if let Some(bar) = headers.get("foo") { if bar != "bar" { let _ = tx - .send(anyhow!( - "unexpected value for 'foo' request header: {:?}", - bar - )) + .send(format!("unexpected value for 'foo' request header: {:?}", bar).into()) .await; } } else { - let _ = tx.send(anyhow!("missing 'foo' request header")).await; + let _ = tx.send("missing 'foo' request header".into()).await; } if headers.get("baz").is_some() { - let _ = tx.send(anyhow!("got 'baz' request header")).await; + let _ = tx.send("got 'baz' request header".into()).await; } ([("python", "lolnope")], "Hello, world!") @@ -292,7 +320,7 @@ async fn http_headers() -> Result<(), Error> { #[traced_test] #[cfg_attr(not(feature = "authenticated-tests"), ignore)] #[test] -async fn user_agent() -> Result<(), Error> { +async fn user_agent() -> Result<(), BoxError> { let tun = serve_http( defaults, |tun| tun.allow_user_agent("foo.*").deny_user_agent(".*"), @@ -319,7 +347,7 @@ async fn user_agent() -> Result<(), Error> { #[traced_test] #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] -async fn basic_auth() -> Result<(), Error> { +async fn basic_auth() -> Result<(), BoxError> { let tun = serve_http( defaults, |tun| tun.basic_auth("user", "foobarbaz"), @@ -345,7 +373,7 @@ async fn basic_auth() -> Result<(), Error> { #[traced_test] #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] -async fn oauth() -> Result<(), Error> { +async fn oauth() -> Result<(), BoxError> { let tun = serve_http( defaults, |tun| tun.oauth(OauthOptions::new("google")), @@ -366,7 +394,7 @@ async fn oauth() -> Result<(), Error> { #[traced_test] #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] -async fn custom_domain() -> Result<(), Error> { +async fn custom_domain() -> Result<(), BoxError> { let mut rng = thread_rng(); let subdomain = (0..7) .map(|_| rng.sample(Alphanumeric) as char) @@ -387,7 +415,7 @@ async fn custom_domain() -> Result<(), Error> { #[traced_test] #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] -async fn policy() -> Result<(), Error> { +async fn policy() -> Result<(), BoxError> { let tun = serve_http( defaults, |tun| tun.policy(create_policy()).unwrap(), @@ -423,7 +451,7 @@ fn create_policy() -> Result { #[traced_test] #[cfg_attr(not(all(feature = "paid-tests", feature = "long-tests")), ignore)] #[test] -async fn circuit_breaker() -> Result<(), Error> { +async fn circuit_breaker() -> Result<(), BoxError> { let ctr = Arc::new(AtomicUsize::new(0)); let tun = serve_http( defaults, @@ -434,7 +462,7 @@ async fn circuit_breaker() -> Result<(), Error> { let ctr = ctr.clone(); move || { ctr.fetch_add(1, Ordering::SeqCst); - async { StatusCode::INTERNAL_SERVER_ERROR } + async { hyper::StatusCode::INTERNAL_SERVER_ERROR } } }), ), @@ -452,7 +480,7 @@ async fn circuit_breaker() -> Result<(), Error> { let resp = reqwest::get(url).await?; let status = resp.status(); tracing::debug!(?status); - Result::<_, Error>::Ok(resp.status()) + Result::<_, BoxError>::Ok(resp.status()) }); } let mut done = false; @@ -496,7 +524,7 @@ macro_rules! proxy_proto_test { #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] #[allow(non_snake_case)] - async fn []() -> Result<(), Error> { + async fn []() -> Result<(), BoxError> { let sess = Session::builder().authtoken_from_env().connect().await?; let mut $tun = sess .[<$ept _endpoint>]() @@ -549,7 +577,7 @@ proxy_proto_test!( #[traced_test] #[test] #[cfg_attr(not(feature = "paid-tests"), ignore)] -async fn http_ip_restriction() -> Result<(), Error> { +async fn http_ip_restriction() -> Result<(), BoxError> { let tun = serve_http( defaults, |tun| tun.allow_cidr("127.0.0.1/32").deny_cidr("0.0.0.0/0"), @@ -567,7 +595,7 @@ async fn http_ip_restriction() -> Result<(), Error> { #[traced_test] #[test] #[cfg_attr(not(feature = "paid-tests"), ignore)] -async fn tcp_ip_restriction() -> Result<(), Error> { +async fn tcp_ip_restriction() -> Result<(), BoxError> { let tun = Session::builder() .authtoken_from_env() .connect() @@ -590,7 +618,7 @@ async fn tcp_ip_restriction() -> Result<(), Error> { #[traced_test] #[test] #[cfg_attr(not(feature = "paid-tests"), ignore)] -async fn websocket_conversion() -> Result<(), Error> { +async fn websocket_conversion() -> Result<(), BoxError> { let mut tun = Session::builder() .authtoken_from_env() .connect() @@ -606,7 +634,7 @@ async fn websocket_conversion() -> Result<(), Error> { while let Some(mut conn) = tun.try_next().await? { conn.write_all("Hello, websockets!".as_bytes()).await?; } - Result::<_, Error>::Ok(()) + Result::<_, BoxError>::Ok(()) }); let mut wss = connect_async(url).await.expect("connect").0; @@ -627,7 +655,7 @@ async fn websocket_conversion() -> Result<(), Error> { wss.send(Message::Pong(b)).await?; } Message::Close(_) => { - anyhow::bail!("didn't get message before close"); + return Err(BoxError::from("didn't get message before close")); } _ => {} } @@ -639,7 +667,7 @@ async fn websocket_conversion() -> Result<(), Error> { #[traced_test] #[test] #[cfg_attr(not(feature = "authenticated-tests"), ignore)] -async fn tcp() -> Result<(), Error> { +async fn tcp() -> Result<(), BoxError> { let tun = Session::builder() .authtoken_from_env() .connect() @@ -663,7 +691,7 @@ const KEY: &[u8] = include_bytes!("../examples/domain.key"); #[traced_test] #[test] #[cfg_attr(not(feature = "authenticated-tests"), ignore)] -async fn tls() -> Result<(), Error> { +async fn tls() -> Result<(), BoxError> { let tun = Session::builder() .authtoken_from_env() .connect() @@ -690,7 +718,7 @@ async fn tls() -> Result<(), Error> { #[cfg_attr(not(feature = "online-tests"), ignore)] #[test] -async fn session_ca_cert() -> Result<(), Error> { +async fn session_ca_cert() -> Result<(), BoxError> { // invalid cert let resp = Session::builder() .authtoken_from_env() @@ -715,7 +743,7 @@ async fn session_ca_cert() -> Result<(), Error> { #[cfg_attr(not(feature = "online-tests"), ignore)] #[test] -async fn session_tls_config() -> Result<(), Error> { +async fn session_tls_config() -> Result<(), BoxError> { let default_tls_config = Session::builder().get_or_create_tls_config(); // invalid cert, but valid tls_config overrides @@ -748,7 +776,7 @@ fn tls_client_config() -> Result, &'static io::Error> { #[traced_test] #[cfg_attr(not(feature = "paid-tests"), ignore)] #[test] -async fn forward_proxy_protocol_tls() -> Result<(), Error> { +async fn forward_proxy_protocol_tls() -> Result<(), BoxError> { let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; let addr = listener.local_addr()?; @@ -799,3 +827,10 @@ async fn forward_proxy_protocol_tls() -> Result<(), Error> { Ok(()) } + +fn unwrap_infallible(result: Result) -> T { + match result { + Ok(value) => value, + Err(err) => match err {}, + } +} diff --git a/ngrok/src/proxy_proto.rs b/ngrok/src/proxy_proto.rs index 830892a..70a2eb3 100644 --- a/ngrok/src/proxy_proto.rs +++ b/ngrok/src/proxy_proto.rs @@ -322,6 +322,56 @@ where } } +#[cfg(feature = "hyper")] +mod hyper { + use ::hyper::rt::{ + Read as HyperRead, + Write as HyperWrite, + }; + + use super::*; + + impl HyperWrite for Stream + where + S: AsyncWrite, + { + #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + ::poll_write(self, cx, buf) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ::poll_flush(self, cx) + } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + ::poll_shutdown(self, cx) + } + } + + impl HyperRead for Stream + where + S: AsyncRead, + { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: ::hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() }); + let res = ready!(::poll_read(self, cx, &mut tokio_buf)); + let filled = tokio_buf.filled().len(); + unsafe { buf.advance(filled) }; + Poll::Ready(res) + } + } +} + #[cfg(test)] mod test { use std::{ diff --git a/ngrok/src/session.rs b/ngrok/src/session.rs index db1fb91..e7aa6e9 100644 --- a/ngrok/src/session.rs +++ b/ngrok/src/session.rs @@ -26,7 +26,7 @@ use futures_rustls::rustls::{ self, pki_types, }; -use hyper::{ +use hyper_0_14::{ client::HttpConnector, service::Service, }; diff --git a/ngrok/src/tunnel.rs b/ngrok/src/tunnel.rs index edac41b..d29f701 100644 --- a/ngrok/src/tunnel.rs +++ b/ngrok/src/tunnel.rs @@ -10,8 +10,6 @@ use std::{ use async_trait::async_trait; use futures::Stream; -#[cfg(feature = "hyper")] -use hyper::server::accept::Accept; use muxado::Error as MuxadoError; use thiserror::Error; use tokio::sync::mpsc::Receiver; @@ -138,12 +136,8 @@ macro_rules! tunnel_trait { } } -#[cfg(not(feature = "hyper"))] tunnel_trait!(); -#[cfg(feature = "hyper")] -tunnel_trait!(+ Accept::Conn, Error = AcceptError>); - /// An ngrok tunnel backing a simple endpoint. /// Most agent-configured tunnels fall into this category, with the exception of /// labeled tunnels. @@ -174,19 +168,6 @@ impl Stream for TunnelInner { } } -#[cfg(feature = "hyper")] -impl Accept for TunnelInner { - type Conn = ConnInner; - type Error = AcceptError; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - self.poll_next(cx) - } -} - impl TunnelInner { /// Get this tunnel's ID as returned by the ngrok server. pub fn id(&self) -> &str { @@ -302,20 +283,6 @@ macro_rules! make_tunnel_type { Pin::new(&mut self.inner).poll_next(cx).map(|o| o.map(|r| r.map(|c| $conn { inner: c }))) } } - - #[cfg(feature = "hyper")] - #[cfg_attr(all(feature = "hyper", docsrs), doc(cfg(feature = "hyper")))] - impl Accept for $wrapper { - type Conn = $conn; - type Error = AcceptError; - - fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(&mut self.inner).poll_accept(cx).map(|o| o.map(|r| r.map(|c| $conn { inner: c }))) - } - } }; (endpoint; $wrapper:ty) => { impl EndpointInfo for $wrapper { diff --git a/ngrok/src/tunnel_ext.rs b/ngrok/src/tunnel_ext.rs index 035dcec..09f1f60 100644 --- a/ngrok/src/tunnel_ext.rs +++ b/ngrok/src/tunnel_ext.rs @@ -14,6 +14,7 @@ use std::{ }; use async_trait::async_trait; +use axum_core::response::Response; use futures::stream::TryStreamExt; use futures_rustls::rustls::{ pki_types, @@ -22,19 +23,12 @@ use futures_rustls::rustls::{ }; #[cfg(feature = "hyper")] use hyper::{ - server::conn::Http, + server::conn::http1, service::service_fn, - Body, - Response, StatusCode, }; use once_cell::sync::Lazy; use proxy_protocol::ProxyHeader; -#[cfg(feature = "hyper")] -use tokio::io::{ - AsyncRead, - AsyncWrite, -}; #[cfg(target_os = "windows")] use tokio::net::windows::named_pipe::ClientOptions; #[cfg(not(target_os = "windows"))] @@ -382,23 +376,20 @@ async fn connect_tcp(host: &str, port: u16) -> Result { #[cfg(feature = "hyper")] fn serve_gateway_error( err: impl fmt::Display + Send + 'static, - conn: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, + conn: impl hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, ) -> JoinHandle<()> { tokio::spawn( async move { - let res = Http::new() - .http1_only(true) - .http1_keep_alive(false) - .serve_connection( - conn, - service_fn(move |_req| { - debug!("serving bad gateway error"); - let mut resp = - Response::new(Body::from(format!("failed to dial backend: {err}"))); - *resp.status_mut() = StatusCode::BAD_GATEWAY; - futures::future::ok::<_, Infallible>(resp) - }), - ) + let service = service_fn(move |_req| { + debug!("serving bad gateway error"); + let mut resp = Response::new(format!("failed to dial backend: {err}")); + *resp.status_mut() = StatusCode::BAD_GATEWAY; + futures::future::ok::<_, Infallible>(resp) + }); + + let res = http1::Builder::new() + .keep_alive(false) + .serve_connection(conn, service) .await; debug!(?res, "connection closed"); } From cc9947f0a5674d6959fa22f907e504053ed8ada4 Mon Sep 17 00:00:00 2001 From: Josh Robson Chase Date: Tue, 13 Feb 2024 11:29:29 -0500 Subject: [PATCH 2/2] ngrok: update tokio-tungstenite --- ngrok/Cargo.toml | 2 +- ngrok/src/online_tests.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ngrok/Cargo.toml b/ngrok/Cargo.toml index e32298e..325cdb0 100644 --- a/ngrok/Cargo.toml +++ b/ngrok/Cargo.toml @@ -66,7 +66,7 @@ flate2 = "1.0.25" tracing-test = "0.2.3" rand = "0.8.5" paste = "1.0.11" -tokio-tungstenite = { version = "0.18.0", features = [ +tokio-tungstenite = { version = "0.21.0", features = [ "rustls", "rustls-tls-webpki-roots", ] } diff --git a/ngrok/src/online_tests.rs b/ngrok/src/online_tests.rs index 5322af9..f00a797 100644 --- a/ngrok/src/online_tests.rs +++ b/ngrok/src/online_tests.rs @@ -37,11 +37,11 @@ use hyper::{ body::Incoming, HeaderMap, Request, + Uri, }; use hyper_0_14::{ header, StatusCode, - Uri, }; use hyper_util::{ rt::TokioExecutor,