Skip to content

Commit

Permalink
feat(awc): allow to set a specific sni host on the request
Browse files Browse the repository at this point in the history
  • Loading branch information
joelwurtz committed Dec 9, 2024
1 parent 002c1b5 commit 50f296f
Show file tree
Hide file tree
Showing 12 changed files with 382 additions and 103 deletions.
1 change: 1 addition & 0 deletions awc/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Update `brotli` dependency to `7`.
- Prevent panics on connection pool drop when Tokio runtime is shutdown early.
- Minimum supported Rust version (MSRV) is now 1.75.
- Allow to set a specific SNI hostname on the request for TLS connections.

## 3.5.1

Expand Down
22 changes: 14 additions & 8 deletions awc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use std::{fmt, net::IpAddr, rc::Rc, time::Duration};
use actix_http::{
error::HttpError,
header::{self, HeaderMap, HeaderName, TryIntoHeaderPair},
Uri,
};
use actix_rt::net::{ActixStream, TcpStream};
use actix_service::{boxed, Service};
use base64::prelude::*;

use crate::{
client::{
ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection,
ClientConfig, ConnectInfo, Connector, ConnectorService, HostnameWithSni, TcpConnectError,
TcpConnection,
},
connect::DefaultConnector,
error::SendRequestError,
Expand Down Expand Up @@ -46,8 +46,8 @@ impl ClientBuilder {
#[allow(clippy::new_ret_no_self)]
pub fn new() -> ClientBuilder<
impl Service<
ConnectInfo<Uri>,
Response = TcpConnection<Uri, TcpStream>,
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, TcpStream>,
Error = TcpConnectError,
> + Clone,
(),
Expand All @@ -69,16 +69,22 @@ impl ClientBuilder {

impl<S, Io, M> ClientBuilder<S, M>
where
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
+ Clone
S: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io>,
Error = TcpConnectError,
> + Clone
+ 'static,
Io: ActixStream + fmt::Debug + 'static,
{
/// Use custom connector service.
pub fn connector<S1, Io1>(self, connector: Connector<S1>) -> ClientBuilder<S1, M>
where
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
+ Clone
S1: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io1>,
Error = TcpConnectError,
> + Clone
+ 'static,
Io1: ActixStream + fmt::Debug + 'static,
{
Expand Down
126 changes: 91 additions & 35 deletions awc/src/client/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,51 @@ use actix_rt::{
use actix_service::Service;
use actix_tls::connect::{
ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection,
Connector as TcpConnector, Resolver,
Connector as TcpConnector, Host, Resolver,
};
use futures_core::{future::LocalBoxFuture, ready};
use http::Uri;
use pin_project_lite::pin_project;

use super::{
config::ConnectorConfig,
connection::{Connection, ConnectionIo},
error::ConnectError,
pool::ConnectionPool,
Connect,
Connect, ServerName,
};

pub enum HostnameWithSni {
ForTcp(String, u16, Option<ServerName>),
ForTls(String, u16, Option<ServerName>),
}

impl Host for HostnameWithSni {
fn hostname(&self) -> &str {
match self {
HostnameWithSni::ForTcp(hostname, _, _) => hostname,
HostnameWithSni::ForTls(hostname, _, sni) => sni.as_deref().unwrap_or(hostname),
}
}

fn port(&self) -> Option<u16> {
match self {
HostnameWithSni::ForTcp(_, port, _) => Some(*port),
HostnameWithSni::ForTls(_, port, _) => Some(*port),
}
}
}

impl HostnameWithSni {
pub fn to_tls(self) -> Self {
match self {
HostnameWithSni::ForTcp(hostname, port, sni) => {
HostnameWithSni::ForTls(hostname, port, sni)
}
HostnameWithSni::ForTls(_, _, _) => self,
}
}
}

enum OurTlsConnector {
#[allow(dead_code)] // only dead when no TLS feature is enabled
None,
Expand Down Expand Up @@ -95,8 +126,8 @@ impl Connector<()> {
#[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
pub fn new() -> Connector<
impl Service<
ConnectInfo<Uri>,
Response = TcpConnection<Uri, TcpStream>,
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, TcpStream>,
Error = actix_tls::connect::ConnectError,
> + Clone,
> {
Expand Down Expand Up @@ -214,8 +245,11 @@ impl<S> Connector<S> {
pub fn connector<S1, Io1>(self, connector: S1) -> Connector<S1>
where
Io1: ActixStream + fmt::Debug + 'static,
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
+ Clone,
S1: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io1>,
Error = TcpConnectError,
> + Clone,
{
Connector {
connector,
Expand All @@ -235,8 +269,11 @@ where
// This remap is to hide ActixStream's trait methods. They are not meant to be called
// from user code.
IO: ActixStream + fmt::Debug + 'static,
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, IO>, Error = TcpConnectError>
+ Clone
S: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, IO>,
Error = TcpConnectError,
> + Clone
+ 'static,
{
/// Sets TCP connection timeout.
Expand Down Expand Up @@ -454,7 +491,7 @@ where
use actix_utils::future::{ready, Ready};

#[allow(non_local_definitions)]
impl IntoConnectionIo for TcpConnection<Uri, Box<dyn ConnectionIo>> {
impl IntoConnectionIo for TcpConnection<HostnameWithSni, Box<dyn ConnectionIo>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let io = self.into_parts().0;
(io, Protocol::Http2)
Expand Down Expand Up @@ -505,7 +542,7 @@ where
use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector};

#[allow(non_local_definitions)]
impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncSslStream<IO>> {
impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncSslStream<IO>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -543,7 +580,7 @@ where
use actix_tls::connect::rustls_0_20::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -577,7 +614,7 @@ where
use actix_tls::connect::rustls_0_21::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -614,7 +651,7 @@ where
use actix_tls::connect::rustls_0_22::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -648,7 +685,7 @@ where
use actix_tls::connect::rustls_0_23::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -688,15 +725,17 @@ where
}
}

/// tcp service for map `TcpConnection<Uri, Io>` type to `(Io, Protocol)`
/// tcp service for map `TcpConnection<HostnameWithSni, Io>` type to `(Io, Protocol)`
#[derive(Clone)]
pub struct TcpConnectorService<S: Clone> {
service: S,
}

impl<S, Io> Service<Connect> for TcpConnectorService<S>
where
S: Service<Connect, Response = TcpConnection<Uri, Io>, Error = ConnectError> + Clone + 'static,
S: Service<Connect, Response = TcpConnection<HostnameWithSni, Io>, Error = ConnectError>
+ Clone
+ 'static,
{
type Response = (Io, Protocol);
type Error = ConnectError;
Expand All @@ -721,7 +760,7 @@ pin_project! {

impl<Fut, Io> Future for TcpConnectorFuture<Fut>
where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
{
type Output = Result<(Io, Protocol), ConnectError>;

Expand Down Expand Up @@ -767,9 +806,10 @@ struct TlsConnectorService<Tcp, Tls> {
))]
impl<Tcp, Tls, IO> Service<Connect> for TlsConnectorService<Tcp, Tls>
where
Tcp:
Service<Connect, Response = TcpConnection<Uri, IO>, Error = ConnectError> + Clone + 'static,
Tls: Service<TcpConnection<Uri, IO>, Error = std::io::Error> + Clone + 'static,
Tcp: Service<Connect, Response = TcpConnection<HostnameWithSni, IO>, Error = ConnectError>
+ Clone
+ 'static,
Tls: Service<TcpConnection<HostnameWithSni, IO>, Error = std::io::Error> + Clone + 'static,
Tls::Response: IntoConnectionIo,
IO: ConnectionIo,
{
Expand Down Expand Up @@ -822,9 +862,14 @@ trait IntoConnectionIo {

impl<S, Io, Fut1, Fut2, Res> Future for TlsConnectorFuture<S, Fut1, Fut2>
where
S: Service<TcpConnection<Uri, Io>, Response = Res, Error = std::io::Error, Future = Fut2>,
S: Service<
TcpConnection<HostnameWithSni, Io>,
Response = Res,
Error = std::io::Error,
Future = Fut2,
>,
S::Response: IntoConnectionIo,
Fut1: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
Fut1: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
Fut2: Future<Output = Result<S::Response, S::Error>>,
Io: ConnectionIo,
{
Expand All @@ -838,10 +883,11 @@ where
timeout,
} => {
let res = ready!(fut.poll(cx))?;
let (io, hostname_with_sni) = res.into_parts();
let fut = tls_service
.take()
.expect("TlsConnectorFuture polled after complete")
.call(res);
.call(TcpConnection::new(hostname_with_sni.to_tls(), io));
let timeout = sleep(*timeout);
self.set(TlsConnectorFuture::TlsConnect { fut, timeout });
self.poll(cx)
Expand Down Expand Up @@ -875,8 +921,11 @@ impl<S: Clone> TcpConnectorInnerService<S> {

impl<S, Io> Service<Connect> for TcpConnectorInnerService<S>
where
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
+ Clone
S: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io>,
Error = TcpConnectError,
> + Clone
+ 'static,
{
type Response = S::Response;
Expand All @@ -886,7 +935,13 @@ where
actix_service::forward_ready!(service);

fn call(&self, req: Connect) -> Self::Future {
let mut req = ConnectInfo::new(req.uri).set_addr(req.addr);
let mut req = ConnectInfo::new(HostnameWithSni::ForTcp(
req.hostname,
req.port,
req.sni_host,
))
.set_addr(req.addr)
.set_port(req.port);

if let Some(local_addr) = self.local_address {
req = req.set_local_addr(local_addr);
Expand All @@ -911,9 +966,9 @@ pin_project! {

impl<Fut, Io> Future for TcpConnectorInnerFuture<Fut>
where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, TcpConnectError>>,
Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, TcpConnectError>>,
{
type Output = Result<TcpConnection<Uri, Io>, ConnectError>;
type Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
Expand Down Expand Up @@ -973,16 +1028,17 @@ where
}

fn call(&self, req: Connect) -> Self::Future {
match req.uri.scheme_str() {
Some("https") | Some("wss") => match self.tls_pool {
if req.tls {
match &self.tls_pool {
None => ConnectorServiceFuture::SslIsNotSupported,
Some(ref pool) => ConnectorServiceFuture::Tls {
Some(pool) => ConnectorServiceFuture::Tls {
fut: pool.call(req),
},
},
_ => ConnectorServiceFuture::Tcp {
}
} else {
ConnectorServiceFuture::Tcp {
fut: self.tcp_pool.call(req),
},
}
}
}
}
Expand Down
Loading

0 comments on commit 50f296f

Please sign in to comment.