Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ members = [
"src/common/display",
"src/common/error",
"src/common/io-config",
"src/common/macros",
"src/common/partitioning",
"src/common/scan-info",
"src/common/system-info",
Expand Down
14 changes: 14 additions & 0 deletions src/common/macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[dependencies]
proc-macro-crate = "1.0"
proc-macro-error = "1.0"
proc-macro2 = "1.0"
quote = "1.0"
syn = "2.0"

[lib]
proc-macro = true

[package]
edition = {workspace = true}
name = "common-macros"
version = {workspace = true}
205 changes: 205 additions & 0 deletions src/common/macros/src/function_args.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
use proc_macro2::Ident;
use proc_macro_crate::{crate_name, FoundCrate};
use proc_macro_error::{abort, abort_call_site};
use quote::{format_ident, quote};
use syn::{
parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed, GenericArgument,
GenericParam, Generics, PathArguments, Type, TypePath,
};

#[derive(PartialEq, Eq)]
enum FieldType {
/// `T`
Required,
/// `Option<T>`
Optional,
/// `Vec<T>`
Variadic,
}

pub fn derive_function_args(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let daft_dsl = get_crate_name("daft-dsl");

let input = parse_macro_input!(input as DeriveInput);

let generic_ident = get_generic_ident(&input.generics);
let fields = get_fields(&input.data);

let num_fields = fields.named.len();

let name = input.ident;
let field_names = get_field_names(fields);
let field_types = get_field_types(fields, generic_ident);

// variadic args cannot be referenced by name
let named_args = field_names
.iter()
.zip(field_types.iter())
.filter_map(|(name, ty)| match ty {
FieldType::Required | FieldType::Optional => Some(name),
FieldType::Variadic => None,
});

let field_getters = field_names.iter().zip(field_types.iter()).map(|(name, ty)| {
match ty {
FieldType::Required => {
quote! {
unnamed
.pop_front()
.or_else(|| named.remove(stringify!(#name)))
.ok_or_else(|| common_error::DaftError::ValueError(format!("Required argument `{}` not found", stringify!(#name))))?
}
},
FieldType::Optional => {
quote! {
unnamed
.pop_front()
.or_else(|| named.remove(stringify!(#name)))
}
},
FieldType::Variadic => {
quote! {
unnamed.drain(..).collect()
}
},
}
});

let expanded = quote! {
impl<T> std::convert::TryFrom<#daft_dsl::functions::FunctionArgs<T>> for #name<T> {
type Error = common_error::DaftError;

fn try_from(args: #daft_dsl::functions::FunctionArgs<T>) -> common_error::DaftResult<Self> {
let (unnamed, mut named) = args.into_unnamed_and_named()?;
let mut unnamed = std::collections::VecDeque::from(unnamed);

let parsed = Self {
#(
#field_names: #field_getters,
)*
};

if !unnamed.is_empty() {
return std::result::Result::Err(
common_error::DaftError::ValueError(format!("Expected {} arguments, received: {}", #num_fields, #num_fields + unnamed.len()))
);
}

if !named.is_empty() {
return std::result::Result::Err(
common_error::DaftError::ValueError(format!("Expected argument names {}, received: {}",
[#(stringify!(#named_args)),*]
.into_iter()
.map(|s| format!("`{}`", s))
.collect::<Vec<_>>()
.join(", "),
named.keys()
.map(|s| format!("`{}`", s))
.collect::<Vec<_>>()
.join(", "),
))
);
}

std::result::Result::Ok(parsed)
}
}
};

proc_macro::TokenStream::from(expanded)
}

fn get_generic_ident(generics: &Generics) -> &Ident {
if generics.params.len() != 1 {
abort!(generics.span(), "expected one generic")

Check warning on line 114 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L114

Added line #L114 was not covered by tests
}

let param = &generics.params[0];

match param {
GenericParam::Type(type_param) => &type_param.ident,
_ => abort!(
param.span(),
"expected generic to be a simple type, such as <T>"
),

Check warning on line 124 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L121-L124

Added lines #L121 - L124 were not covered by tests
}
}

fn get_fields(data: &Data) -> &FieldsNamed {
let Data::Struct(data_struct) = data else {
abort_call_site!("can only derive structs")

Check warning on line 130 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L130

Added line #L130 was not covered by tests
};

let Fields::Named(fields_named) = &data_struct.fields else {
abort!(
data_struct.fields.span(),
"can only derive structs with named fields"
)

Check warning on line 137 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L134-L137

Added lines #L134 - L137 were not covered by tests
};

fields_named
}

fn get_field_names(fields: &FieldsNamed) -> Vec<&Ident> {
fields
.named
.iter()
.map(|f| f.ident.as_ref().expect("named fields have idents"))
.collect()
}

fn get_field_types(fields: &FieldsNamed, generic: &Ident) -> Vec<FieldType> {
fields
.named
.iter()
.map(|f| {
if let Type::Path(TypePath { qself: None, path }) = &f.ty {
if path.is_ident(generic) {
return FieldType::Required;
}

// check if type in the shape `S<T>` where `T` is the struct generic
if path.segments.len() == 1 {
let segment = path.segments.first().unwrap();

if let PathArguments::AngleBracketed(args) = &segment.arguments {
if args.args.len() == 1 {
let arg = args.args.first().unwrap();

if let GenericArgument::Type(Type::Path(TypePath {
qself: None,
path: inner_path,
})) = arg
&& inner_path.is_ident(generic)
{
// check for the value of `S` in the above comment
match segment.ident.to_string().as_str() {
"Option" => return FieldType::Optional,
"Vec" => return FieldType::Variadic,
_ => {}

Check warning on line 179 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L179

Added line #L179 was not covered by tests
}
}
}
}
}
}

Check warning on line 185 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L181-L185

Added lines #L181 - L185 were not covered by tests

abort!(
f.span(),
"field type must be `T`, `Option<T>`, or `Vec<T>`, where `T` is the struct generic"
)

Check warning on line 190 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L187-L190

Added lines #L187 - L190 were not covered by tests
})
.collect()
}

fn get_crate_name(orig_name: &str) -> Ident {
let crate_name = match crate_name(orig_name) {
Ok(FoundCrate::Itself) => "crate".to_string(),
Ok(FoundCrate::Name(name)) => name,
Err(_) => {
abort_call_site!("crate must be available in scope: `{}`", orig_name);

Check warning on line 200 in src/common/macros/src/function_args.rs

View check run for this annotation

Codecov / codecov/patch

src/common/macros/src/function_args.rs#L200

Added line #L200 was not covered by tests
}
};

format_ident!("{}", crate_name)
}
14 changes: 14 additions & 0 deletions src/common/macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#![feature(let_chains)]

mod function_args;

use proc_macro_error::proc_macro_error;

#[proc_macro_error]
#[proc_macro_derive(FunctionArgs)]
/// Proc macro for deriving `TryFrom<FunctionArgs<T>>` for an argument struct.
///
/// Tests are located at `src/daft-dsl/src/functions/macro_tests.rs`
pub fn derive_function_args(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
function_args::derive_function_args(input)
}
4 changes: 4 additions & 0 deletions src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ bincode = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-io-config = {path = "../common/io-config", default-features = false}
common-macros = {path = "../common/macros"}
common-py-serde = {path = "../common/py-serde", default-features = false}
common-resource-request = {path = "../common/resource-request", default-features = false}
common-treenode = {path = "../common/treenode", default-features = false}
Expand All @@ -16,6 +17,9 @@ pyo3 = {workspace = true, optional = true}
serde = {workspace = true}
typetag = {workspace = true}

[dev-dependencies]
rstest = {workspace = true}

[features]
python = [
"dep:pyo3",
Expand Down
Loading
Loading