From 01d87228fde5053d293eb41de51cb0cb843bf70c Mon Sep 17 00:00:00 2001 From: Easyoakland <97992568+Easyoakland@users.noreply.github.com> Date: Tue, 22 Oct 2024 06:32:27 +0000 Subject: [PATCH] scroll_derive: Custom ctx override for derive macro (#106) * feat: attribute to override ctx for particular fields on derive(Pread, Pwrite) Enables using the syntax: ``` #[scroll(ctx = context_expr)] field: T ``` To use the `context_expr` expression as context when for `Pread` for `field`, regardless of the context for the rest of the struct. --- scroll_derive/examples/derive_custom_ctx.rs | 81 ++++++++++++++ scroll_derive/src/lib.rs | 116 ++++++++++++++++---- scroll_derive/tests/tests.rs | 39 ++++++- 3 files changed, 212 insertions(+), 24 deletions(-) create mode 100644 scroll_derive/examples/derive_custom_ctx.rs diff --git a/scroll_derive/examples/derive_custom_ctx.rs b/scroll_derive/examples/derive_custom_ctx.rs new file mode 100644 index 0000000..3de31ee --- /dev/null +++ b/scroll_derive/examples/derive_custom_ctx.rs @@ -0,0 +1,81 @@ +use scroll_derive::{Pread, Pwrite, SizeWith}; + +/// An example of using a method as the value for a ctx in a derive. +struct EndianDependent(Endian); +impl EndianDependent { + fn len(&self) -> usize { + match self.0 { + scroll::Endian::Little => 5, + scroll::Endian::Big => 6, + } + } +} + +#[derive(Debug, PartialEq)] +struct VariableLengthData { + buf: Vec, +} + +impl<'a> TryFromCtx<'a, usize> for VariableLengthData { + type Error = scroll::Error; + + fn try_from_ctx(from: &'a [u8], ctx: usize) -> Result<(Self, usize), Self::Error> { + let offset = &mut 0; + let buf = from.gread_with::<&[u8]>(offset, ctx)?.to_owned(); + Ok((Self { buf }, *offset)) + } +} +impl<'a> TryIntoCtx for &'a VariableLengthData { + type Error = scroll::Error; + fn try_into_ctx(self, dst: &mut [u8], ctx: usize) -> Result { + let offset = &mut 0; + for i in 0..(ctx.min(self.buf.len())) { + dst.gwrite(self.buf[i], offset)?; + } + Ok(*offset) + } +} +impl SizeWith for VariableLengthData { + fn size_with(ctx: &usize) -> usize { + *ctx + } +} + +#[derive(Debug, PartialEq, Pread, Pwrite, SizeWith)] +#[repr(C)] +struct Data { + id: u32, + timestamp: f64, + // You can fix the ctx regardless of what is passed in. + #[scroll(ctx = BE)] + arr: [u16; 2], + // You can use arbitrary expressions for the ctx. + // You have access to the `ctx` parameter of the `{pread/gread}_with` inside the expression. + // TODO(implement) you have access to previous fields. + // TODO(check) will this break structs with fields named `ctx`?. + #[scroll(ctx = EndianDependent(ctx.clone()).len())] + custom_ctx: VariableLengthData, +} + +use scroll::{ + ctx::{SizeWith, TryFromCtx, TryIntoCtx}, + Endian, Pread, Pwrite, BE, LE, +}; + +fn main() { + let bytes = [ + 0xefu8, 0xbe, 0xad, 0xde, 0, 0, 0, 0, 0, 0, 224, 63, 0xad, 0xde, 0xef, 0xbe, 0xaa, 0xbb, + 0xcc, 0xdd, 0xee, + ]; + let data: Data = bytes.pread_with(0, LE).unwrap(); + println!("data: {data:?}"); + assert_eq!(data.id, 0xdeadbeefu32); + assert_eq!(data.arr, [0xadde, 0xefbe]); + let mut bytes2 = vec![0; ::std::mem::size_of::()]; + bytes2.pwrite_with(data, 0, LE).unwrap(); + let data: Data = bytes.pread_with(0, LE).unwrap(); + let data2: Data = bytes2.pread_with(0, LE).unwrap(); + assert_eq!(data, data2); + // Not enough bytes because of ctx dependent length being too long. + assert!(bytes.pread_with::(0, BE).is_err()) +} diff --git a/scroll_derive/src/lib.rs b/scroll_derive/src/lib.rs index 919517f..a5714dd 100644 --- a/scroll_derive/src/lib.rs +++ b/scroll_derive/src/lib.rs @@ -2,11 +2,17 @@ extern crate proc_macro; use proc_macro2; -use quote::quote; +use quote::{quote, ToTokens}; use proc_macro::TokenStream; -fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream { +fn impl_field( + ident: &proc_macro2::TokenStream, + ty: &syn::Type, + custom_ctx: Option<&proc_macro2::TokenStream>, +) -> proc_macro2::TokenStream { + let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream(); + let ctx = custom_ctx.unwrap_or(&default_ctx); match *ty { syn::Type::Array(ref array) => match array.len { syn::Expr::Lit(syn::ExprLit { @@ -15,20 +21,63 @@ fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2:: }) => { let size = int.base10_parse::().unwrap(); quote! { - #ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, ctx)?; __tmp } + #ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, #ctx)?; __tmp } } } _ => panic!("Pread derive with bad array constexpr"), }, - syn::Type::Group(ref group) => impl_field(ident, &group.elem), + syn::Type::Group(ref group) => impl_field(ident, &group.elem, custom_ctx), _ => { quote! { - #ident: src.gread_with::<#ty>(offset, ctx)? + #ident: src.gread_with::<#ty>(offset, #ctx)? } } } } +/// Retrieve the field attribute with given ident e.g: +/// ```ignore +/// #[attr_ident(..)] +/// field: T, +/// ``` +fn get_attr<'a>(attr_ident: &str, field: &'a syn::Field) -> Option<&'a syn::Attribute> { + field + .attrs + .iter() + .filter(|attr| attr.path().is_ident(attr_ident)) + .next() +} + +/// Gets the `TokenStream` for the custom ctx set in the `ctx` attribute. e.g. `expr` in the following +/// ```ignore +/// #[scroll(ctx = expr)] +/// field: T, +/// ``` +fn custom_ctx(field: &syn::Field) -> Option { + get_attr("scroll", field).and_then(|x| { + // parsed #[scroll..] + // `expr` is `None` if the `ctx` key is not used. + let mut expr = None; + let res = x.parse_nested_meta(|meta| { + // parsed #[scroll(..)] + if meta.path.is_ident("ctx") { + // parsed #[scroll(ctx..)] + let value = meta.value()?; // parsed #[scroll(ctx = ..)] + expr = Some(value.parse::()?.into_token_stream()); // parsed #[scroll(ctx = expr)] + return Ok(()); + } + Err(meta.error(match meta.path.get_ident() { + Some(ident) => format!("unrecognized attribute: {ident}"), + None => "unrecognized and invalid attribute".to_owned(), + })) + }); + match res { + Ok(()) => expr, + Err(e) => Some(e.into_compile_error()), + } + }) +} + fn impl_struct( name: &syn::Ident, fields: &syn::punctuated::Punctuated, @@ -43,7 +92,9 @@ fn impl_struct( quote! {#t} }); let ty = &f.ty; - impl_field(ident, ty) + // parse the `expr` out of #[scroll(ctx = expr)] + let custom_ctx = custom_ctx(f); + impl_field(ident, ty, custom_ctx.as_ref()) }) .collect(); @@ -104,14 +155,20 @@ fn impl_try_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { } } -#[proc_macro_derive(Pread)] +#[proc_macro_derive(Pread, attributes(scroll))] pub fn derive_pread(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let gen = impl_try_from_ctx(&ast); gen.into() } -fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream { +fn impl_pwrite_field( + ident: &proc_macro2::TokenStream, + ty: &syn::Type, + custom_ctx: Option<&proc_macro2::TokenStream>, +) -> proc_macro2::TokenStream { + let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream(); + let ctx = custom_ctx.unwrap_or(&default_ctx); match ty { syn::Type::Array(ref array) => match array.len { syn::Expr::Lit(syn::ExprLit { @@ -121,24 +178,24 @@ fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_m let size = int.base10_parse::().unwrap(); quote! { for i in 0..#size { - dst.gwrite_with(&self.#ident[i], offset, ctx)?; + dst.gwrite_with(&self.#ident[i], offset, #ctx)?; } } } _ => panic!("Pwrite derive with bad array constexpr"), }, - syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem), + syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem, custom_ctx), syn::Type::Reference(reference) => match *reference.elem { syn::Type::Slice(_) => quote! { dst.gwrite_with(self.#ident, offset, ())? }, _ => quote! { - dst.gwrite_with(self.#ident, offset, ctx)? + dst.gwrite_with(self.#ident, offset, #ctx)? }, }, _ => { quote! { - dst.gwrite_with(&self.#ident, offset, ctx)? + dst.gwrite_with(&self.#ident, offset, #ctx)? } } } @@ -158,7 +215,8 @@ fn impl_try_into_ctx( quote! {#t} }); let ty = &f.ty; - impl_pwrite_field(ident, ty) + let custom_ctx = custom_ctx(f); + impl_pwrite_field(ident, ty, custom_ctx.as_ref()) }) .collect(); @@ -249,7 +307,7 @@ fn impl_pwrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { } } -#[proc_macro_derive(Pwrite)] +#[proc_macro_derive(Pwrite, attributes(scroll))] pub fn derive_pwrite(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let gen = impl_pwrite(&ast); @@ -265,6 +323,10 @@ fn size_with( .iter() .map(|f| { let ty = &f.ty; + let custom_ctx = custom_ctx(f).map(|x| quote! {&#x}); + let default_ctx = + syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream(); + let ctx = custom_ctx.unwrap_or(default_ctx); match *ty { syn::Type::Array(ref array) => { let elem = &array.elem; @@ -275,7 +337,7 @@ fn size_with( }) => { let size = int.base10_parse::().unwrap(); quote! { - (#size * <#elem>::size_with(ctx)) + (#size * <#elem>::size_with(#ctx)) } } _ => panic!("Pread derive with bad array constexpr"), @@ -283,7 +345,7 @@ fn size_with( } _ => { quote! { - <#ty>::size_with(ctx) + <#ty>::size_with(#ctx) } } } @@ -341,7 +403,7 @@ fn impl_size_with(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { } } -#[proc_macro_derive(SizeWith)] +#[proc_macro_derive(SizeWith, attributes(scroll))] pub fn derive_sizewith(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let gen = impl_size_with(&ast); @@ -356,6 +418,10 @@ fn impl_cread_struct( let items: Vec<_> = fields.iter().enumerate().map(|(i, f)| { let ident = &f.ident.as_ref().map(|i|quote!{#i}).unwrap_or({let t = proc_macro2::Literal::usize_unsuffixed(i); quote!{#t}}); let ty = &f.ty; + let custom_ctx = custom_ctx(f); + let default_ctx = + syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream(); + let ctx = custom_ctx.unwrap_or(default_ctx); match *ty { syn::Type::Array(ref array) => { let arrty = &array.elem; @@ -367,7 +433,7 @@ fn impl_cread_struct( #ident: { let mut __tmp: #ty = [0u8.into(); #size]; for i in 0..__tmp.len() { - __tmp[i] = src.cread_with(*offset, ctx); + __tmp[i] = src.cread_with(*offset, #ctx); *offset += #incr; } __tmp @@ -380,7 +446,7 @@ fn impl_cread_struct( _ => { let size = quote! { ::scroll::export::mem::size_of::<#ty>() }; quote! { - #ident: { let res = src.cread_with::<#ty>(*offset, ctx); *offset += #size; res } + #ident: { let res = src.cread_with::<#ty>(*offset, #ctx); *offset += #size; res } } } } @@ -440,7 +506,7 @@ fn impl_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { } } -#[proc_macro_derive(IOread)] +#[proc_macro_derive(IOread, attributes(scroll))] pub fn derive_ioread(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let gen = impl_from_ctx(&ast); @@ -462,20 +528,24 @@ fn impl_into_ctx( }); let ty = &f.ty; let size = quote! { ::scroll::export::mem::size_of::<#ty>() }; + let custom_ctx = custom_ctx(f); + let default_ctx = + syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream(); + let ctx = custom_ctx.unwrap_or(default_ctx); match *ty { syn::Type::Array(ref array) => { let arrty = &array.elem; quote! { let size = ::scroll::export::mem::size_of::<#arrty>(); for i in 0..self.#ident.len() { - dst.cwrite_with(self.#ident[i], *offset, ctx); + dst.cwrite_with(self.#ident[i], *offset, #ctx); *offset += size; } } } _ => { quote! { - dst.cwrite_with(self.#ident, *offset, ctx); + dst.cwrite_with(self.#ident, *offset, #ctx); *offset += #size; } } @@ -544,7 +614,7 @@ fn impl_iowrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { } } -#[proc_macro_derive(IOwrite)] +#[proc_macro_derive(IOwrite, attributes(scroll))] pub fn derive_iowrite(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let gen = impl_iowrite(&ast); diff --git a/scroll_derive/tests/tests.rs b/scroll_derive/tests/tests.rs index 3c7f806..14dec4e 100644 --- a/scroll_derive/tests/tests.rs +++ b/scroll_derive/tests/tests.rs @@ -1,4 +1,4 @@ -use scroll::{Cread, Cwrite, Pread, Pwrite, LE}; +use scroll::{Cread, Cwrite, IOread, IOwrite, Pread, Pwrite, BE, LE}; use scroll_derive::{IOread, IOwrite, Pread, Pwrite, SizeWith}; use scroll::ctx::SizeWith; @@ -232,3 +232,40 @@ fn test_reference() { assert_eq!(bytes.pwrite_with(&data, 0, LE).unwrap(), 7); assert_eq!(bytes[..7], *b"\xff\x01\x00name"); } + +#[derive(Debug, PartialEq, Pwrite, Pread, IOwrite, IOread, SizeWith)] +struct Data11 { + pub a: u16, + #[scroll(ctx = LE)] + pub b: u16, + #[scroll(ctx = BE)] + pub c: u16, +} + +#[test] +fn test_custom_ctx_derive() { + let buf = [1, 2, 3, 4, 5, 6]; + let data = buf.pread_with(0, LE).unwrap(); + let data2 = Data11 { + a: 0x0201, + b: 0x0403, + c: 0x0506, + }; + assert_eq!(data, data2); + let mut bytes = vec![0; 32]; + assert_eq!(bytes.pwrite_with::<&Data11>(&data, 0, LE).unwrap(), 6); + assert_eq!(bytes[..Data11::size_with(&LE)], buf[..]); + let mut bytes = std::io::Cursor::new(bytes); + assert_eq!(data2, bytes.ioread_with(LE).unwrap()); + bytes.set_position(0); + bytes.iowrite_with(data, BE).unwrap(); + bytes.set_position(0); + assert_eq!(data2, bytes.ioread_with(BE).unwrap()); + bytes.set_position(0); + let data3 = Data11 { + a: 0x0102, + b: 0x0403, + c: 0x0506, + }; + assert_eq!(data3, bytes.ioread_with(LE).unwrap()); +}