diff --git a/pgx-examples/aggregate/src/lib.rs b/pgx-examples/aggregate/src/lib.rs index fb6c5593f..39dab287e 100644 --- a/pgx-examples/aggregate/src/lib.rs +++ b/pgx-examples/aggregate/src/lib.rs @@ -90,7 +90,7 @@ impl Aggregate for IntegerAvgState { // const HYPOTHETICAL: bool = true; // You can skip all these: - fn finalize(current: Self::State, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize { + fn finalize(current: Self::State, _direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize { Self::finalize(current) } diff --git a/pgx-macros/src/rewriter.rs b/pgx-macros/src/rewriter.rs index faeb05b1b..646175716 100644 --- a/pgx-macros/src/rewriter.rs +++ b/pgx-macros/src/rewriter.rs @@ -723,7 +723,31 @@ impl FunctionSignatureRewriter { let mut stream = proc_macro2::TokenStream::new(); let mut i = 0usize; - let mut have_fcinfo = false; + let mut fcinfo_ident = None; + + // Get the fcinfo ident, if it exists. + // We do this because we need to get the **right** ident, if it exists, so Rustc + // doesn't think we're pointing at the fcinfo module path. + for arg in &self.func.sig.inputs { + match arg { + FnArg::Typed(ty) => match ty.pat.deref() { + Pat::Ident(ident) => { + if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") + || type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo") + { + if fcinfo_ident.is_some() { + panic!("When using `pg_sys::FunctionCallInfo` as an argument it must be the last argument"); + } + fcinfo_ident = Some(ident.ident.clone()); + } + }, + _ => (), + }, + _ => () + } + } + let fcinfo_ident = fcinfo_ident.unwrap_or(syn::Ident::new("fcinfo", Span::call_site())); + for arg in &self.func.sig.inputs { match arg { FnArg::Receiver(_) => panic!("Functions that take self are not supported"), @@ -733,33 +757,28 @@ impl FunctionSignatureRewriter { let mut type_ = ty.ty.clone(); let is_option = type_matches(&type_, "Option"); - if have_fcinfo { - panic!("When using `pg_sys::FunctionCallInfo` as an argument it must be the last argument") - } - let ts = if is_option { let option_type = extract_option_type(&type_); let mut option_type = syn::parse2::(option_type).unwrap(); pgx_utils::anonymonize_lifetimes(&mut option_type); quote_spanned! {ident.span()=> - let #name = pgx::pg_getarg::<#option_type>(fcinfo, #i); + let #name = pgx::pg_getarg::<#option_type>(#fcinfo_ident, #i); } } else if type_matches(&type_, "pg_sys :: FunctionCallInfo") || type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo") { - have_fcinfo = true; quote_spanned! {ident.span()=> - let #name = fcinfo; + let #name = #fcinfo_ident; } } else if is_raw { quote_spanned! {ident.span()=> - let #name = pgx::pg_getarg_datum_raw(fcinfo, #i) as #type_; + let #name = pgx::pg_getarg_datum_raw(#fcinfo_ident, #i) as #type_; } } else { pgx_utils::anonymonize_lifetimes(&mut type_); quote_spanned! {ident.span()=> - let #name = pgx::pg_getarg::<#type_>(fcinfo, #i).unwrap_or_else(|| panic!("{} is null", stringify!{#ident})); + let #name = pgx::pg_getarg::<#type_>(#fcinfo_ident, #i).unwrap_or_else(|| panic!("{} is null", stringify!{#ident})); } }; diff --git a/pgx-tests/src/tests/aggregate_tests.rs b/pgx-tests/src/tests/aggregate_tests.rs new file mode 100644 index 000000000..01a2c509f --- /dev/null +++ b/pgx-tests/src/tests/aggregate_tests.rs @@ -0,0 +1,183 @@ +use pgx::*; +use serde::{Serialize, Deserialize}; +use std::collections::HashSet; + +#[derive(Copy, Clone, Default, Debug, PostgresType, Serialize, Deserialize)] +pub struct DemoSum { + count: i32, +} + +#[pg_aggregate] +impl Aggregate for DemoSum { + const NAME: &'static str = "demo_sum"; + const PARALLEL: Option = Some(pgx::aggregate::ParallelOption::Unsafe); + const INITIAL_CONDITION: Option<&'static str> = Some(r#"0"#); + const MOVING_INITIAL_CONDITION: Option<&'static str> = Some(r#"0"#); + + type Args = i32; + type State = i32; + type MovingState = i32; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + current += arg; + current + } + + fn moving_state( + current: Self::State, + arg: Self::Args, + fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::MovingState { + Self::state(current, arg, fcinfo) + } + + fn moving_state_inverse( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::MovingState { + current -= arg; + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + first += second; + first + } +} + +#[derive(Copy, Clone, Default, Debug)] +pub struct DemoUnique; + +#[pg_aggregate] +impl Aggregate for DemoUnique { + type Args = &'static str; + type State = Internal; + type Finalize = i32; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + let inner = unsafe { current.get_or_insert_default::>() }; + + inner.insert(arg.to_string()); + current + } + + fn combine( + mut first: Self::State, + mut second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + let first_inner = unsafe { first.get_or_insert_default::>() }; + let second_inner = unsafe { second.get_or_insert_default::>() }; + + let unioned: HashSet<_> = first_inner.union(second_inner).collect(); + Internal::new(unioned) + } + + fn finalize( + mut current: Self::State, + _direct_args: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::Finalize { + let inner = unsafe { current.get_or_insert_default::>() }; + + inner.len() as i32 + } +} + +#[derive(Copy, Clone, Default, Debug, PostgresType, Serialize, Deserialize)] +pub struct DemoPercentileDisc; + +#[pg_aggregate] +impl Aggregate for DemoPercentileDisc { + type Args = name!(input, i32); + type State = Internal; + type Finalize = i32; + const ORDERED_SET: bool = true; + type OrderedSetArgs = name!(percentile, f64); + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + let inner = unsafe { current.get_or_insert_default::>() }; + + inner.push(arg); + current + } + + fn finalize( + mut current: Self::State, + direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = unsafe { current.get_or_insert_default::>() }; + // This isn't done for us. + inner.sort(); + + let target_index = (inner.len() as f64 * direct_arg).round() as usize; + inner[target_index.saturating_sub(1)] + } +} + +#[cfg(any(test, feature = "pg_test"))] +#[pgx::pg_schema] +mod tests { + #[allow(unused_imports)] + use crate as pgx_tests; + use pgx::*; + + + #[pg_test] + fn aggregate_demo_sum() { + let retval = Spi::get_one::( + "SELECT demo_sum(value) FROM UNNEST(ARRAY [1, 1, 2]) as value;" + ).expect("SQL select failed"); + assert_eq!(retval, 4); + + // Moving-aggregate mode + let retval = Spi::get_one::>(" + SELECT array_agg(calculated) FROM ( + SELECT demo_sum(value) OVER ( + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) as calculated FROM UNNEST(ARRAY [1, 20, 300, 4000]) as value + ) as results; + ").expect("SQL select failed"); + assert_eq!(retval, vec![1, 21, 320, 4300]); + } + + #[pg_test] + fn aggregate_demo_unique() { + let retval = Spi::get_one::( + "SELECT DemoUnique(value) FROM UNNEST(ARRAY ['a', 'a', 'b']) as value;" + ).expect("SQL select failed"); + assert_eq!(retval, 2); + } + + #[pg_test] + fn aggregate_demo_percentile_disc() { + // Example from https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES + let retval = Spi::get_one::( + "SELECT DemoPercentileDisc(0.5) WITHIN GROUP (ORDER BY income) FROM UNNEST(ARRAY [6000, 70000, 500]) as income;" + ).expect("SQL select failed"); + assert_eq!(retval, 6000); + + let retval = Spi::get_one::( + "SELECT DemoPercentileDisc(0.05) WITHIN GROUP (ORDER BY income) FROM UNNEST(ARRAY [5, 100000000, 6000, 70000, 500]) as income;" + ).expect("SQL select failed"); + assert_eq!(retval, 5); + } +} \ No newline at end of file diff --git a/pgx-tests/src/tests/mod.rs b/pgx-tests/src/tests/mod.rs index eaf485f99..5de04f223 100644 --- a/pgx-tests/src/tests/mod.rs +++ b/pgx-tests/src/tests/mod.rs @@ -1,6 +1,7 @@ // Copyright 2020 ZomboDB, LLC . All rights reserved. Use of this source code is // governed by the MIT license that can be found in the LICENSE file. +mod aggregate_tests; mod anyarray_tests; mod array_tests; mod bytea_tests; diff --git a/pgx-utils/src/sql_entity_graph/pg_aggregate/aggregate_type.rs b/pgx-utils/src/sql_entity_graph/pg_aggregate/aggregate_type.rs index ae0626708..bb2af8d19 100644 --- a/pgx-utils/src/sql_entity_graph/pg_aggregate/aggregate_type.rs +++ b/pgx-utils/src/sql_entity_graph/pg_aggregate/aggregate_type.rs @@ -4,6 +4,7 @@ use syn::{ parse::{Parse, ParseStream}, parse_quote, Expr, Type, }; +use super::get_pgx_attr_macro; #[derive(Debug, Clone)] pub(crate) struct AggregateTypeList { @@ -55,23 +56,38 @@ impl ToTokens for AggregateTypeList { #[derive(Debug, Clone)] pub(crate) struct AggregateType { pub(crate) ty: Type, + /// The name, if it exists. + pub(crate) name: Option, } impl AggregateType { pub(crate) fn new(ty: syn::Type) -> Result { - let retval = Self { ty }; + let name_tokens = get_pgx_attr_macro("name", &ty); + let name = match name_tokens { + Some(tokens) => { + let name_macro = syn::parse2::(tokens) + .expect("Could not parse `name!()` macro"); + Some(name_macro.ident) + }, + None => None, + }; + let retval = Self { + name, + ty, + }; Ok(retval) } pub(crate) fn entity_tokens(&self) -> Expr { let ty = &self.ty; let ty_string = ty.to_token_stream().to_string().replace(" ", ""); + let name = self.name.iter(); parse_quote! { pgx::datum::sql_entity_graph::aggregate::AggregateType { ty_source: #ty_string, ty_id: core::any::TypeId::of::<#ty>(), full_path: core::any::type_name::<#ty>(), - name: None, + name: None#( .unwrap_or(Some(#name)) )*, } } } diff --git a/pgx-utils/src/sql_entity_graph/pg_aggregate/maybe_variadic_type.rs b/pgx-utils/src/sql_entity_graph/pg_aggregate/maybe_variadic_type.rs index 0a6fbba63..3100e839b 100644 --- a/pgx-utils/src/sql_entity_graph/pg_aggregate/maybe_variadic_type.rs +++ b/pgx-utils/src/sql_entity_graph/pg_aggregate/maybe_variadic_type.rs @@ -4,6 +4,7 @@ use syn::{ parse::{Parse, ParseStream}, parse_quote, Expr, Type, }; +use super::get_pgx_attr_macro; #[derive(Debug, Clone)] pub(crate) struct MaybeNamedVariadicTypeList { @@ -119,32 +120,6 @@ impl Parse for MaybeNamedVariadicType { } } - -fn get_pgx_attr_macro(attr_name: impl AsRef, ty: &syn::Type) -> Option { - match &ty { - syn::Type::Macro(ty_macro) => { - let mut found_pgx = false; - let mut found_attr = false; - // We don't actually have type resolution here, this is a "Best guess". - for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() { - match segment.ident.to_string().as_str() { - "pgx" if idx == 0 => found_pgx = true, - attr if attr == attr_name.as_ref() => found_attr = true, - _ => (), - } - } - if (ty_macro.mac.path.segments.len() == 1 && found_attr) - || (found_pgx && found_attr) - { - Some(ty_macro.mac.tokens.clone()) - } else { - None - } - } - _ => None, - } -} - #[cfg(test)] mod tests { use super::MaybeNamedVariadicTypeList; diff --git a/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs b/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs index f7cd9671c..d2a11a612 100644 --- a/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs +++ b/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs @@ -61,9 +61,10 @@ pub struct PgAggregate { pg_externs: Vec, // Note these should not be considered *writable*, they're snapshots from construction. type_args: MaybeNamedVariadicTypeList, - type_order_by: Option, + type_ordered_set_args: Option, type_moving_state: Option, type_stype: AggregateType, + const_ordered_set: bool, const_parallel: Option, const_finalize_modify: Option, const_moving_finalize_modify: Option, @@ -145,7 +146,7 @@ impl PgAggregate { remap_self_to_target(&mut remapped, &target_ident); remapped }; - let type_stype = AggregateType { ty: type_state_without_self.clone(), }; + let type_stype = AggregateType { ty: type_state_without_self.clone(), name: Some("state".into()), }; // `MovingState` is an optional value, we default to nothing. let type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState"); @@ -157,15 +158,37 @@ impl PgAggregate { } // `OrderBy` is an optional value, we default to nothing. - let type_order_by = get_impl_type_by_name(&item_impl_snapshot, "OrderBy"); - let type_order_by_value = type_order_by + let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs"); + let type_ordered_set_args_value = type_ordered_set_args .map(|v| AggregateTypeList::new(v.ty.clone())) .transpose()?; - if type_order_by.is_none() { + if type_ordered_set_args.is_none() { item_impl.items.push(parse_quote! { - type OrderBy = (); + type OrderedSetArgs = (); }) } + let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) = type_ordered_set_args_value { + let direct_args = order_by_direct_args + .found + .iter() + .map(|x| (x.name.clone(), x.ty.clone())) + .collect::>(); + let direct_arg_names = ARG_NAMES[0..direct_args.len()] + .iter() + .zip(direct_args.iter()) + .map(|(default_name, (custom_name, _ty))| + Ident::new(&custom_name.clone().unwrap_or_else(|| default_name.to_string()), Span::mixed_site()) + ).collect::>(); + let direct_args_with_names = direct_args.iter().zip(direct_arg_names.iter()).map(|(arg, name)| { + let arg_ty = &arg.1; + parse_quote! { + #name: #arg_ty + } + }).collect::>(); + (direct_args_with_names, direct_arg_names) + } else { + (Vec::default(), Vec::default()) + }; // `Args` is an optional value, we default to nothing. let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| { @@ -262,20 +285,34 @@ impl PgAggregate { found.sig.ident.span(), ); let pg_extern_attr = pg_extern_attr(found); - pg_externs.push(parse_quote! { - #[allow(non_snake_case, clippy::too_many_arguments)] - #pg_extern_attr - fn #fn_name(this: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> <#target_path as pgx::Aggregate>::Finalize { - <#target_path as pgx::Aggregate>::in_memory_context( - fcinfo, - move |_context| <#target_path as pgx::Aggregate>::finalize(this, fcinfo) - ) - } - }); + + if direct_args_with_names.len() > 0 { + pg_externs.push(parse_quote! { + #[allow(non_snake_case, clippy::too_many_arguments)] + #pg_extern_attr + fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: pgx::pg_sys::FunctionCallInfo) -> <#target_path as pgx::Aggregate>::Finalize { + <#target_path as pgx::Aggregate>::in_memory_context( + fcinfo, + move |_context| <#target_path as pgx::Aggregate>::finalize(this, (#(#direct_arg_names),*), fcinfo) + ) + } + }); + } else { + pg_externs.push(parse_quote! { + #[allow(non_snake_case, clippy::too_many_arguments)] + #pg_extern_attr + fn #fn_name(this: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> <#target_path as pgx::Aggregate>::Finalize { + <#target_path as pgx::Aggregate>::in_memory_context( + fcinfo, + move |_context| <#target_path as pgx::Aggregate>::finalize(this, (), fcinfo) + ) + } + }); + }; Some(fn_name) } else { item_impl.items.push(parse_quote! { - fn finalize(current: #type_state_without_self, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize { + fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize { unimplemented!("Call to finalize on an aggregate which does not support it.") } }); @@ -403,7 +440,7 @@ impl PgAggregate { ) -> <#target_path as pgx::Aggregate>::MovingState { <#target_path as pgx::Aggregate>::in_memory_context( fcinfo, - move |_context| <#target_path as pgx::Aggregate>::moving_state(mstate, v, fcinfo) + move |_context| <#target_path as pgx::Aggregate>::moving_state_inverse(mstate, v, fcinfo) ) } }); @@ -428,13 +465,19 @@ impl PgAggregate { found.sig.ident.span(), ); let pg_extern_attr = pg_extern_attr(found); + let maybe_comma: Option = if direct_args_with_names.len() > 0 { + Some(parse_quote! {,}) + } else { + None + }; + pg_externs.push(parse_quote! { #[allow(non_snake_case, clippy::too_many_arguments)] #pg_extern_attr - fn #fn_name(mstate: <#target_path as pgx::Aggregate>::MovingState, fcinfo: pgx::pg_sys::FunctionCallInfo) -> <#target_path as pgx::Aggregate>::Finalize { + fn #fn_name(mstate: <#target_path as pgx::Aggregate>::MovingState, #(#direct_args_with_names),* #maybe_comma fcinfo: pgx::pg_sys::FunctionCallInfo) -> <#target_path as pgx::Aggregate>::Finalize { <#target_path as pgx::Aggregate>::in_memory_context( fcinfo, - move |_context| <#target_path as pgx::Aggregate>::moving_finalize(mstate, fcinfo) + move |_context| <#target_path as pgx::Aggregate>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo) ) } @@ -442,7 +485,7 @@ impl PgAggregate { Some(fn_name) } else { item_impl.items.push(parse_quote! { - fn moving_finalize(_mstate: Self::MovingState, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize { + fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize { unimplemented!("Call to moving_finalize on an aggregate which does not support it.") } }); @@ -454,7 +497,7 @@ impl PgAggregate { pg_externs, name, type_args: type_args_value, - type_order_by: type_order_by_value, + type_ordered_set_args: type_ordered_set_args_value, type_moving_state: type_moving_state_value, type_stype: type_stype, const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL") @@ -471,6 +514,8 @@ impl PgAggregate { "INITIAL_CONDITION", ) .and_then(get_const_litstr), + const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET") + .and_then(get_const_litbool).unwrap_or(false), const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR") .and_then(get_const_litstr), const_moving_intial_condition: get_impl_const_by_name( @@ -520,10 +565,11 @@ impl PgAggregate { let name = &self.name; let type_args_iter = &self.type_args.entity_tokens(); - let type_order_by_iter = self.type_order_by.iter().map(|x| x.entity_tokens()); + let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens()); let type_moving_state_iter = self.type_moving_state.iter(); let type_moving_state_string = self.type_moving_state.as_ref().map(|t| { t.to_token_stream().to_string().replace(" ", "") }); let type_stype = self.type_stype.entity_tokens(); + let const_ordered_set = self.const_ordered_set; let const_parallel_iter = self.const_parallel.iter(); let const_finalize_modify_iter = self.const_finalize_modify.iter(); let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter(); @@ -550,9 +596,10 @@ impl PgAggregate { file: file!(), line: line!(), name: #name, + ordered_set: #const_ordered_set, ty_id: core::any::TypeId::of::<#target_ident>(), args: #type_args_iter, - order_by: None#( .unwrap_or(Some(#type_order_by_iter)) )*, + direct_args: None#( .unwrap_or(Some(#type_order_by_args_iter)) )*, stype: #type_stype, sfunc: stringify!(#fn_state), combinefunc: None#( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*, @@ -728,6 +775,16 @@ fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a needle } +fn get_const_litbool<'a>(item: &'a ImplItemConst) -> Option { + match &item.expr { + syn::Expr::Lit(expr_lit) => match &expr_lit.lit { + syn::Lit::Bool(lit) => Some(lit.value()), + _ => None, + }, + _ => None, + } +} + fn get_const_litstr<'a>(item: &'a ImplItemConst) -> Option { match &item.expr { syn::Expr::Lit(expr_lit) => match &expr_lit.lit { @@ -780,6 +837,31 @@ fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) { } } +fn get_pgx_attr_macro(attr_name: impl AsRef, ty: &syn::Type) -> Option { + match &ty { + syn::Type::Macro(ty_macro) => { + let mut found_pgx = false; + let mut found_attr = false; + // We don't actually have type resolution here, this is a "Best guess". + for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() { + match segment.ident.to_string().as_str() { + "pgx" if idx == 0 => found_pgx = true, + attr if attr == attr_name.as_ref() => found_attr = true, + _ => (), + } + } + if (ty_macro.mac.path.segments.len() == 1 && found_attr) + || (found_pgx && found_attr) + { + Some(ty_macro.mac.tokens.clone()) + } else { + None + } + } + _ => None, + } +} + #[cfg(test)] mod tests { use super::PgAggregate; diff --git a/pgx/src/aggregate.rs b/pgx/src/aggregate.rs index 3492205a3..8a875f406 100644 --- a/pgx/src/aggregate.rs +++ b/pgx/src/aggregate.rs @@ -303,18 +303,22 @@ where /// with [`pgx::name!()`](crate::name), it must be used **inside** the [`pgx::name!()`](crate::name) macro. type Args; - /// The types of the order argument(s). + /// The types of the direct argument(s) to an ordered-set aggregate's `finalize`. + /// + /// **Only effective if `ORDERED_SET` is `true`.** /// /// For a single argument, provide the type directly. /// /// For multiple arguments, provide a tuple. + /// + /// For no arguments, don't set this, or set it to `()` (the default). /// /// `pgx` does not support `argname` as it is only used for documentation purposes. /// /// If the final argument is to be variadic, use `pgx::Variadic`. /// /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. - type OrderBy; + type OrderedSetArgs; /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. type Finalize; @@ -325,31 +329,44 @@ where /// The name of the aggregate. (eg. What you'd pass to `SELECT agg(col) FROM tab`.) const NAME: &'static str; - /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. + /// Set to true if this is an ordered set aggregate. + /// + /// If set, the `OrderedSetArgs` associated type becomes effective, this allows for + /// direct arguments to the `finalize` function. + /// + /// See https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES + /// for more information. + /// + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. + const ORDERED_SET: bool = false; + + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. const PARALLEL: Option = None; - /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. const FINALIZE_MODIFY: Option = None; - /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. const MOVING_FINALIZE_MODIFY: Option = None; - /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. const INITIAL_CONDITION: Option<&'static str> = None; - /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. const SORT_OPERATOR: Option<&'static str> = None; - /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. const MOVING_INITIAL_CONDITION: Option<&'static str> = None; - /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. + /// **Optional:** This const can be skipped, `#[pg_aggregate]` will create a stub. const HYPOTHETICAL: bool = false; fn state(current: Self::State, v: Self::Args, fcinfo: FunctionCallInfo) -> Self::State; + /// The `OrderedSetArgs` is `()` unless `ORDERED_SET` is `true` and `OrderedSetArgs` is configured. + /// /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. - fn finalize(current: Self::State, fcinfo: FunctionCallInfo) -> Self::Finalize; + fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, fcinfo: FunctionCallInfo) -> Self::Finalize; /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. fn combine(current: Self::State, _other: Self::State, fcinfo: FunctionCallInfo) -> Self::State; @@ -366,8 +383,10 @@ where /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args, fcinfo: FunctionCallInfo) -> Self::MovingState; + /// The `OrderedSetArgs` is `()` unless `ORDERED_SET` is `true` and `OrderedSetArgs` is configured. + /// /// **Optional:** This function can be skipped, `#[pg_aggregate]` will create a stub. - fn moving_finalize(_mstate: Self::MovingState, fcinfo: FunctionCallInfo) -> Self::Finalize; + fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, fcinfo: FunctionCallInfo) -> Self::Finalize; unsafe fn memory_context(fcinfo: FunctionCallInfo) -> Option { if fcinfo.is_null() { diff --git a/pgx/src/datum/sql_entity_graph/aggregate.rs b/pgx/src/datum/sql_entity_graph/aggregate.rs index 64ecc287d..63fb08531 100644 --- a/pgx/src/datum/sql_entity_graph/aggregate.rs +++ b/pgx/src/datum/sql_entity_graph/aggregate.rs @@ -29,15 +29,20 @@ pub struct PgAggregateEntity { pub name: &'static str, + /// If the aggregate is an ordered set aggregate. + /// + /// See [the PostgreSQL ordered set docs](https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES). + pub ordered_set: bool, + /// The `arg_data_type` list. /// /// Corresponds to `Args` in [`crate::aggregate::Aggregate`]. pub args: Vec, - /// The `ORDER BY arg_data_type` list. + /// The direct argument list, appearing before `ORDER BY` in ordered set aggregates. /// /// Corresponds to `OrderBy` in [`crate::aggregate::Aggregate`]. - pub order_by: Option>, + pub direct_args: Option>, /// The `STYPE` and `name` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html) /// @@ -310,7 +315,7 @@ impl ToSql for PgAggregateEntity { "\n\ -- {file}:{line}\n\ -- {full_path}\n\ - CREATE AGGREGATE {schema}{name} ({args}{maybe_order_by})\n\ + CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\ (\n\ \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\ \tSTYPE = {schema}{stype}{maybe_comma_after_stype} /* {stype_full_path} */\ @@ -367,13 +372,11 @@ impl ToSql for PgAggregateEntity { ); args.push(buf); } - String::from("\n") - + &args.join("\n") - + if self.order_by.is_none() { "\n" } else { "" } + "\n".to_string() + &args.join("\n") + "\n" }, - maybe_order_by = if let Some(order_by) = &self.order_by { + direct_args = if let Some(direct_args) = &self.direct_args { let mut args = Vec::new(); - for (idx, arg) in order_by.iter().enumerate() { + for (idx, arg) in direct_args.iter().enumerate() { let graph_index = context .graph .neighbors_undirected(self_index) @@ -386,9 +389,9 @@ impl ToSql for PgAggregateEntity { .ok_or_else(|| { eyre_err!("Could not find arg type in graph. Got: {:?}", arg) })?; - let needs_comma = idx < (order_by.len() - 1); + let needs_comma = idx < (direct_args.len() - 1); let buf = format!("\ - {schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\ + \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\ ", schema_prefix = context.schema_prefix_for(&graph_index), // First try to match on [`TypeId`] since it's most reliable. @@ -397,15 +400,23 @@ impl ToSql for PgAggregateEntity { arg.full_path, self.name ))?, + maybe_name = if let Some(name) = arg.name { + "\"".to_string() + name + "\" " + } else { "".to_string() }, maybe_comma = if needs_comma { ", " } else { " " }, full_path = arg.full_path, ); args.push(buf); } - String::from("\n\tORDER BY ") + &args.join("\n,") + "\n" + "\n".to_string() + &args.join("\n,") + "\n" } else { String::default() }, + maybe_order_by = if self.ordered_set { + "\tORDER BY" + } else { + "" + }, optional_attributes = if optional_attributes.len() == 0 { String::from("\n") } else { diff --git a/pgx/src/datum/sql_entity_graph/pgx_sql.rs b/pgx/src/datum/sql_entity_graph/pgx_sql.rs index 869a25525..2271a56b6 100644 --- a/pgx/src/datum/sql_entity_graph/pgx_sql.rs +++ b/pgx/src/datum/sql_entity_graph/pgx_sql.rs @@ -1236,7 +1236,7 @@ fn connect_aggregates( } } - for arg in item.order_by.as_ref().unwrap_or(&vec![]) { + for arg in item.direct_args.as_ref().unwrap_or(&vec![]) { let found = make_type_or_enum_connection( graph, "Aggregate",