Skip to content

Commit

Permalink
feat: extend sql generation via #[pgx(sql)] (#410)
Browse files Browse the repository at this point in the history
This commit extends the `#[pgx]` and `#[pg_extern]` attributes with a new
`sql` argument type that allows for customizing the behavior of SQL generation
on certain entities in three different ways:

1. `#[pgx(sql = false)]` disables generation of the decorated item's SQL
2. `#[pgx(sql = path::to::function)]` delegates responsibility for
   generating SQL to the named function
3. `#[pgx(sql = "RAW SQL CODE;")]` uses the provided SQL string in place
   of the SQL that would have been generated by default

In the 2nd case, the function has almost the same signature as the
`to_sql` function of the built-in `ToSql` trait, except the function
receives a reference to a `SqlGraphEntity` in place of `&self`. This
allows for extending any of the `SqlGraphEntity` variants using a single
function, as trying to use specific typed functions for each entity type
was deemed overly complex. This also works well with the way `ToSql` is
invoked today, which starts by calling the implementation on a value of
type `SqlGraphEntity`, so we can perform all the checks in a single
place.

As an aside, these custom callbacks can still delegate to the built-in
`ToSql` trait if desired. For example, if you wish to write the
generated SQL for specific entities to a different file.

The motivation for this is that in some edge cases, it is desirable to
elide or modify the SQL being generated. In our case specifically, we
need to chop up the generated SQL so that we can split out non-idempotent
operations separately from idempotent operations in order to properly
manage extension upgrades. Rather than lose the facilities pgx provides,
the `#[pgx(sql)]` attribute allows us to make ad-hoc adjustments to what,
when and where things get generated without needing any upstream support
for those details.
  • Loading branch information
bitwalker authored Feb 4, 2022
1 parent 982ab5e commit 6edfe10
Show file tree
Hide file tree
Showing 23 changed files with 850 additions and 235 deletions.
51 changes: 39 additions & 12 deletions pgx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ Optionally accepts the following attributes:
* `parallel_unsafe`: Corresponds to [`PARALLEL UNSAFE`](https://www.postgresql.org/docs/current/sql-createfunction.html).
* `parallel_restricted`: Corresponds to [`PARALLEL RESTRICTED`](https://www.postgresql.org/docs/current/sql-createfunction.html).
* `no_guard`: Do not use `#[pg_guard]` with the function.
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
Functions can accept and return any type which `pgx` supports. `pgx` supports many PostgreSQL types by default.
New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`].
Expand Down Expand Up @@ -597,7 +598,7 @@ enum DogNames {
```
*/
#[proc_macro_derive(PostgresEnum, attributes(requires))]
#[proc_macro_derive(PostgresEnum, attributes(requires, pgx))]
pub fn postgres_enum(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -683,9 +684,12 @@ Optionally accepts the following attributes:
* `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type.
* `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type.
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires))]
#[proc_macro_derive(
PostgresType,
attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx)
)]
pub fn postgres_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -911,12 +915,16 @@ enum DogNames {
Brandy,
}
```
Optionally accepts the following attributes:
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresEq)]
#[proc_macro_derive(PostgresEq, attributes(pgx))]
pub fn postgres_eq(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_eq(ast).into()
impl_postgres_eq(ast)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

/**
Expand All @@ -935,12 +943,16 @@ enum DogNames {
Brandy,
}
```
Optionally accepts the following attributes:
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresOrd)]
#[proc_macro_derive(PostgresOrd, attributes(pgx))]
pub fn postgres_ord(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_ord(ast).into()
impl_postgres_ord(ast)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

/**
Expand All @@ -956,12 +968,16 @@ enum DogNames {
Brandy,
}
```
Optionally accepts the following attributes:
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresHash)]
#[proc_macro_derive(PostgresHash, attributes(pgx))]
pub fn postgres_hash(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_hash(ast).into()
impl_postgres_hash(ast)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

/**
Expand Down Expand Up @@ -991,15 +1007,26 @@ pub fn pg_aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream {
}

/**
An inner attribute for [`#[pg_aggregate]`](macro@pg_aggregate).
A helper attribute for various contexts.
## Usage with [`#[pg_aggregate]`](macro@pg_aggregate).
It can be decorated on functions inside a [`#[pg_aggregate]`](macro@pg_aggregate) implementation.
In this position, it takes the same args as [`#[pg_extern]`](macro@pg_extern), and those args have the same effect.
Used outside of a [`#[pg_aggregate]`](macro@pg_aggregate), this does nothing.
## Usage for configuring SQL generation
This attribute can be used to control the behavior of the SQL generator on a decorated item,
e.g. `#[pgx(sql = false)]`
Currently `sql` can be provided one of the following:
* Disable SQL generation with `#[pgx(sql = false)]`
* Call custom SQL generator function with `#[pgx(sql = path::to_function)]`
* Render a specific fragment of SQL with a string `#[pgx(sql = "CREATE OR REPLACE FUNCTION ...")]`
*/
#[proc_macro_attribute]
pub fn pgx(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}

17 changes: 9 additions & 8 deletions pgx-macros/src/operators.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use pgx_utils::{operator_common::*, sql_entity_graph};

use quote::ToTokens;
use syn::DeriveInput;

pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(eq(&ast.ident));
stream.extend(ne(&ast.ident));

stream
Ok(stream)
}

pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(lt(&ast.ident));
Expand All @@ -20,19 +21,19 @@ pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream {
stream.extend(ge(&ast.ident));
stream.extend(cmp(&ast.ident));

let sql_graph_entity_item = sql_entity_graph::PostgresOrd::new(ast.ident.clone());
let sql_graph_entity_item = sql_entity_graph::PostgresOrd::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

stream
Ok(stream)
}

pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(hash(&ast.ident));

let sql_graph_entity_item = sql_entity_graph::PostgresHash::new(ast.ident.clone());
let sql_graph_entity_item = sql_entity_graph::PostgresHash::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

stream
Ok(stream)
}
17 changes: 13 additions & 4 deletions pgx-macros/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ use proc_macro2::{Ident, Span};
use quote::{quote, quote_spanned, ToTokens};
use std::ops::Deref;
use std::str::FromStr;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
FnArg, ForeignItem, ForeignItemFn, Generics, ItemFn, ItemForeignMod, Pat, ReturnType,
Signature, Type, Visibility,
Signature, Token, Type, Visibility,
};

pub struct PgGuardRewriter();
Expand Down Expand Up @@ -224,7 +225,11 @@ impl PgGuardRewriter {
let return_type =
proc_macro2::TokenStream::from_str(return_type.trim_start_matches("->")).unwrap();
let return_type = quote! {impl std::iter::Iterator<Item = #return_type>};
let attrs = entity_submission.unwrap().extern_attr_tokens();
let attrs = entity_submission
.unwrap()
.extern_attrs()
.iter()
.collect::<Punctuated<_, Token![,]>>();

func.sig.output = ReturnType::Default;
let sig = func.sig;
Expand Down Expand Up @@ -708,7 +713,9 @@ impl FunctionSignatureRewriter {
fn args(&self, is_raw: bool) -> proc_macro2::TokenStream {
if self.func.sig.inputs.len() == 1 && self.return_type_is_datum() {
if let FnArg::Typed(ty) = self.func.sig.inputs.first().unwrap() {
if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") || type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo") {
if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo")
|| type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo")
{
return proc_macro2::TokenStream::new();
}
}
Expand Down Expand Up @@ -738,7 +745,9 @@ impl FunctionSignatureRewriter {
quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg::<#option_type>(fcinfo, #i);
}
} else if type_matches(&type_, "pg_sys :: FunctionCallInfo") || type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo") {
} else if type_matches(&type_, "pg_sys :: FunctionCallInfo")
|| type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo")
{
have_fcinfo = true;
quote_spanned! {ident.span()=>
let #name = fcinfo;
Expand Down
112 changes: 112 additions & 0 deletions pgx-tests/src/tests/schema_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,66 @@ use pgx::*;

#[pgx::pg_schema]
mod test_schema {
use pgx::datum::sql_entity_graph::{PgxSql, SqlGraphEntity};
use pgx::*;
use serde::{Deserialize, Serialize};

#[pg_extern]
fn func_in_diff_schema() {}

#[pg_extern(sql = false)]
fn func_elided_from_schema() {}

#[pg_extern(sql = generate_function)]
fn func_generated_with_custom_sql() {}

#[derive(Debug, PostgresType, Serialize, Deserialize)]
pub struct TestType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[pgx(sql = false)]
pub struct ElidedType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[pgx(sql = generate_type)]
pub struct OtherType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[pgx(sql = "CREATE TYPE test_schema.ManuallyRenderedType;")]
pub struct OverriddenType(pub u64);

fn generate_function(
entity: &SqlGraphEntity,
_context: &PgxSql,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
if let SqlGraphEntity::Function(ref func) = entity {
Ok(format!("\
CREATE OR REPLACE FUNCTION test_schema.\"func_generated_with_custom_name\"() RETURNS void\n\
LANGUAGE c /* Rust */\n\
AS 'MODULE_PATHNAME', '{unaliased_name}_wrapper';\
",
unaliased_name = func.unaliased_name,
))
} else {
panic!("expected extern function entity, got {:?}", entity);
}
}

fn generate_type(
entity: &SqlGraphEntity,
_context: &PgxSql,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
if let SqlGraphEntity::Type(ref ty) = entity {
Ok(format!(
"\n\
CREATE TYPE test_schema.Custom{name};\
",
name = ty.name,
))
} else {
panic!("expected type entity, got {:?}", entity);
}
}
}

#[pg_extern(schema = "test_schema")]
Expand Down Expand Up @@ -44,4 +96,64 @@ mod tests {
fn test_type_in_different_schema() {
Spi::run("SELECT type_in_diff_schema();");
}

#[pg_test]
fn elided_extern_is_elided() {
// Validate that a function we know exists, exists
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_in_diff_schema');",
)
.expect("expected result");
assert_eq!(result, true);

// Validate that the function we expect not to exist, doesn't
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_elided_from_schema');",
)
.expect("expected result");
assert_eq!(result, false);
}

#[pg_test]
fn elided_type_is_elided() {
// Validate that a type we know exists, exists
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'testtype');")
.expect("expected result");
assert_eq!(result, true);

// Validate that the type we expect not to exist, doesn't
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'elidedtype');")
.expect("expected result");
assert_eq!(result, false);
}

#[pg_test]
fn custom_to_sql_extern() {
// Validate that the function we generated has the modifications we expect
let result: bool = Spi::get_one("SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_generated_with_custom_name');").expect("expected result");
assert_eq!(result, true);

Spi::run("SELECT test_schema.func_generated_with_custom_name();");
}

#[pg_test]
fn custom_to_sql_type() {
// Validate that the type we generated has the expected modifications
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'customothertype');")
.expect("expected result");
assert_eq!(result, true);
}

#[pg_test]
fn custom_handwritten_to_sql_type() {
// Validate that the SQL we provided was used
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'manuallyrenderedtype');",
)
.expect("expected result");
assert_eq!(result, true);
}
}
6 changes: 6 additions & 0 deletions pgx-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
let name = name[1..name.len() - 1].to_string();
args.insert(ExternArgs::Name(name.to_string()))
}
// Recognized, but not handled as an extern argument
"sql" => {
let _punc = itr.next().unwrap();
let _value = itr.next().unwrap();
false
}
_ => false,
};
}
Expand Down
4 changes: 4 additions & 0 deletions pgx-utils/src/sql_entity_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@ mod extension_sql;
mod pg_aggregate;
mod pg_extern;
mod pg_schema;
mod pgx_attribute;
mod positioning_ref;
mod postgres_enum;
mod postgres_hash;
mod postgres_ord;
mod postgres_type;
mod to_sql;

pub use super::ExternArgs;
pub use extension_sql::{ExtensionSql, ExtensionSqlFile, SqlDeclared};
pub use pg_aggregate::PgAggregate;
pub use pg_extern::{Argument, PgExtern, PgOperator};
pub use pg_schema::Schema;
pub use pgx_attribute::{ArgValue, NameValueArg, PgxArg, PgxAttribute};
pub use positioning_ref::PositioningRef;
pub use postgres_enum::PostgresEnum;
pub use postgres_hash::PostgresHash;
pub use postgres_ord::PostgresOrd;
pub use postgres_type::PostgresType;
pub use to_sql::ToSqlConfig;

/// Reexports for the pgx SQL generator binaries.
#[doc(hidden)]
Expand Down
Loading

0 comments on commit 6edfe10

Please sign in to comment.