Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework validation of names and aliases for aggregate UDFs #60

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,9 @@

## [Unreleased] - ReleaseDate

### Added

### Changed

### Removed



## [0.5.4] - 2023-09-10

### Added

### Changed

### Removed

Rework the validation of names and aliases for aggregate UDFs. This fixes an
issue where aliases could not be used for aggregate UDFs, and provides better
error messages.


## [0.5.4] - 2023-09-10
Expand Down
217 changes: 130 additions & 87 deletions udf-macros/src/register.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(unused_imports)]

use std::iter;

use heck::AsSnakeCase;
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
Expand All @@ -8,24 +10,17 @@ use syn::parse::{Parse, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, parse_quote, DeriveInput, Error, Expr, ExprLit, Ident, ImplItem,
ImplItemType, Item, ItemImpl, Lit, Meta, Path, PathSegment, Token, Type, TypePath,
ImplItemType, Item, ItemImpl, Lit, LitStr, Meta, Path, PathSegment, Token, Type, TypePath,
TypeReference,
};

use crate::match_variant;
use crate::types::{make_type_list, ImplType, RetType, TypeClass};

/// Create an identifier from another identifier, changing the name to snake case
macro_rules! format_ident_str {
($formatter:tt, $ident:ident) => {
Ident::new(format!($formatter, $ident).as_str(), Span::call_site())
};
}

