diff --git a/CHANGELOG.md b/CHANGELOG.md index c7f0a17..fe59dae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added + +- `request::parse_args` to parse arguments for a request handler. + ### Changed - `SendError::MessageTooLarge` no longer contains the underlying error, diff --git a/src/request.rs b/src/request.rs index e76cd16..7162a40 100644 --- a/src/request.rs +++ b/src/request.rs @@ -205,7 +205,6 @@ pub async fn handle( recv: &mut RecvStream, ) -> Result<(), HandlerError> { let mut buf = Vec::new(); - let codec = bincode::DefaultOptions::new(); loop { let (code, body) = match message::recv_request_raw(recv, &mut buf).await { Ok(res) => res, @@ -232,9 +231,7 @@ pub async fn handle( send_response(send, &mut buf, handler.reload_config().await).await?; } RequestCode::ReloadTi => { - let version = codec - .deserialize::<&str>(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let version = parse_args::<&str>(body)?; let result = handler.reload_ti(version).await; send_response(send, &mut buf, result).await?; } @@ -242,9 +239,7 @@ pub async fn handle( send_response(send, &mut buf, handler.resource_usage().await).await?; } RequestCode::TorExitNodeList => { - let nodes = codec - .deserialize::>(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let nodes = parse_args::>(body)?; let result = handler.tor_exit_node_list(&nodes).await; send_response(send, &mut buf, result).await?; } @@ -257,10 +252,7 @@ pub async fn handle( send_response(send, &mut buf, result).await?; } RequestCode::TrustedDomainList => { - let domains = codec - .deserialize::, String>>(body) - .map_err(frame::RecvError::DeserializationFailure)?; - + let domains = parse_args::, String>>(body)?; let result = if let Ok(domains) = domains { handler.trusted_domain_list(&domains).await } else { @@ -269,23 +261,17 @@ pub async fn handle( send_response(send, &mut buf, result).await?; } RequestCode::InternalNetworkList => { - let network_list = codec - .deserialize::(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let network_list = parse_args::(body)?; let result = handler.internal_network_list(network_list).await; send_response(send, &mut buf, result).await?; } RequestCode::AllowList => { - let allow_list = codec - .deserialize::(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let allow_list = parse_args::(body)?; let result = handler.allow_list(allow_list).await; send_response(send, &mut buf, result).await?; } RequestCode::BlockList => { - let block_list = codec - .deserialize::(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let block_list = parse_args::(body)?; let result = handler.block_list(block_list).await; send_response(send, &mut buf, result).await?; } @@ -293,16 +279,12 @@ pub async fn handle( send_response(send, &mut buf, Ok::<(), String>(())).await?; } RequestCode::TrustedUserAgentList => { - let user_agent_list = codec - .deserialize::>(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let user_agent_list = parse_args::>(body)?; let result = handler.trusted_user_agent_list(&user_agent_list).await; send_response(send, &mut buf, result).await?; } RequestCode::ReloadFilterRule => { - let rules = codec - .deserialize::>(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let rules = parse_args::>(body)?; let result = handler.update_traffic_filter_rules(&rules).await; send_response(send, &mut buf, result).await?; } @@ -310,9 +292,7 @@ pub async fn handle( send_response(send, &mut buf, handler.get_config().await).await?; } RequestCode::SetConfig => { - let conf = codec - .deserialize::(body) - .map_err(frame::RecvError::DeserializationFailure)?; + let conf = parse_args::(body)?; let result = handler.set_config(conf).await; send_response(send, &mut buf, result).await?; } @@ -335,7 +315,25 @@ pub async fn handle( Ok(()) } -async fn send_response( +/// Parses the arguments of a request. +/// +/// # Errors +/// +/// Returns `frame::RecvError::DeserializationFailure`: if the arguments could +/// not be deserialized. +pub fn parse_args<'de, T: Deserialize<'de>>(args: &'de [u8]) -> Result { + bincode::DefaultOptions::new() + .deserialize::(args) + .map_err(frame::RecvError::DeserializationFailure) +} + +/// Sends a response to a request. +/// +/// # Errors +/// +/// * `SendError::MessageTooLarge` if `e` is too large to be serialized +/// * `SendError::WriteError` if the message could not be written +pub async fn send_response( send: &mut SendStream, buf: &mut Vec, body: T,