Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 35 additions & 59 deletions articles/postgresql-aggregates-with-rust.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ CREATE AGGREGATE example_sum(integer)
);

SELECT example_sum(value) FROM UNNEST(ARRAY [1, 2, 3]) as value;
-- example_sum
-- example_sum
-- -------------
-- 6
-- (1 row)
Expand Down Expand Up @@ -91,7 +91,7 @@ CREATE AGGREGATE example_sum(integer)
);

SELECT example_sum(value) FROM generate_series(0, 4000) as value;
-- example_sum
-- example_sum
-- -------------
-- 8002000
-- (1 row)
Expand Down Expand Up @@ -128,7 +128,7 @@ CREATE AGGREGATE example_uniq(text)
);

SELECT example_uniq(value) FROM UNNEST(ARRAY ['a', 'a', 'b']) as value;
-- example_uniq
-- example_uniq
-- --------------
-- 2
-- (1 row)
Expand Down Expand Up @@ -162,7 +162,7 @@ SELECT example_concat(first, second, third) FROM
UNNEST(ARRAY ['a', 'b', 'c']) as first,
UNNEST(ARRAY ['1', '2', '3']) as second,
UNNEST(ARRAY ['!', '@', '#']) as third;
-- example_concat
-- example_concat
-- ---------------------------------------------------------------------------------------------------------------
-- {a1!,a2!,a3!,b1!,b2!,b3!,c1!,c2!,c3!,a1@,a2@,a3@,b1@,b2@,b3@,c1@,c2@,c3@,a1#,a2#,a3#,b1#,b2#,b3#,c1#,c2#,c3#}
-- (1 row)
Expand All @@ -174,7 +174,7 @@ See how we see `a1`, `b1`, and `c1`? Multiple arguments might not work as you ex
SELECT UNNEST(ARRAY ['a', 'b', 'c']) as first,
UNNEST(ARRAY ['1', '2', '3']) as second,
UNNEST(ARRAY ['!', '@', '#']) as third;
-- first | second | third
-- first | second | third
-- -------+--------+-------
-- a | 1 | !
-- b | 2 | @
Expand Down Expand Up @@ -202,7 +202,7 @@ It includes:

If a Rust toolchain is not already installed, please follow the instructions on [rustup.rs][rustup-rs].

You'll also [need to make sure you have some development libraries][pgrx-system-requirements] like `zlib` and `libclang`, as
You'll also [need to make sure you have some development libraries][pgrx-system-requirements] like `zlib` and `libclang`, as
`cargo pgrx init` will, by default, build it's own development PostgreSQL installs. Usually it's possible to
figure out if something is missing from error messages and then discover the required package for the system.

Expand Down Expand Up @@ -244,7 +244,7 @@ running SQL generator
psql (13.5)
Type "help" for help.

exploring_aggregates=#
exploring_aggregates=#
```

Observing the start of the `src/lib.rs` file, we can see the `pg_module_magic!()` and a function `hello_exploring_aggregates`:
Expand All @@ -268,13 +268,13 @@ CREATE EXTENSION exploring_aggregates;

\dx+ exploring_aggregates
-- Objects in extension "exploring_aggregates"
-- Object description
-- Object description
-- ---------------------------------------
-- function hello_exploring_aggregates()
-- (1 row)

SELECT hello_exploring_aggregates();
-- hello_exploring_aggregates
-- hello_exploring_aggregates
-- -----------------------------
-- Hello, exploring_aggregates
-- (1 row)
Expand Down Expand Up @@ -337,7 +337,7 @@ running SQL generator
This creates `sql/exploring_aggregates-0.0.0.sql`:

```sql
/*
/*
This file is auto generated by pgrx.

The ordering of items is not stable, it is driven by a dependency graph.
Expand Down Expand Up @@ -378,6 +378,10 @@ but it should be flexible enough for any use.
Aggregates in `pgrx` are defined by creating a type (this doesn't necessarily need to be the state type), then using the [`#[pg_aggregate]`][pgrx-pg_aggregate]
procedural macro on an [`pgrx::Aggregate`][pgrx-aggregate-aggregate] implementation for that type.

The aggregate name is specified through a type parameter that implements the `ToAggregateName` trait.
Or you can use a derive-macro `AggregateName` to automatically implement this trait for some type.
The default name of the aggregate will be taken as the name of the structure, but you can change this with the `aggregate_name` attribute

The [`pgrx::Aggregate`][pgrx-aggregate-aggregate] trait has quite a few items (`fn`s, `const`s, `type`s) that you can implement, but the procedural macro can fill in
stubs for all non-essential items. The state type (the implementation target by default) must have a [`#[derive(PostgresType)]`][pgrx-postgrestype] declaration,
or be a type PostgreSQL already knows about.
Expand All @@ -390,13 +394,14 @@ use serde::{Serialize, Deserialize};

pg_module_magic!();

#[derive(Copy, Clone, Default, Debug, PostgresType, Serialize, Deserialize)]
#[derive(Copy, Clone, Default, Debug, PostgresType, Serialize, Deserialize, AggregateName)]
#[aggregate_name = "DemoSum"]
pub struct DemoSum {
count: i32,
}

#[pg_aggregate]
impl Aggregate for DemoSum {
impl Aggregate<DemoSum> for DemoSum {
const INITIAL_CONDITION: Option<&'static str> = Some(r#"{ "count": 0 }"#);
type Args = i32;
fn state(
Expand All @@ -413,7 +418,7 @@ impl Aggregate for DemoSum {
We can review the generated SQL (generated via `cargo pgrx schema`):

```sql
/*
/*
This file is auto generated by pgrx.

The ordering of items is not stable, it is driven by a dependency graph.
Expand Down Expand Up @@ -495,7 +500,7 @@ running SQL generator
psql (13.5)
Type "help" for help.

exploring_aggregates=#
exploring_aggregates=#
```

Now we're connected via `psql`:
Expand All @@ -505,7 +510,7 @@ CREATE EXTENSION exploring_aggregates;
-- CREATE EXTENSION

SELECT DemoSum(value) FROM generate_series(0, 4000) as value;
-- demosum
-- demosum
-- -------------------
-- {"count":8002000}
-- (1 row)
Expand All @@ -518,11 +523,11 @@ Pretty cool!
Let's change the [`State`][pgrx-aggregate-aggregate-state] this time:

```rust
#[derive(Copy, Clone, Default, Debug)]
#[derive(Copy, Clone, Default, Debug, AggregateName)]
pub struct DemoSum;

#[pg_aggregate]
impl Aggregate for DemoSum {
impl Aggregate<DemoSum> for DemoSum {
const INITIAL_CONDITION: Option<&'static str> = Some(r#"0"#);
type Args = i32;
type State = i32;
Expand All @@ -542,7 +547,7 @@ Now when we run it:

```sql
SELECT DemoSum(value) FROM generate_series(0, 4000) as value;
-- demosum
-- demosum
-- ---------
-- 8002000
-- (1 row)
Expand All @@ -552,7 +557,7 @@ This is a fine reimplementation of `SUM` so far, but as we saw previously we nee

```rust
#[pg_aggregate]
impl Aggregate for DemoSum {
impl Aggregate<DemoSum> for DemoSum {
// ...
fn combine(
mut first: Self::State,
Expand All @@ -565,35 +570,6 @@ impl Aggregate for DemoSum {
}
```

We can also change the name of the generated aggregate, or set the [`PARALLEL`][pgrx-aggregate-aggregate-parallel] settings, for example:

```rust
#[pg_aggregate]
impl Aggregate for DemoSum {
// ...
const NAME: &'static str = "demo_sum";
const PARALLEL: Option<ParallelOption> = Some(pgrx::aggregate::ParallelOption::Unsafe);
// ...
}
```

This generates:

```sql
-- src/lib.rs:9
-- exploring_aggregates::DemoSum
CREATE AGGREGATE demo_sum (
integer /* i32 */
)
(
SFUNC = "demo_sum_state", /* exploring_aggregates::DemoSum::state */
STYPE = integer, /* i32 */
COMBINEFUNC = "demo_sum_combine", /* exploring_aggregates::DemoSum::combine */
INITCOND = '0', /* exploring_aggregates::DemoSum::INITIAL_CONDITION */
PARALLEL = UNSAFE /* exploring_aggregates::DemoSum::PARALLEL */
);
```

## Rust state types

It's possible to use a non-SQL (say, [`HashSet<String>`][std::collections::HashSet]) type as a state by using [`Internal`][pgrx::datum::Internal].
Expand All @@ -608,11 +584,11 @@ use std::collections::HashSet;

pg_module_magic!();

#[derive(Copy, Clone, Default, Debug)]
#[derive(Copy, Clone, Default, Debug, AggregateName)]
pub struct DemoUnique;

#[pg_aggregate]
impl Aggregate for DemoUnique {
impl Aggregate<DemoUnique> for DemoUnique {
type Args = &'static str;
type State = Internal;
type Finalize = i32;
Expand Down Expand Up @@ -656,7 +632,7 @@ We can test it:

```sql
SELECT DemoUnique(value) FROM UNNEST(ARRAY ['a', 'a', 'b']) as value;
-- demounique
-- demounique
-- ------------
-- 2
-- (1 row)
Expand All @@ -674,11 +650,11 @@ PostgreSQL also supports what are called [*Ordered-Set Aggregates*][postgresql-o
Let's create a simple `percentile_disc` reimplementation to get an idea of how to make one with `pgrx`. You'll notice we add [`ORDERED_SET = true`][pgrx::aggregate::Aggregate::ORDERED_SET] and set an (optional) [`OrderedSetArgs`][pgrx::aggregate::Aggregate::OrderedSetArgs], which determines the direct arguments.

```rust
#[derive(Copy, Clone, Default, Debug)]
#[derive(Copy, Clone, Default, Debug, AggregateName)]
pub struct DemoPercentileDisc;

#[pg_aggregate]
impl Aggregate for DemoPercentileDisc {
impl Aggregate<DemoPercentileDisc> for DemoPercentileDisc {
type Args = name!(input, i32);
type State = Internal;
type Finalize = i32;
Expand Down Expand Up @@ -732,13 +708,13 @@ We can test it like so:

```sql
SELECT DemoPercentileDisc(0.5) WITHIN GROUP (ORDER BY income) FROM UNNEST(ARRAY [6000, 70000, 500]) as income;
-- demopercentiledisc
-- demopercentiledisc
-- --------------------
-- 6000
-- (1 row)

SELECT DemoPercentileDisc(0.05) WITHIN GROUP (ORDER BY income) FROM UNNEST(ARRAY [5, 100000000, 6000, 70000, 500]) as income;
-- demopercentiledisc
-- demopercentiledisc
-- --------------------
-- 5
-- (1 row)
Expand Down Expand Up @@ -845,7 +821,7 @@ SELECT demo_sum(value) OVER (
-- LOG: moving_state(0, 20)
-- LOG: moving_state(0, 300)
-- LOG: moving_state(0, 4000)
-- demo_sum
-- demo_sum
-- ----------
-- 1
-- 20
Expand All @@ -864,7 +840,7 @@ SELECT demo_sum(value) OVER (
-- LOG: moving_state(1, 20)
-- LOG: moving_state(21, 300)
-- LOG: moving_state(321, 4000)
-- demo_sum
-- demo_sum
-- ----------
-- 4321
-- 4321
Expand All @@ -881,7 +857,7 @@ SELECT demo_sum(value) OVER (
-- LOG: moving_state(20, 300)
-- LOG: moving_state_inverse(320, 20)
-- LOG: moving_state(300, 4000)
-- demo_sum
-- demo_sum
-- ----------
-- 1
-- 21
Expand All @@ -899,7 +875,7 @@ SELECT demo_sum(value) OVER (
-- LOG: moving_state(0, 10000)
-- LOG: moving_state(10000, 1)
-- LOG: moving_state_inverse(10001, 10000)
-- demo_sum
-- demo_sum
-- ----------
-- 10001
-- 1
Expand Down
14 changes: 7 additions & 7 deletions pgrx-examples/aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use std::str::FromStr;

pgrx::pg_module_magic!(c"aggregate", pgrx::pg_sys::PG_VERSION);

#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize, AggregateName)]
#[aggregate_name = "DEMOAVG"]
#[pgvarlena_inoutfuncs]
#[derive(Default)]
pub struct IntegerAvgState {
Expand All @@ -27,9 +28,9 @@ pub struct IntegerAvgState {
impl IntegerAvgState {
#[inline(always)]
fn state(
mut current: <Self as Aggregate>::State,
arg: <Self as Aggregate>::Args,
) -> <Self as Aggregate>::State {
mut current: <Self as Aggregate<Self>>::State,
arg: <Self as Aggregate<Self>>::Args,
) -> <Self as Aggregate<Self>>::State {
if let Some(arg) = arg {
current.sum += arg;
current.n += 1;
Expand All @@ -38,7 +39,7 @@ impl IntegerAvgState {
}

#[inline(always)]
fn finalize(current: <Self as Aggregate>::State) -> <Self as Aggregate>::Finalize {
fn finalize(current: <Self as Aggregate<Self>>::State) -> <Self as Aggregate<Self>>::Finalize {
current.sum / current.n
}
}
Expand Down Expand Up @@ -74,10 +75,9 @@ impl PgVarlenaInOutFuncs for IntegerAvgState {
// In order to improve the testability of your code, it's encouraged to make this implementation
// call to your own functions which don't require a PostgreSQL made [`pgrx::pg_sys::FunctionCallInfo`].
#[pg_aggregate]
impl Aggregate for IntegerAvgState {
impl Aggregate<IntegerAvgState> for IntegerAvgState {
type State = PgVarlena<Self>;
type Args = pgrx::name!(value, Option<i32>);
const NAME: &'static str = "DEMOAVG";

const INITIAL_CONDITION: Option<&'static str> = Some("0,0");

Expand Down
45 changes: 45 additions & 0 deletions pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,51 @@ pub fn derive_postgres_hash(input: TokenStream) -> TokenStream {
deriving_postgres_hash(ast).unwrap_or_else(syn::Error::into_compile_error).into()
}

/// Derives the `ToAggregateName` trait.
#[proc_macro_derive(AggregateName, attributes(aggregate_name))]
pub fn derive_aggregate_name(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

impl_aggregate_name(ast).unwrap_or_else(|e| e.into_compile_error()).into()
}

fn impl_aggregate_name(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let name = &ast.ident;

let mut custom_name_value: Option<String> = None;

for attr in &ast.attrs {
if attr.path().is_ident("aggregate_name") {
let meta = &attr.meta;
match meta {
syn::Meta::NameValue(syn::MetaNameValue {
value: syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(s), .. }),
..
}) => {
custom_name_value = Some(s.value());
break;
}
_ => {
return Err(syn::Error::new_spanned(
attr,
"#[aggregate_name] must be in the form `#[aggregate_name = \"string_literal\"]`",
));
}
}
}
}

let name_str = custom_name_value.unwrap_or(name.to_string());

let expanded = quote! {
impl ::pgrx::aggregate::ToAggregateName for #name {
const NAME: &'static str = #name_str;
}
};

Ok(expanded)
}

/**
Declare a `pgrx::Aggregate` implementation on a type as able to used by Postgres as an aggregate.

Expand Down
Loading