diff --git a/src/protocols/unix_sock/server.rs b/src/protocols/unix_sock/server.rs index 08082908..e53f6826 100644 --- a/src/protocols/unix_sock/server.rs +++ b/src/protocols/unix_sock/server.rs @@ -50,9 +50,14 @@ impl Stream for UnixListenerStream { pub async fn run_server(socket_path: &Path) -> Result { info!("Starting Unix socket server listening cnx on {:?}", socket_path); - let path_to_delete = !socket_path.exists(); + if socket_path.exists() { + std::fs::remove_file(socket_path) + .with_context(|| format!("Failed to delete existing Unix socket at {:?}", socket_path))?; + } + let listener = UnixListener::bind(socket_path) .with_context(|| format!("Cannot create Unix socket server {:?}", socket_path))?; + let path_to_delete = true; Ok(UnixListenerStream::new(listener, path_to_delete)) } diff --git a/src/tunnel/listeners/unix_sock.rs b/src/tunnel/listeners/unix_sock.rs index 392fd4f0..72e3735b 100644 --- a/src/tunnel/listeners/unix_sock.rs +++ b/src/tunnel/listeners/unix_sock.rs @@ -2,7 +2,7 @@ use crate::protocols::unix_sock; use crate::protocols::unix_sock::UnixListenerStream; use crate::tunnel::{LocalProtocol, RemoteAddr}; use anyhow::{anyhow, Context}; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::pin::Pin; use std::task::{ready, Poll}; use tokio::net::unix; @@ -13,6 +13,7 @@ pub struct UnixTunnelListener { listener: UnixListenerStream, dest: (Host, u16), proxy_protocol: bool, + path: PathBuf, } impl UnixTunnelListener { @@ -25,6 +26,7 @@ impl UnixTunnelListener { listener, dest, proxy_protocol, + path: path.to_path_buf(), }) } } @@ -55,3 +57,10 @@ impl Stream for UnixTunnelListener { Poll::Ready(ret) } } +impl Drop for UnixTunnelListener { + fn drop(&mut self) { + if let Err(err) = std::fs::remove_file(&self.path) { + log::error!("Cannot remove Unix domain socket file {}: {}", self.path.display(), err); + } + } +} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index faa7e9c2..69fd0753 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -57,6 +57,12 @@ pub enum LocalProtocol { }, } +#[derive(Hash, Eq, PartialEq, Clone, Debug)] +pub enum BindAddr { + Socket(SocketAddr), + Unix(String), // Unix socket path +} + impl LocalProtocol { pub const fn is_reverse_tunnel(&self) -> bool { matches!( diff --git a/src/tunnel/server/reverse_tunnel.rs b/src/tunnel/server/reverse_tunnel.rs index cb1b1f4b..48c2803d 100644 --- a/src/tunnel/server/reverse_tunnel.rs +++ b/src/tunnel/server/reverse_tunnel.rs @@ -1,4 +1,5 @@ use crate::tunnel::listeners::TunnelListener; +use crate::tunnel::BindAddr; use crate::tunnel::RemoteAddr; use ahash::AHashMap; use anyhow::anyhow; @@ -6,7 +7,6 @@ use futures_util::{pin_mut, StreamExt}; use log::warn; use parking_lot::Mutex; use std::future::Future; -use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -29,7 +29,7 @@ impl Clone for ReverseTunnelItem { } pub struct ReverseTunnelServer { - servers: Arc>>>, + servers: Arc>>>, } impl ReverseTunnelServer { @@ -41,7 +41,7 @@ impl ReverseTunnelServer { pub async fn run_listening_server( &self, - bind_addr: SocketAddr, + bind_addr: BindAddr, gen_listening_server: impl Future>, ) -> anyhow::Result<((::Reader, ::Writer), RemoteAddr)> where @@ -57,7 +57,7 @@ impl ReverseTunnelServer { let nb_seen_clients = Arc::new(AtomicUsize::new(0)); let seen_clients = nb_seen_clients.clone(); let server = self.servers.clone(); - let local_srv2 = bind_addr; + let local_srv2 = bind_addr.clone(); let fut = async move { scopeguard::defer!({ diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index 357f6e8a..3d1f2f0b 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -13,7 +13,7 @@ use std::sync::{Arc, LazyLock}; use std::time::Duration; use crate::protocols; -use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr}; +use crate::tunnel::{try_to_sock_addr, BindAddr, LocalProtocol, RemoteAddr}; use hyper::body::Incoming; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; @@ -194,9 +194,10 @@ impl WsServer { let header = ppp::v2::Builder::with_addresses( ppp::v2::Version::Two | ppp::v2::Command::Proxy, ppp::v2::Protocol::Stream, - (client_address, tx.local_addr()?), + (client_address, tx.local_addr().unwrap()), ) - .build()?; + .build() + .unwrap(); let _ = tx.write_all(&header).await; } @@ -210,8 +211,9 @@ impl WsServer { let local_srv = (remote.host, remote_port); let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { TcpTunnelListener::new(bind, local_srv.clone(), false).await }; - let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; - + let ((local_rx, local_tx), remote) = SERVERS + .run_listening_server(BindAddr::Socket(bind), listening_server) + .await?; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseUdp { timeout } => { @@ -222,7 +224,9 @@ impl WsServer { let local_srv = (remote.host, remote_port); let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { UdpTunnelListener::new(bind, local_srv.clone(), timeout).await }; - let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; + let ((local_rx, local_tx), remote) = SERVERS + .run_listening_server(BindAddr::Socket(bind), listening_server) + .await?; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseSocks5 { timeout, credentials } => { @@ -233,7 +237,9 @@ impl WsServer { let local_srv = (remote.host, remote_port); let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { Socks5TunnelListener::new(bind, timeout, credentials).await }; - let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; + let ((local_rx, local_tx), remote) = SERVERS + .run_listening_server(BindAddr::Socket(bind), listening_server) + .await?; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } @@ -245,7 +251,9 @@ impl WsServer { let local_srv = (remote.host, remote_port); let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { HttpProxyTunnelListener::new(bind, timeout, credentials, false).await }; - let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; + let ((local_rx, local_tx), remote) = SERVERS + .run_listening_server(BindAddr::Socket(bind), listening_server) + .await?; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } @@ -255,11 +263,10 @@ impl WsServer { static SERVERS: LazyLock> = LazyLock::new(ReverseTunnelServer::new); - let remote_port = find_mapped_port(remote.port, restriction); - let local_srv = (remote.host, remote_port); - let bind = try_to_sock_addr(local_srv.clone())?; - let listening_server = async { UnixTunnelListener::new(path, local_srv, false).await }; - let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; + let listening_server = async { UnixTunnelListener::new(path, (remote.host, remote.port), false).await }; + let ((local_rx, local_tx), remote) = SERVERS + .run_listening_server(BindAddr::Unix(path.to_str().unwrap().to_string()), listening_server) + .await?; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } @@ -291,7 +298,6 @@ impl WsServer { move |req: Request| { ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) .map::, _>(Ok) - .instrument(mk_span()) } }; @@ -302,7 +308,6 @@ impl WsServer { move |req: Request| { http_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) .map::, _>(Ok) - .instrument(mk_span()) } }; @@ -337,7 +342,6 @@ impl WsServer { .unwrap()) } } - .instrument(mk_span()) } }; @@ -383,12 +387,20 @@ impl WsServer { } }; - let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string(),); - info!(parent: &span, "Accepting connection"); if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), &None) { warn!("Error while configuring server socket {:?}", err); } + let span = span!( + Level::INFO, + "tunnel", + id = tracing::field::Empty, + remote = tracing::field::Empty, + peer = peer_addr.to_string(), + forwarded_for = tracing::field::Empty + ); + + info!("Accepting connection"); let server = self.clone(); let restrictions = restrictions.restrictions_rules().clone(); @@ -435,9 +447,7 @@ impl WsServer { mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); let conn_fut = http1::Builder::new() .timer(TokioTimer::new()) - // https://github.com/erebe/wstunnel/issues/358 - // disabled, to avoid conflict with --connection-min-idle flag, that open idle connections - .header_read_timeout(None) + .header_read_timeout(Duration::from_secs(10)) .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) .with_upgrades(); @@ -450,6 +460,7 @@ impl WsServer { .instrument(span); tokio::spawn(fut); + // Normal } // HTTP without TLS None => { @@ -477,16 +488,6 @@ impl WsServer { } } -fn mk_span() -> Span { - span!( - Level::INFO, - "tunnel", - id = tracing::field::Empty, - remote = tracing::field::Empty, - forwarded_for = tracing::field::Empty - ) -} - impl Debug for WsServerConfig { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("WsServerConfig")