Skip to content

Commit

Permalink
Remove handshake implementation
Browse files Browse the repository at this point in the history
They are protocol-specific and should be implemented by a higher-level
crate.
  • Loading branch information
msk committed Apr 1, 2024
1 parent 9351072 commit 97c20d0
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 256 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Versioning](https://semver.org/spec/v2.0.0.html).

- `RequestCode::Forward` and `message::send_forward_request`, since forwarding
messages between agents is no longer supported.
- `client_handshake`, `server_handshake`, and `AgentInfo`. These belong to the
`review-protocol` crate.

## [0.11.0] - 2024-03-25

Expand Down
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ futures = "0.3"
ipnet = { version = "2", features = ["serde"] }
num_enum = "0.7"
quinn = "0.10"
semver = "1"
serde = { version = "1", features = ["derive"] }
thiserror = "1"
tokio = "1"
Expand Down
255 changes: 3 additions & 252 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
use crate::frame::{self, RecvError, SendError};
use bincode::Options;
use quinn::{Connection, ConnectionError, RecvStream, SendStream};
use semver::{Version, VersionReq};
use serde::{Deserialize, Serialize};
use std::{
fmt, mem,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};
use quinn::{ConnectionError, RecvStream, SendStream};
use serde::Serialize;
use std::{fmt, mem};
use thiserror::Error;

/// Receives a message as a stream of bytes with a big-endian 4-byte length
Expand Down Expand Up @@ -70,142 +66,6 @@ impl From<SendError> for HandshakeError {
}
}

/// Properties of an agent.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AgentInfo {
pub app_name: String,
pub version: String,
pub protocol_version: String,
pub addr: SocketAddr,
}

impl std::fmt::Display for AgentInfo {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}, {}", self.app_name, self.addr)
}
}

/// Sends a handshake request and processes the response.
///
/// # Errors
///
/// Returns `HandshakeError` if the handshake failed.
pub async fn client_handshake(
conn: &Connection,
app_name: &str,
app_version: &str,
protocol_version: &str,
) -> Result<(SendStream, RecvStream), HandshakeError> {
// A placeholder for the address of this agent. Will be replaced by the
// server.
//
// TODO: This is unnecessary in handshake, and thus should be removed in the
// future.
let addr = if conn.remote_address().is_ipv6() {
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
} else {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
};

let agent_info = AgentInfo {
app_name: app_name.to_string(),
version: app_version.to_string(),
protocol_version: protocol_version.to_string(),
addr,
};

let (mut send, mut recv) = conn.open_bi().await?;
let mut buf = Vec::new();
if let Err(e) = frame::send(&mut send, &mut buf, &agent_info).await {
match e {
SendError::SerializationFailure(e) => {
return Err(HandshakeError::SerializationFailure(e))
}
SendError::MessageTooLarge(_) => return Err(HandshakeError::MessageTooLarge),
SendError::WriteError(e) => return Err(HandshakeError::WriteError(e)),
}
}

match frame::recv_raw(&mut recv, &mut buf).await {
Ok(()) => {}
Err(quinn::ReadExactError::FinishedEarly) => {
return Err(HandshakeError::ConnectionClosed);
}
Err(quinn::ReadExactError::ReadError(e)) => {
return Err(HandshakeError::ReadError(e));
}
}
let de = bincode::DefaultOptions::new();
de.deserialize::<Result<&str, &str>>(&buf)
.map_err(|_| HandshakeError::InvalidMessage)?
.map_err(|e| {
HandshakeError::IncompatibleProtocol(protocol_version.to_string(), e.to_string())
})?;

Ok((send, recv))
}

/// Processes a handshake message and sends a response.
///
/// # Errors
///
/// Returns `HandshakeError` if the handshake failed.
///
/// # Panics
///
/// * panic if it failed to parse version requirement string.
pub async fn server_handshake(
conn: &Connection,
addr: SocketAddr,
version_req: &str,
highest_protocol_version: &str,
) -> Result<AgentInfo, HandshakeError> {
let (mut send, mut recv) = conn
.accept_bi()
.await
.map_err(HandshakeError::ConnectionLost)?;
let mut buf = Vec::new();
let mut agent_info = frame::recv::<AgentInfo>(&mut recv, &mut buf)
.await
.map_err(|_| HandshakeError::InvalidMessage)?;
agent_info.addr = addr;
let version_req = VersionReq::parse(version_req).expect("valid version requirement");
let protocol_version = Version::parse(&agent_info.protocol_version).map_err(|_| {
HandshakeError::IncompatibleProtocol(
agent_info.protocol_version.clone(),
version_req.to_string(),
)
})?;
if version_req.matches(&protocol_version) {
let highest_protocol_version =
Version::parse(highest_protocol_version).expect("valid semver");
if protocol_version <= highest_protocol_version {
send_ok(&mut send, &mut buf, highest_protocol_version.to_string())
.await
.map_err(HandshakeError::from)?;
Ok(agent_info)
} else {
send_err(&mut send, &mut buf, &highest_protocol_version)
.await
.map_err(HandshakeError::from)?;
send.finish().await.ok();
Err(HandshakeError::IncompatibleProtocol(
protocol_version.to_string(),
version_req.to_string(),
))
}
} else {
send_err(&mut send, &mut buf, version_req.to_string())
.await
.map_err(HandshakeError::from)?;
send.finish().await.ok();
Err(HandshakeError::IncompatibleProtocol(
protocol_version.to_string(),
version_req.to_string(),
))
}
}

