diff --git a/CHANGELOG.md b/CHANGELOG.md index 440d1ca..506de09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ Versioning](https://semver.org/spec/v2.0.0.html). - `RequestCode::Forward` and `message::send_forward_request`, since forwarding messages between agents is no longer supported. +- `client_handshake`, `server_handshake`, and `AgentInfo`. These belong to the + `review-protocol` crate. ## [0.11.0] - 2024-03-25 diff --git a/Cargo.toml b/Cargo.toml index b1ed0a3..6a8d389 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,6 @@ futures = "0.3" ipnet = { version = "2", features = ["serde"] } num_enum = "0.7" quinn = "0.10" -semver = "1" serde = { version = "1", features = ["derive"] } thiserror = "1" tokio = "1" diff --git a/src/message.rs b/src/message.rs index 3830ce5..6294cdb 100644 --- a/src/message.rs +++ b/src/message.rs @@ -2,13 +2,9 @@ use crate::frame::{self, RecvError, SendError}; use bincode::Options; -use quinn::{Connection, ConnectionError, RecvStream, SendStream}; -use semver::{Version, VersionReq}; -use serde::{Deserialize, Serialize}; -use std::{ - fmt, mem, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, -}; +use quinn::{ConnectionError, RecvStream, SendStream}; +use serde::Serialize; +use std::{fmt, mem}; use thiserror::Error; /// Receives a message as a stream of bytes with a big-endian 4-byte length @@ -70,142 +66,6 @@ impl From<SendError> for HandshakeError { } } -/// Properties of an agent. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct AgentInfo { - pub app_name: String, - pub version: String, - pub protocol_version: String, - pub addr: SocketAddr, -} - -impl std::fmt::Display for AgentInfo { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}, {}", self.app_name, self.addr) - } -} - -/// Sends a handshake request and processes the response. -/// -/// # Errors -/// -/// Returns `HandshakeError` if the handshake failed. -pub async fn client_handshake( - conn: &Connection, - app_name: &str, - app_version: &str, - protocol_version: &str, -) -> Result<(SendStream, RecvStream), HandshakeError> { - // A placeholder for the address of this agent. Will be replaced by the - // server. - // - // TODO: This is unnecessary in handshake, and thus should be removed in the - // future. - let addr = if conn.remote_address().is_ipv6() { - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) - } else { - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) - }; - - let agent_info = AgentInfo { - app_name: app_name.to_string(), - version: app_version.to_string(), - protocol_version: protocol_version.to_string(), - addr, - }; - - let (mut send, mut recv) = conn.open_bi().await?; - let mut buf = Vec::new(); - if let Err(e) = frame::send(&mut send, &mut buf, &agent_info).await { - match e { - SendError::SerializationFailure(e) => { - return Err(HandshakeError::SerializationFailure(e)) - } - SendError::MessageTooLarge(_) => return Err(HandshakeError::MessageTooLarge), - SendError::WriteError(e) => return Err(HandshakeError::WriteError(e)), - } - } - - match frame::recv_raw(&mut recv, &mut buf).await { - Ok(()) => {} - Err(quinn::ReadExactError::FinishedEarly) => { - return Err(HandshakeError::ConnectionClosed); - } - Err(quinn::ReadExactError::ReadError(e)) => { - return Err(HandshakeError::ReadError(e)); - } - } - let de = bincode::DefaultOptions::new(); - de.deserialize::<Result<&str, &str>>(&buf) - .map_err(|_| HandshakeError::InvalidMessage)? - .map_err(|e| { - HandshakeError::IncompatibleProtocol(protocol_version.to_string(), e.to_string()) - })?; - - Ok((send, recv)) -} - -/// Processes a handshake message and sends a response. -/// -/// # Errors -/// -/// Returns `HandshakeError` if the handshake failed. -/// -/// # Panics -/// -/// * panic if it failed to parse version requirement string. -pub async fn server_handshake( - conn: &Connection, - addr: SocketAddr, - version_req: &str, - highest_protocol_version: &str, -) -> Result<AgentInfo, HandshakeError> { - let (mut send, mut recv) = conn - .accept_bi() - .await - .map_err(HandshakeError::ConnectionLost)?; - let mut buf = Vec::new(); - let mut agent_info = frame::recv::<AgentInfo>(&mut recv, &mut buf) - .await - .map_err(|_| HandshakeError::InvalidMessage)?; - agent_info.addr = addr; - let version_req = VersionReq::parse(version_req).expect("valid version requirement"); - let protocol_version = Version::parse(&agent_info.protocol_version).map_err(|_| { - HandshakeError::IncompatibleProtocol( - agent_info.protocol_version.clone(), - version_req.to_string(), - ) - })?; - if version_req.matches(&protocol_version) { - let highest_protocol_version = - Version::parse(highest_protocol_version).expect("valid semver"); - if protocol_version <= highest_protocol_version { - send_ok(&mut send, &mut buf, highest_protocol_version.to_string()) - .await - .map_err(HandshakeError::from)?; - Ok(agent_info) - } else { - send_err(&mut send, &mut buf, &highest_protocol_version) - .await - .map_err(HandshakeError::from)?; - send.finish().await.ok(); - Err(HandshakeError::IncompatibleProtocol( - protocol_version.to_string(), - version_req.to_string(), - )) - } - } else { - send_err(&mut send, &mut buf, version_req.to_string()) - .await - .map_err(HandshakeError::from)?; - send.finish().await.ok(); - Err(HandshakeError::IncompatibleProtocol( - protocol_version.to_string(), - version_req.to_string(), - )) - } -} - /// Sends a request with a big-endian 4-byte length header. /// /// `buf` will be cleared after the response is sent. @@ -285,115 +145,6 @@ mod tests { use crate::test::{channel, TOKEN}; use crate::{frame, RequestCode}; - #[tokio::test] - async fn handshake() { - use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - - const APP_NAME: &str = "oinq"; - const APP_VERSION: &str = "1.0.0"; - const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION"); - - let _lock = TOKEN.lock().await; - let channel = channel().await; - let (server, client) = (channel.server, channel.client); - - let handle = tokio::spawn(async move { - super::client_handshake(&client.conn, APP_NAME, APP_VERSION, PROTOCOL_VERSION).await - }); - - let agent_info = super::server_handshake( - &server.conn, - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - PROTOCOL_VERSION, - PROTOCOL_VERSION, - ) - .await - .unwrap(); - - assert_eq!(agent_info.app_name, APP_NAME); - assert_eq!(agent_info.version, APP_VERSION); - assert_eq!(agent_info.protocol_version, PROTOCOL_VERSION); - - assert_eq!( - agent_info.to_string(), - format!("{}, {}", agent_info.app_name, agent_info.addr) - ); - let res = tokio::join!(handle).0.unwrap(); - assert!(res.is_ok()); - } - - #[tokio::test] - async fn handshake_version_incompatible_err() { - use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - - const APP_NAME: &str = "oinq"; - const APP_VERSION: &str = "1.0.0"; - const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION"); - - let _lock = TOKEN.lock().await; - let channel = channel().await; - let (server, client) = (channel.server, channel.client); - - let handle = tokio::spawn(async move { - super::client_handshake(&client.conn, APP_NAME, APP_VERSION, PROTOCOL_VERSION).await - }); - - let res = super::server_handshake( - &server.conn, - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - &format!("<{PROTOCOL_VERSION}"), - PROTOCOL_VERSION, - ) - .await; - - assert!(res.is_err()); - - let res = tokio::join!(handle).0.unwrap(); - assert!(res.is_err()); - } - - #[tokio::test] - async fn handshake_incompatible_err() { - use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - - const APP_NAME: &str = "oinq"; - const APP_VERSION: &str = "1.0.0"; - const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION"); - - let version_req = semver::VersionReq::parse(&format!(">={PROTOCOL_VERSION}")).unwrap(); - let mut highest_version = semver::Version::parse(PROTOCOL_VERSION).unwrap(); - highest_version.patch += 1; - let mut protocol_version = highest_version.clone(); - protocol_version.minor += 1; - - let _lock = TOKEN.lock().await; - let channel = channel().await; - let (server, client) = (channel.server, channel.client); - - let handle = tokio::spawn(async move { - super::client_handshake( - &client.conn, - APP_NAME, - APP_VERSION, - &protocol_version.to_string(), - ) - .await - }); - - let res = super::server_handshake( - &server.conn, - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - &version_req.to_string(), - &highest_version.to_string(), - ) - .await; - - assert!(res.is_err()); - - let res = tokio::join!(handle).0.unwrap(); - assert!(res.is_err()); - } - #[tokio::test] async fn send_and_recv() { let _lock = TOKEN.lock().await; diff --git a/src/test.rs b/src/test.rs index 97652ff..00febf6 100644 --- a/src/test.rs +++ b/src/test.rs @@ -10,7 +10,7 @@ pub(crate) struct Channel { } pub(crate) struct Endpoint { - pub(crate) conn: Connection, + pub(crate) _conn: Connection, pub(crate) send: SendStream, pub(crate) recv: RecvStream, } @@ -80,12 +80,12 @@ pub(crate) async fn channel() -> Channel { Channel { server: self::Endpoint { - conn: server_connection, + _conn: server_connection, send: server_send, recv: server_recv, }, client: self::Endpoint { - conn: client_connection, + _conn: client_connection, send: client_send, recv: client_recv, },