/// Verify that an `ItemImpl` matches the end of any given path
///
/// implements `BasicUdf` (in any of its pathing options)
fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
fn impl_type(itemimpl: &ItemImpl) -> Option<ImplType> {
let implemented = &itemimpl.trait_.as_ref().unwrap().1.segments;

let basic_paths: [Punctuated<PathSegment, Token![::]>; 3] = [
Expand All @@ -39,9 +34,12 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
parse_quote! {AggregateUdf},
];

match expected {
ImplType::Basic => basic_paths.contains(implemented),
ImplType::Aggregate => arg_paths.contains(implemented),
if basic_paths.contains(implemented) {
Some(ImplType::Basic)
} else if arg_paths.contains(implemented) {
Some(ImplType::Aggregate)
} else {
None
}
}

Expand All @@ -57,14 +55,11 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
let parsed = parse_macro_input!(input as ItemImpl);

let impls_basic = impls_path(&parsed, ImplType::Basic);
let impls_agg = impls_path(&parsed, ImplType::Aggregate);

if !(impls_basic || impls_agg) {
let Some(impl_ty) = impl_type(&parsed) else {
return Error::new_spanned(&parsed, "Expected trait `BasicUdf` or `AggregateUdf`")
.into_compile_error()
.into();
}
};

// Full type path of our data struct
let Type::Path(dstruct_path) = parsed.self_ty.as_ref() else {
Expand All @@ -73,7 +68,7 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
.into();
};

let base_fn_names = match parse_args(args, dstruct_path) {
let parsed_meta = match ParsedMeta::parse(args, dstruct_path) {
Ok(v) => v,
Err(e) => return e.into_compile_error().into(),
};
Expand All @@ -89,91 +84,110 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
Span::call_site(),
);

let (ret_ty, wrapper_def) = if impls_basic {
match get_rt_and_wrapper(&parsed, dstruct_path, &wrapper_ident) {
let (ret_ty, wrapper_def) = match impl_ty {
ImplType::Basic => match get_ret_ty_and_wrapper(&parsed, dstruct_path, &wrapper_ident) {
Ok((r, w)) => (Some(r), w),
Err(e) => return e.into_compile_error().into(),
}
} else {
(None, TokenStream2::new())
},
ImplType::Aggregate => (None, TokenStream2::new()),
};

let content_iter = base_fn_names.iter().map(|base_fn_name| {
if impls_basic {
make_basic_fns(
ret_ty.as_ref().unwrap(),
base_fn_name,
dstruct_path,
&wrapper_ident,
)
} else {
make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident)
}
let helper_traits = make_helper_trait_impls(dstruct_path, &parsed_meta, impl_ty);

let fn_items_iter = parsed_meta.all_names().map(|base_fn_name| match impl_ty {
ImplType::Basic => make_basic_fns(
ret_ty.as_ref().unwrap(),
base_fn_name,
dstruct_path,
&wrapper_ident,
),
ImplType::Aggregate => make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident),
});

quote! {
#parsed

#wrapper_def

#( #content_iter )*
#helper_traits

#( #fn_items_iter )*
}
.into()
}

/// Parse attribute arguments. Returns an iterator of names
fn parse_args(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result<Vec<String>> {
let meta = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args.clone())?;
let mut base_fn_names: Vec<String> = vec![];
let mut primary_name_specified = false;

for m in meta {
let Meta::NameValue(mval) = m else {
return Err(Error::new_spanned(m, "expected `a = b atributes`"));
};
/// Arguments we parse from metadata or default to
struct ParsedMeta {
name: String,
aliases: Vec<String>,
default_name_used: bool,
}

if !mval.path.segments.iter().count() == 1 {
return Err(Error::new_spanned(mval.path, "unexpected path"));
}
impl ParsedMeta {
/// Parse attribute arguments. Returns an iterator of names
fn parse(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result<Self> {
let meta = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args.clone())?;
let mut name_from_attributes = None;
let mut aliases = Vec::new();

let key = mval.path.segments.first().unwrap();
for m in meta {
let Meta::NameValue(mval) = m else {
return Err(Error::new_spanned(m, "expected `a = b atributes`"));
};

let Expr::Lit(ExprLit {
lit: Lit::Str(value),
..
}) = mval.value
else {
return Err(Error::new_spanned(mval.value, "expected a literal string"));
};
if !mval.path.segments.iter().count() == 1 {
return Err(Error::new_spanned(mval.path, "unexpected path"));
}

if key.ident == "name" {
if primary_name_specified {
return Err(Error::new_spanned(key, "`name` can only be specified once"));
let key = mval.path.segments.first().unwrap();

let Expr::Lit(ExprLit {
lit: Lit::Str(value),
..
}) = mval.value
else {
return Err(Error::new_spanned(mval.value, "expected a literal string"));
};

if key.ident == "name" {
if name_from_attributes.is_some() {
return Err(Error::new_spanned(key, "`name` can only be specified once"));
}
name_from_attributes = Some(value.value());
} else if key.ident == "alias" {
aliases.push(value.value());
} else {
return Err(Error::new_spanned(
key,
"unexpected key (only `name` and `alias` are accepted)",
));
}
base_fn_names.push(value.value());
primary_name_specified = true;
} else if key.ident == "alias" {
base_fn_names.push(value.value());
} else {
return Err(Error::new_spanned(
key,
"unexpected key (only `name` and `alias` are accepted)",
));
}
}

if !primary_name_specified {
// If we don't have a name specified, use the type name as snake case
let ty_ident = &dstruct_path.path.segments.last().unwrap().ident;
let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string();
base_fn_names.push(fn_name);
let mut default_name_used = false;
let name = name_from_attributes.unwrap_or_else(|| {
// If we don't have a name specified, use the type name as snake case
let ty_ident = &dstruct_path.path.segments.last().unwrap().ident;
let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string();
default_name_used = true;
fn_name
});

Ok(Self {
name,
aliases,
default_name_used,
})
}

Ok(base_fn_names)
/// Iterate the basic name and all aliases
fn all_names(&self) -> impl Iterator<Item = &String> {
iter::once(&self.name).chain(self.aliases.iter())
}
}

/// Get the return type to use and a wrapper. Once per impl setup.
fn get_rt_and_wrapper(
fn get_ret_ty_and_wrapper(
parsed: &ItemImpl,
dstruct_path: &TypePath,
wrapper_ident: &Ident,
Expand Down Expand Up @@ -209,16 +223,50 @@ fn get_rt_and_wrapper(
Ok((ret_ty, wrapper_struct))
}

/// Make implementations for our helper/metadata traits
fn make_helper_trait_impls(
dstruct_path: &TypePath,
meta: &ParsedMeta,
impl_ty: ImplType,
) -> TokenStream2 {
let name = LitStr::new(&meta.name, Span::call_site());
let aliases = meta
.aliases
.iter()
.map(|alias| LitStr::new(alias.as_ref(), Span::call_site()));
let (trait_name, check_expr) = match impl_ty {
ImplType::Basic => (
quote! { ::udf::wrapper::RegisteredBasicUdf },
TokenStream2::new(),
),
ImplType::Aggregate => (
quote! { ::udf::wrapper::RegisteredAggregateUdf },
quote! { const _: () = ::udf::wrapper::verify_aggregate_attributes::<#dstruct_path>(); },
),
};
let default_name_used = meta.default_name_used;

quote! {
impl #trait_name for #dstruct_path {
const NAME: &'static str = #name;
const ALIASES: &'static [&'static str] = &[#( #aliases ),*];
const DEFAULT_NAME_USED: bool = #default_name_used;
}

#check_expr
}
}

/// Create the basic function signatures (`xxx_init`, `xxx_deinit`, `xxx`)
fn make_basic_fns(
rt: &RetType,
base_fn_name: &str,
dstruct_path: &TypePath,
wrapper_ident: &Ident,
) -> TokenStream2 {
let init_fn_name = format_ident_str!("{}_init", base_fn_name);
let deinit_fn_name = format_ident_str!("{}_deinit", base_fn_name);
let process_fn_name = format_ident_str!("{}", base_fn_name);
let init_fn_name = format_ident!("{}_init", base_fn_name);
let deinit_fn_name = format_ident!("{}_deinit", base_fn_name);
let process_fn_name = format_ident!("{}", base_fn_name);

let init_fn = make_init_fn(dstruct_path, wrapper_ident, &init_fn_name);
let deinit_fn = make_deinit_fn(dstruct_path, wrapper_ident, &deinit_fn_name);
Expand Down Expand Up @@ -269,9 +317,9 @@ fn make_agg_fns(
dstruct_path: &TypePath, // Name of the data structure
wrapper_ident: &Ident,
) -> TokenStream2 {
let clear_fn_name = format_ident_str!("{}_clear", base_fn_name);
let add_fn_name = format_ident_str!("{}_add", base_fn_name);
let remove_fn_name = format_ident_str!("{}_remove", base_fn_name);
let clear_fn_name = format_ident!("{}_clear", base_fn_name);
let add_fn_name = format_ident!("{}_add", base_fn_name);
let remove_fn_name = format_ident!("{}_remove", base_fn_name);

// Determine whether this re-implements `remove`
let impls_remove = &parsed
Expand All @@ -280,7 +328,6 @@ fn make_agg_fns(
.filter_map(match_variant!(ImplItem::Fn))
.map(|m| &m.sig.ident)
.any(|id| *id == "remove");
let base_fn_ident = Ident::new(base_fn_name, Span::call_site());

let clear_fn = make_clear_fn(dstruct_path, wrapper_ident, &clear_fn_name);
let add_fn = make_add_fn(dstruct_path, wrapper_ident, &add_fn_name);
Expand All @@ -295,10 +342,6 @@ fn make_agg_fns(
};

quote! {
// Sanity check that we implemented
#[allow(dead_code, non_upper_case_globals)]
const did_you_apply_the_same_aliases_to_the_BasicUdf_impl: *const () = #base_fn_ident as _;

#clear_fn

#add_fn
Expand Down
22 changes: 22 additions & 0 deletions udf-macros/tests/fail/agg_missing_basic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#![allow(unused)]

use udf::prelude::*;

struct MyUdf;

impl AggregateUdf for MyUdf {
// Required methods
fn clear(&mut self, cfg: &UdfCfg<Process>, error: Option<NonZeroU8>) -> Result<(), NonZeroU8> {
todo!()
}
fn add(
&mut self,
cfg: &UdfCfg<Process>,
args: &ArgList<'_, Process>,
error: Option<NonZeroU8>,
) -> Result<(), NonZeroU8> {
todo!()
}
}

fn main() {}
11 changes: 11 additions & 0 deletions udf-macros/tests/fail/agg_missing_basic.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
error[E0277]: the trait bound `MyUdf: BasicUdf` is not satisfied
--> tests/fail/agg_missing_basic.rs:7:23
|
7 | impl AggregateUdf for MyUdf {
| ^^^^^ the trait `BasicUdf` is not implemented for `MyUdf`
|
note: required by a bound in `udf::AggregateUdf`
--> $WORKSPACE/udf/src/traits.rs
|
| pub trait AggregateUdf: BasicUdf {
| ^^^^^^^^ required by this bound in `AggregateUdf`
Loading
Loading