diff --git a/CHANGELOG.md b/CHANGELOG.md index b197ad2..a0dd598 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,22 +4,9 @@ ## [Unreleased] - ReleaseDate -### Added - -### Changed - -### Removed - - - -## [0.5.4] - 2023-09-10 - -### Added - -### Changed - -### Removed - +Rework the validation of names and aliases for aggregate UDFs. This fixes an +issue where aliases could not be used for aggregate UDFs, and provides better +error messages. ## [0.5.4] - 2023-09-10 diff --git a/udf-macros/src/register.rs b/udf-macros/src/register.rs index 42c4320..32420fa 100644 --- a/udf-macros/src/register.rs +++ b/udf-macros/src/register.rs @@ -1,5 +1,7 @@ #![allow(unused_imports)] +use std::iter; + use heck::AsSnakeCase; use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; @@ -8,24 +10,17 @@ use syn::parse::{Parse, ParseStream, Parser}; use syn::punctuated::Punctuated; use syn::{ parse_macro_input, parse_quote, DeriveInput, Error, Expr, ExprLit, Ident, ImplItem, - ImplItemType, Item, ItemImpl, Lit, Meta, Path, PathSegment, Token, Type, TypePath, + ImplItemType, Item, ItemImpl, Lit, LitStr, Meta, Path, PathSegment, Token, Type, TypePath, TypeReference, }; use crate::match_variant; use crate::types::{make_type_list, ImplType, RetType, TypeClass}; -/// Create an identifier from another identifier, changing the name to snake case -macro_rules! format_ident_str { - ($formatter:tt, $ident:ident) => { - Ident::new(format!($formatter, $ident).as_str(), Span::call_site()) - }; -} - /// Verify that an `ItemImpl` matches the end of any given path /// /// implements `BasicUdf` (in any of its pathing options) -fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool { +fn impl_type(itemimpl: &ItemImpl) -> Option { let implemented = &itemimpl.trait_.as_ref().unwrap().1.segments; let basic_paths: [Punctuated; 3] = [ @@ -39,9 +34,12 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool { parse_quote! {AggregateUdf}, ]; - match expected { - ImplType::Basic => basic_paths.contains(implemented), - ImplType::Aggregate => arg_paths.contains(implemented), + if basic_paths.contains(implemented) { + Some(ImplType::Basic) + } else if arg_paths.contains(implemented) { + Some(ImplType::Aggregate) + } else { + None } } @@ -57,14 +55,11 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool { pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream { let parsed = parse_macro_input!(input as ItemImpl); - let impls_basic = impls_path(&parsed, ImplType::Basic); - let impls_agg = impls_path(&parsed, ImplType::Aggregate); - - if !(impls_basic || impls_agg) { + let Some(impl_ty) = impl_type(&parsed) else { return Error::new_spanned(&parsed, "Expected trait `BasicUdf` or `AggregateUdf`") .into_compile_error() .into(); - } + }; // Full type path of our data struct let Type::Path(dstruct_path) = parsed.self_ty.as_ref() else { @@ -73,7 +68,7 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream { .into(); }; - let base_fn_names = match parse_args(args, dstruct_path) { + let parsed_meta = match ParsedMeta::parse(args, dstruct_path) { Ok(v) => v, Err(e) => return e.into_compile_error().into(), }; @@ -89,26 +84,24 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream { Span::call_site(), ); - let (ret_ty, wrapper_def) = if impls_basic { - match get_rt_and_wrapper(&parsed, dstruct_path, &wrapper_ident) { + let (ret_ty, wrapper_def) = match impl_ty { + ImplType::Basic => match get_ret_ty_and_wrapper(&parsed, dstruct_path, &wrapper_ident) { Ok((r, w)) => (Some(r), w), Err(e) => return e.into_compile_error().into(), - } - } else { - (None, TokenStream2::new()) + }, + ImplType::Aggregate => (None, TokenStream2::new()), }; - let content_iter = base_fn_names.iter().map(|base_fn_name| { - if impls_basic { - make_basic_fns( - ret_ty.as_ref().unwrap(), - base_fn_name, - dstruct_path, - &wrapper_ident, - ) - } else { - make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident) - } + let helper_traits = make_helper_trait_impls(dstruct_path, &parsed_meta, impl_ty); + + let fn_items_iter = parsed_meta.all_names().map(|base_fn_name| match impl_ty { + ImplType::Basic => make_basic_fns( + ret_ty.as_ref().unwrap(), + base_fn_name, + dstruct_path, + &wrapper_ident, + ), + ImplType::Aggregate => make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident), }); quote! { @@ -116,64 +109,85 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream { #wrapper_def - #( #content_iter )* + #helper_traits + + #( #fn_items_iter )* } .into() } -/// Parse attribute arguments. Returns an iterator of names -fn parse_args(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result> { - let meta = Punctuated::::parse_terminated.parse(args.clone())?; - let mut base_fn_names: Vec = vec![]; - let mut primary_name_specified = false; - - for m in meta { - let Meta::NameValue(mval) = m else { - return Err(Error::new_spanned(m, "expected `a = b atributes`")); - }; +/// Arguments we parse from metadata or default to +struct ParsedMeta { + name: String, + aliases: Vec, + default_name_used: bool, +} - if !mval.path.segments.iter().count() == 1 { - return Err(Error::new_spanned(mval.path, "unexpected path")); - } +impl ParsedMeta { + /// Parse attribute arguments. Returns an iterator of names + fn parse(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result { + let meta = Punctuated::::parse_terminated.parse(args.clone())?; + let mut name_from_attributes = None; + let mut aliases = Vec::new(); - let key = mval.path.segments.first().unwrap(); + for m in meta { + let Meta::NameValue(mval) = m else { + return Err(Error::new_spanned(m, "expected `a = b atributes`")); + }; - let Expr::Lit(ExprLit { - lit: Lit::Str(value), - .. - }) = mval.value - else { - return Err(Error::new_spanned(mval.value, "expected a literal string")); - }; + if !mval.path.segments.iter().count() == 1 { + return Err(Error::new_spanned(mval.path, "unexpected path")); + } - if key.ident == "name" { - if primary_name_specified { - return Err(Error::new_spanned(key, "`name` can only be specified once")); + let key = mval.path.segments.first().unwrap(); + + let Expr::Lit(ExprLit { + lit: Lit::Str(value), + .. + }) = mval.value + else { + return Err(Error::new_spanned(mval.value, "expected a literal string")); + }; + + if key.ident == "name" { + if name_from_attributes.is_some() { + return Err(Error::new_spanned(key, "`name` can only be specified once")); + } + name_from_attributes = Some(value.value()); + } else if key.ident == "alias" { + aliases.push(value.value()); + } else { + return Err(Error::new_spanned( + key, + "unexpected key (only `name` and `alias` are accepted)", + )); } - base_fn_names.push(value.value()); - primary_name_specified = true; - } else if key.ident == "alias" { - base_fn_names.push(value.value()); - } else { - return Err(Error::new_spanned( - key, - "unexpected key (only `name` and `alias` are accepted)", - )); } - } - if !primary_name_specified { - // If we don't have a name specified, use the type name as snake case - let ty_ident = &dstruct_path.path.segments.last().unwrap().ident; - let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string(); - base_fn_names.push(fn_name); + let mut default_name_used = false; + let name = name_from_attributes.unwrap_or_else(|| { + // If we don't have a name specified, use the type name as snake case + let ty_ident = &dstruct_path.path.segments.last().unwrap().ident; + let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string(); + default_name_used = true; + fn_name + }); + + Ok(Self { + name, + aliases, + default_name_used, + }) } - Ok(base_fn_names) + /// Iterate the basic name and all aliases + fn all_names(&self) -> impl Iterator { + iter::once(&self.name).chain(self.aliases.iter()) + } } /// Get the return type to use and a wrapper. Once per impl setup. -fn get_rt_and_wrapper( +fn get_ret_ty_and_wrapper( parsed: &ItemImpl, dstruct_path: &TypePath, wrapper_ident: &Ident, @@ -209,6 +223,40 @@ fn get_rt_and_wrapper( Ok((ret_ty, wrapper_struct)) } +/// Make implementations for our helper/metadata traits +fn make_helper_trait_impls( + dstruct_path: &TypePath, + meta: &ParsedMeta, + impl_ty: ImplType, +) -> TokenStream2 { + let name = LitStr::new(&meta.name, Span::call_site()); + let aliases = meta + .aliases + .iter() + .map(|alias| LitStr::new(alias.as_ref(), Span::call_site())); + let (trait_name, check_expr) = match impl_ty { + ImplType::Basic => ( + quote! { ::udf::wrapper::RegisteredBasicUdf }, + TokenStream2::new(), + ), + ImplType::Aggregate => ( + quote! { ::udf::wrapper::RegisteredAggregateUdf }, + quote! { const _: () = ::udf::wrapper::verify_aggregate_attributes::<#dstruct_path>(); }, + ), + }; + let default_name_used = meta.default_name_used; + + quote! { + impl #trait_name for #dstruct_path { + const NAME: &'static str = #name; + const ALIASES: &'static [&'static str] = &[#( #aliases ),*]; + const DEFAULT_NAME_USED: bool = #default_name_used; + } + + #check_expr + } +} + /// Create the basic function signatures (`xxx_init`, `xxx_deinit`, `xxx`) fn make_basic_fns( rt: &RetType, @@ -216,9 +264,9 @@ fn make_basic_fns( dstruct_path: &TypePath, wrapper_ident: &Ident, ) -> TokenStream2 { - let init_fn_name = format_ident_str!("{}_init", base_fn_name); - let deinit_fn_name = format_ident_str!("{}_deinit", base_fn_name); - let process_fn_name = format_ident_str!("{}", base_fn_name); + let init_fn_name = format_ident!("{}_init", base_fn_name); + let deinit_fn_name = format_ident!("{}_deinit", base_fn_name); + let process_fn_name = format_ident!("{}", base_fn_name); let init_fn = make_init_fn(dstruct_path, wrapper_ident, &init_fn_name); let deinit_fn = make_deinit_fn(dstruct_path, wrapper_ident, &deinit_fn_name); @@ -269,9 +317,9 @@ fn make_agg_fns( dstruct_path: &TypePath, // Name of the data structure wrapper_ident: &Ident, ) -> TokenStream2 { - let clear_fn_name = format_ident_str!("{}_clear", base_fn_name); - let add_fn_name = format_ident_str!("{}_add", base_fn_name); - let remove_fn_name = format_ident_str!("{}_remove", base_fn_name); + let clear_fn_name = format_ident!("{}_clear", base_fn_name); + let add_fn_name = format_ident!("{}_add", base_fn_name); + let remove_fn_name = format_ident!("{}_remove", base_fn_name); // Determine whether this re-implements `remove` let impls_remove = &parsed @@ -280,7 +328,6 @@ fn make_agg_fns( .filter_map(match_variant!(ImplItem::Fn)) .map(|m| &m.sig.ident) .any(|id| *id == "remove"); - let base_fn_ident = Ident::new(base_fn_name, Span::call_site()); let clear_fn = make_clear_fn(dstruct_path, wrapper_ident, &clear_fn_name); let add_fn = make_add_fn(dstruct_path, wrapper_ident, &add_fn_name); @@ -295,10 +342,6 @@ fn make_agg_fns( }; quote! { - // Sanity check that we implemented - #[allow(dead_code, non_upper_case_globals)] - const did_you_apply_the_same_aliases_to_the_BasicUdf_impl: *const () = #base_fn_ident as _; - #clear_fn #add_fn diff --git a/udf-macros/tests/fail/agg_missing_basic.rs b/udf-macros/tests/fail/agg_missing_basic.rs new file mode 100644 index 0000000..88bfc8a --- /dev/null +++ b/udf-macros/tests/fail/agg_missing_basic.rs @@ -0,0 +1,22 @@ +#![allow(unused)] + +use udf::prelude::*; + +struct MyUdf; + +impl AggregateUdf for MyUdf { + // Required methods + fn clear(&mut self, cfg: &UdfCfg, error: Option) -> Result<(), NonZeroU8> { + todo!() + } + fn add( + &mut self, + cfg: &UdfCfg, + args: &ArgList<'_, Process>, + error: Option, + ) -> Result<(), NonZeroU8> { + todo!() + } +} + +fn main() {} diff --git a/udf-macros/tests/fail/agg_missing_basic.stderr b/udf-macros/tests/fail/agg_missing_basic.stderr new file mode 100644 index 0000000..b944721 --- /dev/null +++ b/udf-macros/tests/fail/agg_missing_basic.stderr @@ -0,0 +1,11 @@ +error[E0277]: the trait bound `MyUdf: BasicUdf` is not satisfied + --> tests/fail/agg_missing_basic.rs:7:23 + | +7 | impl AggregateUdf for MyUdf { + | ^^^^^ the trait `BasicUdf` is not implemented for `MyUdf` + | +note: required by a bound in `udf::AggregateUdf` + --> $WORKSPACE/udf/src/traits.rs + | + | pub trait AggregateUdf: BasicUdf { + | ^^^^^^^^ required by this bound in `AggregateUdf` diff --git a/udf-macros/tests/fail/missing_rename.stderr b/udf-macros/tests/fail/missing_rename.stderr index 36c4bd5..0e1a69f 100644 --- a/udf-macros/tests/fail/missing_rename.stderr +++ b/udf-macros/tests/fail/missing_rename.stderr @@ -1,7 +1,22 @@ -error[E0425]: cannot find value `my_udf` in this scope +error[E0080]: evaluation of constant value failed + --> $WORKSPACE/udf/src/wrapper.rs + | + | panic!("{}", msg); + | ^^^^^^^^^^^^^^^^^ the evaluated program panicked at '`#[register]` on `BasicUdf` and `AggregateUdf` must have the same `name` argument; got `foo` and `my_udf` (default from struct name)', $WORKSPACE/udf/src/wrapper.rs:83:5 + | +note: inside `wrapper::verify_aggregate_attributes_name::` + --> $WORKSPACE/udf/src/wrapper.rs + | + | panic!("{}", msg); + | ^^^^^^^^^^^^^^^^^ +note: inside `verify_aggregate_attributes::` + --> $WORKSPACE/udf/src/wrapper.rs + | + | verify_aggregate_attributes_name::(); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +note: inside `_` --> tests/fail/missing_rename.rs:25:1 | 25 | #[register] - | ^^^^^^^^^^^ not found in this scope - | - = note: this error originates in the attribute macro `register` (in Nightly builds, run with -Z macro-backtrace for more info) + | ^^^^^^^^^^^ + = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the attribute macro `register` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/udf-macros/tests/ok_agg_alias.rs b/udf-macros/tests/ok_agg_alias.rs new file mode 100644 index 0000000..da0026c --- /dev/null +++ b/udf-macros/tests/ok_agg_alias.rs @@ -0,0 +1,51 @@ +#![allow(unused)] + +use udf::prelude::*; + +struct MyUdf; + +#[register(name = "foo", alias = "bar")] +impl BasicUdf for MyUdf { + type Returns<'a> = Option; + + fn init(cfg: &UdfCfg, args: &ArgList) -> Result { + todo!(); + } + + fn process<'a>( + &'a mut self, + cfg: &UdfCfg, + args: &ArgList, + error: Option, + ) -> Result, ProcessError> { + todo!(); + } +} + +#[register(name = "foo", alias = "bar")] +impl AggregateUdf for MyUdf { + fn clear(&mut self, cfg: &UdfCfg, error: Option) -> Result<(), NonZeroU8> { + todo!() + } + fn add( + &mut self, + cfg: &UdfCfg, + args: &ArgList<'_, Process>, + error: Option, + ) -> Result<(), NonZeroU8> { + todo!() + } +} + +fn main() { + let _ = foo as *const (); + let _ = foo_init as *const (); + let _ = foo_deinit as *const (); + let _ = foo_add as *const (); + let _ = foo_clear as *const (); + let _ = bar as *const (); + let _ = bar_init as *const (); + let _ = bar_deinit as *const (); + let _ = bar_add as *const (); + let _ = bar_clear as *const (); +} diff --git a/udf/src/types/arg_list.rs b/udf/src/types/arg_list.rs index 6975df3..e47a303 100644 --- a/udf/src/types/arg_list.rs +++ b/udf/src/types/arg_list.rs @@ -19,7 +19,7 @@ use crate::{Init, SqlArg, UdfState}; #[repr(transparent)] pub struct ArgList<'a, S: UdfState>( /// `UnsafeCell` indicates to the compiler that this struct may have interior - /// mutability (i.e., cannot make som optimizations) + /// mutability (i.e., cannot make some optimizations) pub(super) UnsafeCell, /// We use this zero-sized marker to hold our state PhantomData<&'a S>, diff --git a/udf/src/types/sql_types.rs b/udf/src/types/sql_types.rs index 13a904e..e09d381 100644 --- a/udf/src/types/sql_types.rs +++ b/udf/src/types/sql_types.rs @@ -149,7 +149,7 @@ impl<'a> SqlResult<'a> { ptr: *const u8, tag: Item_result, len: usize, - ) -> Result, String> { + ) -> Result { // Handle nullptr right away here let marker = diff --git a/udf/src/wrapper.rs b/udf/src/wrapper.rs index 87799e0..01e69f4 100644 --- a/udf/src/wrapper.rs +++ b/udf/src/wrapper.rs @@ -3,11 +3,16 @@ //! Warning: This module should be considered unstable and generally not for //! public use +#[macro_use] +mod const_helpers; mod functions; mod helpers; mod modded_types; mod process; +use std::str; + +use const_helpers::{const_slice_eq, const_slice_to_str, const_str_eq}; pub use functions::{wrap_add, wrap_clear, wrap_deinit, wrap_init, wrap_remove, BufConverter}; pub(crate) use helpers::*; pub use modded_types::UDF_ARGSx; @@ -16,5 +21,107 @@ pub use process::{ wrap_process_buf_option_ref, }; +/// A trait implemented by the proc macro +// FIXME: on unimplemented +pub trait RegisteredBasicUdf { + /// The main function name + const NAME: &'static str; + /// Aliases, if any + const ALIASES: &'static [&'static str]; + /// True if `NAME` comes from the default value for the struct + const DEFAULT_NAME_USED: bool; +} + +/// Implemented by the proc macro. This is used to enforce that the basic UDF and aggregate +/// UDF have the same name and aliases. +pub trait RegisteredAggregateUdf: RegisteredBasicUdf { + /// The main function name + const NAME: &'static str; + /// Aliases, if any + const ALIASES: &'static [&'static str]; + /// True if `NAME` comes from the default value for the struct + const DEFAULT_NAME_USED: bool; +} + +const NAME_MSG: &str = "`#[register]` on `BasicUdf` and `AggregateUdf` must have the same "; + +/// Enforce that a struct has the same basic and aggregate UDF names. +pub const fn verify_aggregate_attributes() { + verify_aggregate_attributes_name::(); + verify_aggregate_attribute_aliases::(); +} + +const fn verify_aggregate_attributes_name() { + let basic_name = ::NAME; + let agg_name = ::NAME; + let basic_default_name = ::DEFAULT_NAME_USED; + let agg_default_name = ::DEFAULT_NAME_USED; + + if const_str_eq(basic_name, agg_name) { + return; + } + + let mut msg_buf = [0u8; 512]; + let mut curs = 0; + curs += const_write_all!( + msg_buf, + [NAME_MSG, "`name` argument; got `", basic_name, "`",], + curs + ); + + if basic_default_name { + curs += const_write_all!(msg_buf, [" (default from struct name)"], curs); + } + + curs += const_write_all!(msg_buf, [" and `", agg_name, "`"], curs); + + if agg_default_name { + curs += const_write_all!(msg_buf, [" (default from struct name)"], curs); + } + + let msg = const_slice_to_str(msg_buf.as_slice(), curs); + panic!("{}", msg); +} + +#[allow(clippy::cognitive_complexity)] +const fn verify_aggregate_attribute_aliases() { + let basic_aliases = ::ALIASES; + let agg_aliases = ::ALIASES; + + if const_slice_eq(basic_aliases, agg_aliases) { + return; + } + + let mut msg_buf = [0u8; 512]; + let mut curs = 0; + + curs += const_write_all!(msg_buf, [NAME_MSG, "`alias` arguments; got [",], 0); + + let mut i = 0; + while i < basic_aliases.len() { + if i > 0 { + curs += const_write_all!(msg_buf, [", "], curs); + } + curs += const_write_all!(msg_buf, ["`", basic_aliases[i], "`",], curs); + i += 1; + } + + curs += const_write_all!(msg_buf, ["] and ["], curs); + + let mut i = 0; + while i < agg_aliases.len() { + if i > 0 { + curs += const_write_all!(msg_buf, [", "], curs); + } + curs += const_write_all!(msg_buf, ["`", agg_aliases[i], "`",], curs); + i += 1; + } + + curs += const_write_all!(msg_buf, ["]"], curs); + + let msg = const_slice_to_str(msg_buf.as_slice(), curs); + panic!("{}", msg); +} + #[cfg(test)] mod tests; diff --git a/udf/src/wrapper/const_helpers.rs b/udf/src/wrapper/const_helpers.rs new file mode 100644 index 0000000..0050ffd --- /dev/null +++ b/udf/src/wrapper/const_helpers.rs @@ -0,0 +1,127 @@ +use std::str; + +/// Similar to `copy_from_slice` but works at comptime. +/// +/// Takes a `start` offset so we can index into an existing slice. +macro_rules! const_arr_copy { + ($dst:expr, $src:expr, $start:expr) => {{ + let max_idx = $dst.len() - $start; + let (to_write, add_ellipsis) = if $src.len() <= (max_idx.saturating_sub($start)) { + ($src.len(), false) + } else { + ($src.len().saturating_sub(4), true) + }; + + let mut i = 0; + while i < to_write { + $dst[i + $start] = $src[i]; + i += 1; + } + + if add_ellipsis { + while i < $dst.len() - $start { + $dst[i + $start] = b'.'; + i += 1; + } + } + + i + }}; +} + +macro_rules! const_write_all { + ($dst:expr, $src_arr:expr, $start:expr) => {{ + let mut offset = $start; + + let mut i = 0; + while i < $src_arr.len() && offset < $dst.len() { + offset += const_arr_copy!($dst, $src_arr[i].as_bytes(), offset); + i += 1; + } + + offset - $start + }}; +} + +pub const fn const_str_eq(a: &str, b: &str) -> bool { + let a = a.as_bytes(); + let b = b.as_bytes(); + if a.len() != b.len() { + return false; + } + + let mut i = 0; + while i < a.len() { + if a[i] != b[i] { + return false; + } + + i += 1; + } + + true +} + +pub const fn const_slice_eq(a: &[&str], b: &[&str]) -> bool { + if a.len() != b.len() { + return false; + } + + let mut i = 0; + while i < a.len() { + if !const_str_eq(a[i], b[i]) { + return false; + } + + i += 1; + } + + true +} + +pub const fn const_slice_to_str(s: &[u8], len: usize) -> &str { + assert!(len <= s.len()); + // FIXME(msrv): use const `split_at` once our MSRV gets to 1.71 + // SAFETY: validated inbounds above + let buf = unsafe { std::slice::from_raw_parts(s.as_ptr(), len) }; + + match str::from_utf8(buf) { + Ok(v) => v, + Err(_e) => panic!("utf8 error"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_arr_copy() { + let mut x = [0u8; 20]; + let w1 = const_arr_copy!(x, b"foobar", 0); + let s = const_slice_to_str(x.as_slice(), w1); + assert_eq!(s, "foobar"); + + let w2 = const_arr_copy!(x, b"foobar", w1); + let s = const_slice_to_str(x.as_slice(), w1 + w2); + assert_eq!(s, "foobarfoobar"); + + let mut x = [0u8; 6]; + let written = const_arr_copy!(x, b"foobar", 0); + let s = const_slice_to_str(x.as_slice(), written); + assert_eq!(s, "foobar"); + + let mut x = [0u8; 5]; + let written = const_arr_copy!(x, b"foobar", 0); + let s = const_slice_to_str(x.as_slice(), written); + assert_eq!(s, "fo..."); + } + + #[test] + fn test_const_write_all() { + let mut x = [0u8; 20]; + let w1 = const_write_all!(x, ["foo", "bar", "baz"], 0); + let s = const_slice_to_str(x.as_slice(), w1); + assert_eq!(s, "foobarbaz"); + } +} diff --git a/udf/src/wrapper/tests.rs b/udf/src/wrapper/tests.rs index 02c121d..e99a9c0 100644 --- a/udf/src/wrapper/tests.rs +++ b/udf/src/wrapper/tests.rs @@ -159,3 +159,58 @@ fn test_wrapper_bufwrapper() { wrap_init::(todo!(), todo!(), todo!()); } } + +#[test] +fn test_verify_aggregate_attributes() { + struct Foo; + impl RegisteredBasicUdf for Foo { + const NAME: &'static str = "foo"; + const ALIASES: &'static [&'static str] = &["foo", "bar"]; + const DEFAULT_NAME_USED: bool = false; + } + impl RegisteredAggregateUdf for Foo { + const NAME: &'static str = "foo"; + const ALIASES: &'static [&'static str] = &["foo", "bar"]; + const DEFAULT_NAME_USED: bool = false; + } + + verify_aggregate_attributes::(); +} + +#[test] +#[should_panic = "#[register]` on `BasicUdf` and `AggregateUdf` must have the same `name` \ + argument; got `foo` and `bar`"] +fn test_verify_aggregate_attributes_mismatch_name() { + struct Foo; + impl RegisteredBasicUdf for Foo { + const NAME: &'static str = "foo"; + const ALIASES: &'static [&'static str] = &["foo", "bar"]; + const DEFAULT_NAME_USED: bool = false; + } + impl RegisteredAggregateUdf for Foo { + const NAME: &'static str = "bar"; + const ALIASES: &'static [&'static str] = &["foo", "bar"]; + const DEFAULT_NAME_USED: bool = false; + } + + verify_aggregate_attributes::(); +} + +#[test] +#[should_panic = "`#[register]` on `BasicUdf` and `AggregateUdf` must have the same `alias` \ + arguments; got [`foo`, `bar`, `baz`] and [`foo`, `bar`]"] +fn test_verify_aggregate_attributes_mismatch_aliases() { + struct Foo; + impl RegisteredBasicUdf for Foo { + const NAME: &'static str = "foo"; + const ALIASES: &'static [&'static str] = &["foo", "bar", "baz"]; + const DEFAULT_NAME_USED: bool = false; + } + impl RegisteredAggregateUdf for Foo { + const NAME: &'static str = "foo"; + const ALIASES: &'static [&'static str] = &["foo", "bar"]; + const DEFAULT_NAME_USED: bool = false; + } + + verify_aggregate_attributes::(); +}