/// Sends a request with a big-endian 4-byte length header.
///
/// `buf` will be cleared after the response is sent.
Expand Down Expand Up @@ -285,115 +145,6 @@ mod tests {
use crate::test::{channel, TOKEN};
use crate::{frame, RequestCode};

#[tokio::test]
async fn handshake() {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

const APP_NAME: &str = "oinq";
const APP_VERSION: &str = "1.0.0";
const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");

let _lock = TOKEN.lock().await;
let channel = channel().await;
let (server, client) = (channel.server, channel.client);

let handle = tokio::spawn(async move {
super::client_handshake(&client.conn, APP_NAME, APP_VERSION, PROTOCOL_VERSION).await
});

let agent_info = super::server_handshake(
&server.conn,
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
PROTOCOL_VERSION,
PROTOCOL_VERSION,
)
.await
.unwrap();

assert_eq!(agent_info.app_name, APP_NAME);
assert_eq!(agent_info.version, APP_VERSION);
assert_eq!(agent_info.protocol_version, PROTOCOL_VERSION);

assert_eq!(
agent_info.to_string(),
format!("{}, {}", agent_info.app_name, agent_info.addr)
);
let res = tokio::join!(handle).0.unwrap();
assert!(res.is_ok());
}

#[tokio::test]
async fn handshake_version_incompatible_err() {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

const APP_NAME: &str = "oinq";
const APP_VERSION: &str = "1.0.0";
const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");

let _lock = TOKEN.lock().await;
let channel = channel().await;
let (server, client) = (channel.server, channel.client);

let handle = tokio::spawn(async move {
super::client_handshake(&client.conn, APP_NAME, APP_VERSION, PROTOCOL_VERSION).await
});

let res = super::server_handshake(
&server.conn,
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
&format!("<{PROTOCOL_VERSION}"),
PROTOCOL_VERSION,
)
.await;

assert!(res.is_err());

let res = tokio::join!(handle).0.unwrap();
assert!(res.is_err());
}

#[tokio::test]
async fn handshake_incompatible_err() {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

const APP_NAME: &str = "oinq";
const APP_VERSION: &str = "1.0.0";
const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");

let version_req = semver::VersionReq::parse(&format!(">={PROTOCOL_VERSION}")).unwrap();
let mut highest_version = semver::Version::parse(PROTOCOL_VERSION).unwrap();
highest_version.patch += 1;
let mut protocol_version = highest_version.clone();
protocol_version.minor += 1;

let _lock = TOKEN.lock().await;
let channel = channel().await;
let (server, client) = (channel.server, channel.client);

let handle = tokio::spawn(async move {
super::client_handshake(
&client.conn,
APP_NAME,
APP_VERSION,
&protocol_version.to_string(),
)
.await
});

let res = super::server_handshake(
&server.conn,
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
&version_req.to_string(),
&highest_version.to_string(),
)
.await;

assert!(res.is_err());

let res = tokio::join!(handle).0.unwrap();
assert!(res.is_err());
}

#[tokio::test]
async fn send_and_recv() {
let _lock = TOKEN.lock().await;
Expand Down
6 changes: 3 additions & 3 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(crate) struct Channel {
}

pub(crate) struct Endpoint {
pub(crate) conn: Connection,
pub(crate) _conn: Connection,
pub(crate) send: SendStream,
pub(crate) recv: RecvStream,
}
Expand Down Expand Up @@ -80,12 +80,12 @@ pub(crate) async fn channel() -> Channel {

Channel {
server: self::Endpoint {
conn: server_connection,
_conn: server_connection,
send: server_send,
recv: server_recv,
},
client: self::Endpoint {
conn: client_connection,
_conn: client_connection,
send: client_send,
recv: client_recv,
},
Expand Down

0 comments on commit 97c20d0

Please sign in to comment.