Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Ordered Sets better #428

Merged
merged 5 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgx-examples/aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl Aggregate for IntegerAvgState {
// const HYPOTHETICAL: bool = true;

// You can skip all these:
fn finalize(current: Self::State, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize {
fn finalize(current: Self::State, _direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize {
Self::finalize(current)
}

Expand Down
39 changes: 29 additions & 10 deletions pgx-macros/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,31 @@ impl FunctionSignatureRewriter {

let mut stream = proc_macro2::TokenStream::new();
let mut i = 0usize;
let mut have_fcinfo = false;
let mut fcinfo_ident = None;

// Get the fcinfo ident, if it exists.
// We do this because we need to get the **right** ident, if it exists, so Rustc
// doesn't think we're pointing at the fcinfo module path.
for arg in &self.func.sig.inputs {
match arg {
FnArg::Typed(ty) => match ty.pat.deref() {
Pat::Ident(ident) => {
if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo")
|| type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo")
{
if fcinfo_ident.is_some() {
panic!("When using `pg_sys::FunctionCallInfo` as an argument it must be the last argument");
}
fcinfo_ident = Some(ident.ident.clone());
}
},
_ => (),
},
_ => ()
}
}
let fcinfo_ident = fcinfo_ident.unwrap_or(syn::Ident::new("fcinfo", Span::call_site()));

for arg in &self.func.sig.inputs {
match arg {
FnArg::Receiver(_) => panic!("Functions that take self are not supported"),
Expand All @@ -733,33 +757,28 @@ impl FunctionSignatureRewriter {
let mut type_ = ty.ty.clone();
let is_option = type_matches(&type_, "Option");

if have_fcinfo {
panic!("When using `pg_sys::FunctionCallInfo` as an argument it must be the last argument")
}

let ts = if is_option {
let option_type = extract_option_type(&type_);
let mut option_type = syn::parse2::<syn::Type>(option_type).unwrap();
pgx_utils::anonymonize_lifetimes(&mut option_type);

quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg::<#option_type>(fcinfo, #i);
let #name = pgx::pg_getarg::<#option_type>(#fcinfo_ident, #i);
}
} else if type_matches(&type_, "pg_sys :: FunctionCallInfo")
|| type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo")
{
have_fcinfo = true;
quote_spanned! {ident.span()=>
let #name = fcinfo;
let #name = #fcinfo_ident;
}
} else if is_raw {
quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg_datum_raw(fcinfo, #i) as #type_;
let #name = pgx::pg_getarg_datum_raw(#fcinfo_ident, #i) as #type_;
}
} else {
pgx_utils::anonymonize_lifetimes(&mut type_);
quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg::<#type_>(fcinfo, #i).unwrap_or_else(|| panic!("{} is null", stringify!{#ident}));
let #name = pgx::pg_getarg::<#type_>(#fcinfo_ident, #i).unwrap_or_else(|| panic!("{} is null", stringify!{#ident}));
}
};

Expand Down
183 changes: 183 additions & 0 deletions pgx-tests/src/tests/aggregate_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
use pgx::*;
use serde::{Serialize, Deserialize};
use std::collections::HashSet;

#[derive(Copy, Clone, Default, Debug, PostgresType, Serialize, Deserialize)]
pub struct DemoSum {
count: i32,
}

#[pg_aggregate]
impl Aggregate for DemoSum {
const NAME: &'static str = "demo_sum";
const PARALLEL: Option<ParallelOption> = Some(pgx::aggregate::ParallelOption::Unsafe);
const INITIAL_CONDITION: Option<&'static str> = Some(r#"0"#);
const MOVING_INITIAL_CONDITION: Option<&'static str> = Some(r#"0"#);

type Args = i32;
type State = i32;
type MovingState = i32;

fn state(
mut current: Self::State,
arg: Self::Args,
_fcinfo: pg_sys::FunctionCallInfo,
) -> Self::State {
current += arg;
current
}

fn moving_state(
current: Self::State,
arg: Self::Args,
fcinfo: pg_sys::FunctionCallInfo,
) -> Self::MovingState {
Self::state(current, arg, fcinfo)
}

fn moving_state_inverse(
mut current: Self::State,
arg: Self::Args,
_fcinfo: pg_sys::FunctionCallInfo,
) -> Self::MovingState {
current -= arg;
current
}

fn combine(
mut first: Self::State,
second: Self::State,
_fcinfo: pg_sys::FunctionCallInfo,
) -> Self::State {
first += second;
first
}
}

#[derive(Copy, Clone, Default, Debug)]
pub struct DemoUnique;

#[pg_aggregate]
impl Aggregate for DemoUnique {
type Args = &'static str;
type State = Internal;
type Finalize = i32;

fn state(
mut current: Self::State,
arg: Self::Args,
_fcinfo: pg_sys::FunctionCallInfo,
) -> Self::State {
let inner = unsafe { current.get_or_insert_default::<HashSet<String>>() };

inner.insert(arg.to_string());
current
}

fn combine(
mut first: Self::State,
mut second: Self::State,
_fcinfo: pg_sys::FunctionCallInfo
) -> Self::State {
let first_inner = unsafe { first.get_or_insert_default::<HashSet<String>>() };
let second_inner = unsafe { second.get_or_insert_default::<HashSet<String>>() };

let unioned: HashSet<_> = first_inner.union(second_inner).collect();
Internal::new(unioned)
}

fn finalize(
mut current: Self::State,
_direct_args: Self::OrderedSetArgs,
_fcinfo: pg_sys::FunctionCallInfo
) -> Self::Finalize {
let inner = unsafe { current.get_or_insert_default::<HashSet<String>>() };

inner.len() as i32
}
}

#[derive(Copy, Clone, Default, Debug, PostgresType, Serialize, Deserialize)]
pub struct DemoPercentileDisc;

#[pg_aggregate]
impl Aggregate for DemoPercentileDisc {
type Args = name!(input, i32);
type State = Internal;
type Finalize = i32;
const ORDERED_SET: bool = true;
type OrderedSetArgs = name!(percentile, f64);

fn state(
mut current: Self::State,
arg: Self::Args,
_fcinfo: pg_sys::FunctionCallInfo,
) -> Self::State {
let inner = unsafe { current.get_or_insert_default::<Vec<i32>>() };

inner.push(arg);
current
}

fn finalize(
mut current: Self::State,
direct_arg: Self::OrderedSetArgs,
_fcinfo: pg_sys::FunctionCallInfo,
) -> Self::Finalize {
let inner = unsafe { current.get_or_insert_default::<Vec<i32>>() };
// This isn't done for us.
inner.sort();

let target_index = (inner.len() as f64 * direct_arg).round() as usize;
inner[target_index.saturating_sub(1)]
}
}

#[cfg(any(test, feature = "pg_test"))]
#[pgx::pg_schema]
mod tests {
#[allow(unused_imports)]
use crate as pgx_tests;
use pgx::*;


#[pg_test]
fn aggregate_demo_sum() {
let retval = Spi::get_one::<i32>(
"SELECT demo_sum(value) FROM UNNEST(ARRAY [1, 1, 2]) as value;"
).expect("SQL select failed");
assert_eq!(retval, 4);

// Moving-aggregate mode
let retval = Spi::get_one::<Vec<i32>>("
SELECT array_agg(calculated) FROM (
SELECT demo_sum(value) OVER (
ROWS BETWEEN 1 PRECEDING AND CURRENT ROW
) as calculated FROM UNNEST(ARRAY [1, 20, 300, 4000]) as value
) as results;
").expect("SQL select failed");
assert_eq!(retval, vec![1, 21, 320, 4300]);
}

#[pg_test]
fn aggregate_demo_unique() {
let retval = Spi::get_one::<i32>(
"SELECT DemoUnique(value) FROM UNNEST(ARRAY ['a', 'a', 'b']) as value;"
).expect("SQL select failed");
assert_eq!(retval, 2);
}

#[pg_test]
fn aggregate_demo_percentile_disc() {
// Example from https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES
let retval = Spi::get_one::<i32>(
"SELECT DemoPercentileDisc(0.5) WITHIN GROUP (ORDER BY income) FROM UNNEST(ARRAY [6000, 70000, 500]) as income;"
).expect("SQL select failed");
assert_eq!(retval, 6000);

let retval = Spi::get_one::<i32>(
"SELECT DemoPercentileDisc(0.05) WITHIN GROUP (ORDER BY income) FROM UNNEST(ARRAY [5, 100000000, 6000, 70000, 500]) as income;"
).expect("SQL select failed");
assert_eq!(retval, 5);
}
}
1 change: 1 addition & 0 deletions pgx-tests/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2020 ZomboDB, LLC <[email protected]>. All rights reserved. Use of this source code is
// governed by the MIT license that can be found in the LICENSE file.

mod aggregate_tests;
mod anyarray_tests;
mod array_tests;
mod bytea_tests;
Expand Down
20 changes: 18 additions & 2 deletions pgx-utils/src/sql_entity_graph/pg_aggregate/aggregate_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use syn::{
parse::{Parse, ParseStream},
parse_quote, Expr, Type,
};
use super::get_pgx_attr_macro;

#[derive(Debug, Clone)]
pub(crate) struct AggregateTypeList {
Expand Down Expand Up @@ -55,23 +56,38 @@ impl ToTokens for AggregateTypeList {
#[derive(Debug, Clone)]
pub(crate) struct AggregateType {
pub(crate) ty: Type,
/// The name, if it exists.
pub(crate) name: Option<String>,
}

impl AggregateType {
pub(crate) fn new(ty: syn::Type) -> Result<Self, syn::Error> {
let retval = Self { ty };
let name_tokens = get_pgx_attr_macro("name", &ty);
let name = match name_tokens {
Some(tokens) => {
let name_macro = syn::parse2::<crate::sql_entity_graph::pg_extern::NameMacro>(tokens)
.expect("Could not parse `name!()` macro");
Some(name_macro.ident)
},
None => None,
};
let retval = Self {
name,
ty,
};
Ok(retval)
}

pub(crate) fn entity_tokens(&self) -> Expr {
let ty = &self.ty;
let ty_string = ty.to_token_stream().to_string().replace(" ", "");
let name = self.name.iter();
parse_quote! {
pgx::datum::sql_entity_graph::aggregate::AggregateType {
ty_source: #ty_string,
ty_id: core::any::TypeId::of::<#ty>(),
full_path: core::any::type_name::<#ty>(),
name: None,
name: None#( .unwrap_or(Some(#name)) )*,
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use syn::{
parse::{Parse, ParseStream},
parse_quote, Expr, Type,
};
use super::get_pgx_attr_macro;

#[derive(Debug, Clone)]
pub(crate) struct MaybeNamedVariadicTypeList {
Expand Down Expand Up @@ -119,32 +120,6 @@ impl Parse for MaybeNamedVariadicType {
}
}


fn get_pgx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
match &ty {
syn::Type::Macro(ty_macro) => {
let mut found_pgx = false;
let mut found_attr = false;
// We don't actually have type resolution here, this is a "Best guess".
for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
match segment.ident.to_string().as_str() {
"pgx" if idx == 0 => found_pgx = true,
attr if attr == attr_name.as_ref() => found_attr = true,
_ => (),
}
}
if (ty_macro.mac.path.segments.len() == 1 && found_attr)
|| (found_pgx && found_attr)
{
Some(ty_macro.mac.tokens.clone())
} else {
None
}
}
_ => None,
}
}

#[cfg(test)]
mod tests {
use super::MaybeNamedVariadicTypeList;
Expand Down
Loading