From 8aa90e49d7b149238ae2f0a78246b39954ec5fee Mon Sep 17 00:00:00 2001 From: Felix Wirth Date: Wed, 26 Jun 2024 13:56:41 +0200 Subject: [PATCH] transport: split encoder into own module --- src/transport/encode.rs | 325 ++++++++++++++++++++++++++++++++++++++++ src/transport/mod.rs | 318 +-------------------------------------- 2 files changed, 328 insertions(+), 315 deletions(-) create mode 100644 src/transport/encode.rs diff --git a/src/transport/encode.rs b/src/transport/encode.rs new file mode 100644 index 0000000..196b904 --- /dev/null +++ b/src/transport/encode.rs @@ -0,0 +1,325 @@ +use core::borrow::Borrow; + +use crate::util::{Buffer, OutOfMemory, CRC_X25}; + +struct Padding(u8); + +impl Padding { + const fn new() -> Self { + Padding(0) + } + + fn bump(&mut self) { + self.0 = self.0.wrapping_sub(1); + } + + const fn get(&self) -> u8 { + self.0 & 0x3 + } +} + +#[derive(Debug, Clone, Copy)] +enum EncoderState { + Init(u8), + LookingForEscape(u8), + HandlingEscape(u8), + End(i8), +} + +/// An iterator that encodes the bytes of an underlying iterator using the SML Transport Protocol v1. +pub struct Encoder +where + I: Iterator, +{ + state: EncoderState, + crc: crc::Digest<'static, u16>, + padding: Padding, + iter: I, +} + +impl Encoder +where + I: Iterator, +{ + /// Creates an `Encoder` from a byte iterator. + pub fn new(iter: I) -> Self { + let mut crc = CRC_X25.digest(); + crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]); + Encoder { + state: EncoderState::Init(0), + crc, + padding: Padding::new(), + iter, + } + } + + fn read_from_iter(&mut self) -> Option { + let ret = self.iter.next(); + if ret.is_some() { + self.padding.bump(); + } + ret + } + + fn next_from_state(&mut self, state: EncoderState) -> (Option, EncoderState) { + self.state = state; + let out = self.next(); + (out, self.state) + } +} + +impl Iterator for Encoder +where + I: Iterator, +{ + type Item = u8; + + fn next(&mut self) -> Option { + use EncoderState::*; + let (out, state) = match self.state { + Init(n) if n < 4 => (Some(0x1b), Init(n + 1)), + Init(n) if n < 8 => (Some(0x01), Init(n + 1)), + Init(n) => { + assert_eq!(n, 8); + self.next_from_state(LookingForEscape(0)) + } + LookingForEscape(n) if n < 4 => { + match self.read_from_iter() { + Some(b) => { + self.crc.update(&[b]); + (Some(b), LookingForEscape((n + 1) * u8::from(b == 0x1b))) + } + None => { + let padding = self.padding.get(); + // finalize crc + for _ in 0..padding { + self.crc.update(&[0x00]); + } + self.crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, padding]); + self.next_from_state(End(-(padding as i8))) + } + } + } + LookingForEscape(n) => { + assert_eq!(n, 4); + self.crc.update(&[0x1b; 4]); + self.next_from_state(HandlingEscape(0)) + } + HandlingEscape(n) if n < 4 => (Some(0x1b), HandlingEscape(n + 1)), + HandlingEscape(n) => { + assert_eq!(n, 4); + self.next_from_state(LookingForEscape(0)) + } + End(n) => { + let out = match n { + n if n < 0 => 0x00, + n if n < 4 => 0x1b, + 4 => 0x1a, + 5 => self.padding.get(), + n if n < 8 => { + let crc_bytes = self.crc.clone().finalize().to_le_bytes(); + crc_bytes[(n - 6) as usize] + } + 8 => { + return None; + } + _ => unreachable!(), + }; + (Some(out), End(n + 1)) + } + }; + self.state = state; + out + } +} + +/// Takes a slice of bytes as input and returns a buffer containing the encoded message. +/// +/// Returns `Err(())` when the buffer can't be grown to hold the entire output. +/// +/// # Examples +/// +/// ``` +/// // example data +/// let bytes = [0x12, 0x34, 0x56, 0x78]; +/// let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; +/// ``` +/// +/// ### Using alloc::Vec +/// +/// ``` +/// # #[cfg(feature = "alloc")] { +/// # use sml_rs::transport::encode; +/// # let bytes = [0x12, 0x34, 0x56, 0x78]; +/// # let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; +/// let encoded = encode::>(&bytes); +/// assert!(encoded.is_ok()); +/// assert_eq!(encoded.unwrap().as_slice(), &expected); +/// # } +/// ``` +/// +/// ### Using `ArrayBuf` +/// +/// ``` +/// # use sml_rs::{util::{ArrayBuf, OutOfMemory}, transport::encode}; +/// # let bytes = [0x12, 0x34, 0x56, 0x78]; +/// # let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; +/// let encoded = encode::>(&bytes); +/// assert!(encoded.is_ok()); +/// assert_eq!(&*encoded.unwrap(), &expected); +/// +/// // encoding returns `Err(())` if the encoded message does not fit into the vector +/// let encoded = encode::>(&bytes); +/// assert_eq!(encoded, Err(OutOfMemory)); +/// ``` +/// +pub fn encode( + iter: impl IntoIterator>, +) -> Result { + let mut res: B = Default::default(); + + // start escape sequence + res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01])?; + + // encode data + let mut num_1b = 0; + for b in iter.into_iter() { + let b = *b.borrow(); + if b == 0x1b { + num_1b += 1; + } else { + num_1b = 0; + } + + res.push(b)?; + + if num_1b == 4 { + res.extend_from_slice(&[0x1b; 4])?; + num_1b = 0; + } + } + + // padding bytes + let num_padding_bytes = (4 - (res.len() % 4)) % 4; + res.extend_from_slice(&[0x0; 3][..num_padding_bytes])?; + + res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, num_padding_bytes as u8])?; + let crc = CRC_X25.checksum(&res[..]); + + res.extend_from_slice(&crc.to_le_bytes())?; + + Ok(res) +} + +/// Takes an iterator over bytes and returns an iterator that produces the encoded message. +/// +/// # Examples +/// ``` +/// # use sml_rs::transport::encode_streaming; +/// // example data +/// let bytes = [0x12, 0x34, 0x56, 0x78]; +/// let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; +/// let iter = encode_streaming(bytes); +/// assert!(iter.eq(expected)); +/// ``` +pub fn encode_streaming( + iter: impl IntoIterator>, +) -> Encoder> { + Encoder::new(iter.into_iter().map(|x| *x.borrow())) +} + +#[cfg(test)] +mod tests { + use crate::transport::decode; + + use super::*; + use hex_literal::hex; + + // assert_eq macro that prints its arguments as hex when they don't match. + // (adapted from the `assert_hex` crate) + macro_rules! assert_eq_hex { + ($left:expr, $right:expr $(,)?) => {{ + match (&$left, &$right) { + (left_val, right_val) => { + if !(*left_val == *right_val) { + // The reborrows below are intentional. Without them, the stack slot for the + // borrow is initialized even before the values are compared, leading to a + // noticeable slow down. + panic!( + "assertion failed: `(left == right)`\n left: `{:02x?}`,\n right: `{:02x?}`", + &*left_val, &*right_val + ) + } + } + } + }}; + } + + fn test_encoding(bytes: &[u8], exp_encoded_bytes: &[u8; N]) { + // test that: encode(bytes) == exp_encoded_bytes + compare_encoded_bytes( + exp_encoded_bytes, + &encode::>(bytes).expect("ran out of memory"), + ); + + // test that: encode_streaming(bytes).collect() == exp_encoded_bytes + compare_encoded_bytes( + exp_encoded_bytes, + &encode_streaming(bytes).collect::>(), + ); + + // reverse direction: + // test that: decode(exp_encoded_bytes) == Ok(bytes) + #[cfg(feature = "alloc")] + assert_eq_hex!(alloc::vec![Ok(bytes.to_vec())], decode(exp_encoded_bytes)); + } + + fn compare_encoded_bytes(expected: &[u8], actual: &[u8]) { + assert_eq_hex!(expected, actual); + } + + #[test] + fn basic() { + test_encoding( + &hex!("12345678"), + &hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b"), + ); + } + + #[test] + fn empty() { + test_encoding(&hex!(""), &hex!("1b1b1b1b 01010101 1b1b1b1b 1a00c6e5")); + } + + #[test] + fn padding() { + test_encoding( + &hex!("123456"), + &hex!("1b1b1b1b 01010101 12345600 1b1b1b1b 1a0191a5"), + ); + } + + #[test] + fn escape_in_user_data() { + test_encoding( + &hex!("121b1b1b1b"), + &hex!("1b1b1b1b 01010101 12 1b1b1b1b 1b1b1b1b 000000 1b1b1b1b 1a03be25"), + ); + } + + #[test] + fn almost_escape_in_user_data() { + test_encoding( + &hex!("121b1b1bFF"), + &hex!("1b1b1b1b 01010101 12 1b1b1bFF 000000 1b1b1b1b 1a0324d9"), + ); + } + + #[test] + fn ending_with_1b_no_padding() { + test_encoding( + &hex!("12345678 12341b1b"), + &hex!("1b1b1b1b 01010101 12345678 12341b1b 1b1b1b1b 1a001ac5"), + ); + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 5a84480..75320e7 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -29,245 +29,21 @@ //! - using `Decoder` directly: instantiate a `Decoder` manually, call `push_byte()` on it when data becomes available. Call `finalize()` when all data has been pushed. mod decoder_reader; +mod encode; pub use decoder_reader::{DecoderReader, ReadDecodedError}; +pub use encode::{Encoder, encode, encode_streaming}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use core::{borrow::Borrow, fmt}; -use crate::util::{Buffer, OutOfMemory, CRC_X25}; +use crate::util::{Buffer, CRC_X25}; #[cfg(feature = "alloc")] use alloc::vec::Vec; -struct Padding(u8); - -impl Padding { - const fn new() -> Self { - Padding(0) - } - - fn bump(&mut self) { - self.0 = self.0.wrapping_sub(1); - } - - const fn get(&self) -> u8 { - self.0 & 0x3 - } -} - -#[derive(Debug, Clone, Copy)] -enum EncoderState { - Init(u8), - LookingForEscape(u8), - HandlingEscape(u8), - End(i8), -} - -/// An iterator that encodes the bytes of an underlying iterator using the SML Transport Protocol v1. -pub struct Encoder -where - I: Iterator, -{ - state: EncoderState, - crc: crc::Digest<'static, u16>, - padding: Padding, - iter: I, -} - -impl Encoder -where - I: Iterator, -{ - /// Creates an `Encoder` from a byte iterator. - pub fn new(iter: I) -> Self { - let mut crc = CRC_X25.digest(); - crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]); - Encoder { - state: EncoderState::Init(0), - crc, - padding: Padding::new(), - iter, - } - } - - fn read_from_iter(&mut self) -> Option { - let ret = self.iter.next(); - if ret.is_some() { - self.padding.bump(); - } - ret - } - - fn next_from_state(&mut self, state: EncoderState) -> (Option, EncoderState) { - self.state = state; - let out = self.next(); - (out, self.state) - } -} - -impl Iterator for Encoder -where - I: Iterator, -{ - type Item = u8; - - fn next(&mut self) -> Option { - use EncoderState::*; - let (out, state) = match self.state { - Init(n) if n < 4 => (Some(0x1b), Init(n + 1)), - Init(n) if n < 8 => (Some(0x01), Init(n + 1)), - Init(n) => { - assert_eq!(n, 8); - self.next_from_state(LookingForEscape(0)) - } - LookingForEscape(n) if n < 4 => { - match self.read_from_iter() { - Some(b) => { - self.crc.update(&[b]); - (Some(b), LookingForEscape((n + 1) * u8::from(b == 0x1b))) - } - None => { - let padding = self.padding.get(); - // finalize crc - for _ in 0..padding { - self.crc.update(&[0x00]); - } - self.crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, padding]); - self.next_from_state(End(-(padding as i8))) - } - } - } - LookingForEscape(n) => { - assert_eq!(n, 4); - self.crc.update(&[0x1b; 4]); - self.next_from_state(HandlingEscape(0)) - } - HandlingEscape(n) if n < 4 => (Some(0x1b), HandlingEscape(n + 1)), - HandlingEscape(n) => { - assert_eq!(n, 4); - self.next_from_state(LookingForEscape(0)) - } - End(n) => { - let out = match n { - n if n < 0 => 0x00, - n if n < 4 => 0x1b, - 4 => 0x1a, - 5 => self.padding.get(), - n if n < 8 => { - let crc_bytes = self.crc.clone().finalize().to_le_bytes(); - crc_bytes[(n - 6) as usize] - } - 8 => { - return None; - } - _ => unreachable!(), - }; - (Some(out), End(n + 1)) - } - }; - self.state = state; - out - } -} - -/// Takes a slice of bytes as input and returns a buffer containing the encoded message. -/// -/// Returns `Err(())` when the buffer can't be grown to hold the entire output. -/// -/// # Examples -/// -/// ``` -/// // example data -/// let bytes = [0x12, 0x34, 0x56, 0x78]; -/// let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; -/// ``` -/// -/// ### Using alloc::Vec -/// -/// ``` -/// # #[cfg(feature = "alloc")] { -/// # use sml_rs::transport::encode; -/// # let bytes = [0x12, 0x34, 0x56, 0x78]; -/// # let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; -/// let encoded = encode::>(&bytes); -/// assert!(encoded.is_ok()); -/// assert_eq!(encoded.unwrap().as_slice(), &expected); -/// # } -/// ``` -/// -/// ### Using `ArrayBuf` -/// -/// ``` -/// # use sml_rs::{util::{ArrayBuf, OutOfMemory}, transport::encode}; -/// # let bytes = [0x12, 0x34, 0x56, 0x78]; -/// # let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; -/// let encoded = encode::>(&bytes); -/// assert!(encoded.is_ok()); -/// assert_eq!(&*encoded.unwrap(), &expected); -/// -/// // encoding returns `Err(())` if the encoded message does not fit into the vector -/// let encoded = encode::>(&bytes); -/// assert_eq!(encoded, Err(OutOfMemory)); -/// ``` -/// -pub fn encode( - iter: impl IntoIterator>, -) -> Result { - let mut res: B = Default::default(); - - // start escape sequence - res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01])?; - - // encode data - let mut num_1b = 0; - for b in iter.into_iter() { - let b = *b.borrow(); - if b == 0x1b { - num_1b += 1; - } else { - num_1b = 0; - } - - res.push(b)?; - - if num_1b == 4 { - res.extend_from_slice(&[0x1b; 4])?; - num_1b = 0; - } - } - - // padding bytes - let num_padding_bytes = (4 - (res.len() % 4)) % 4; - res.extend_from_slice(&[0x0; 3][..num_padding_bytes])?; - - res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, num_padding_bytes as u8])?; - let crc = CRC_X25.checksum(&res[..]); - - res.extend_from_slice(&crc.to_le_bytes())?; - - Ok(res) -} - -/// Takes an iterator over bytes and returns an iterator that produces the encoded message. -/// -/// # Examples -/// ``` -/// # use sml_rs::transport::encode_streaming; -/// // example data -/// let bytes = [0x12, 0x34, 0x56, 0x78]; -/// let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b]; -/// let iter = encode_streaming(bytes); -/// assert!(iter.eq(expected)); -/// ``` -pub fn encode_streaming( - iter: impl IntoIterator>, -) -> Encoder> { - Encoder::new(iter.into_iter().map(|x| *x.borrow())) -} - #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, PartialEq, Eq, Clone)] /// An error which can be returned when decoding an sml message. @@ -805,94 +581,6 @@ pub fn decode_streaming( DecodeIterator::new(iter.into_iter().map(|x| *x.borrow())) } -#[cfg(test)] -mod tests { - use super::*; - use hex_literal::hex; - - // assert_eq macro that prints its arguments as hex when they don't match. - // (adapted from the `assert_hex` crate) - macro_rules! assert_eq_hex { - ($left:expr, $right:expr $(,)?) => {{ - match (&$left, &$right) { - (left_val, right_val) => { - if !(*left_val == *right_val) { - // The reborrows below are intentional. Without them, the stack slot for the - // borrow is initialized even before the values are compared, leading to a - // noticeable slow down. - panic!( - "assertion failed: `(left == right)`\n left: `{:02x?}`,\n right: `{:02x?}`", - &*left_val, &*right_val - ) - } - } - } - }}; - } - - fn test_encoding(bytes: &[u8], exp_encoded_bytes: &[u8; N]) { - compare_encoded_bytes( - exp_encoded_bytes, - &encode::>(bytes).expect("ran out of memory"), - ); - compare_encoded_bytes( - exp_encoded_bytes, - &encode_streaming(bytes).collect::>(), - ); - #[cfg(feature = "alloc")] - assert_eq_hex!(alloc::vec![Ok(bytes.to_vec())], decode(exp_encoded_bytes)); - } - - fn compare_encoded_bytes(expected: &[u8], actual: &[u8]) { - assert_eq_hex!(expected, actual); - } - - #[test] - fn basic() { - test_encoding( - &hex!("12345678"), - &hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b"), - ); - } - - #[test] - fn empty() { - test_encoding(&hex!(""), &hex!("1b1b1b1b 01010101 1b1b1b1b 1a00c6e5")); - } - - #[test] - fn padding() { - test_encoding( - &hex!("123456"), - &hex!("1b1b1b1b 01010101 12345600 1b1b1b1b 1a0191a5"), - ); - } - - #[test] - fn escape_in_user_data() { - test_encoding( - &hex!("121b1b1b1b"), - &hex!("1b1b1b1b 01010101 12 1b1b1b1b 1b1b1b1b 000000 1b1b1b1b 1a03be25"), - ); - } - - #[test] - fn almost_escape_in_user_data() { - test_encoding( - &hex!("121b1b1bFF"), - &hex!("1b1b1b1b 01010101 12 1b1b1bFF 000000 1b1b1b1b 1a0324d9"), - ); - } - - #[test] - fn ending_with_1b_no_padding() { - test_encoding( - &hex!("12345678 12341b1b"), - &hex!("1b1b1b1b 01010101 12345678 12341b1b 1b1b1b1b 1a001ac5"), - ); - } -} - #[cfg(test)] mod decode_tests { use super::*;