Skip to content

Commit

Permalink
feat: extend sql generation via to_sql attribute
Browse files Browse the repository at this point in the history
This commit introduces a new `#[to_sql]` attribute that allows for
customizing the behavior of SQL generation on certain entities in two
different ways:

1. `#[to_sql(false)]` disables generation of the decorated item's SQL
2. `#[to_sql(path::to::function)]` delegates responsibility for
   generating SQL to the named function

In the latter case, the function has almost the same signature as the
`to_sql` function of the built-in `ToSql` trait, except the function
receives a reference to a `SqlGraphEntity` in place of `&self`. This
allows for extending any of the `SqlGraphEntity` variants using a single
function, as trying to use specific typed functions for each entity type
was deemed overly complex. This also works well with the way `ToSql` is
invoked today, which starts by calling the implementation on a value of
type `SqlGraphEntity`, so we can perform all the checks in a single
place.

As an aside, these custom callbacks can still delegate to the built-in
`ToSql` trait if desired. For example, if you wish to write the
generated SQL for specific entities to a different file.

The motivation for this is that in some edge cases, it is desirable to
elide or modify the SQL being generated. In our case specifically, we
need to chop up the generated SQL so that we can split out non-idempotent
operations separately from idempotent operations in order to properly
manage extension upgrades. Rather than lose the facilities pgx provides,
the `#[to_sql]` attribute allows us to make ad-hoc adjustments to what,
when and where things get generated without needing any upstream support
for those details.
  • Loading branch information
bitwalker committed Jan 28, 2022
1 parent 2452450 commit abebee0
Show file tree
Hide file tree
Showing 19 changed files with 475 additions and 55 deletions.
53 changes: 45 additions & 8 deletions pgx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,23 @@ pub fn merges(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}

/**
Helper attribute for non-derive contexts
Used to control the behavior of the SQL generator on the decorated item
Currently can be provided either a boolean to enable/disable SQL generation
for the item, or can be provided a path to a function with the expected signature.
Examples:
* Disable SQL generation with `#[to_sql(false)]`
* Call custom SQL generator function with `#[to_sql(path::to_function)]`
*/
#[proc_macro_attribute]
pub fn to_sql(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}

/**
Declare a Rust module and its contents to be in a schema.
Expand Down Expand Up @@ -597,7 +614,7 @@ enum DogNames {
```
*/
#[proc_macro_derive(PostgresEnum, attributes(requires))]
#[proc_macro_derive(PostgresEnum, attributes(requires, to_sql))]
pub fn postgres_enum(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -685,7 +702,10 @@ Optionally accepts the following attributes:
* `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type.
*/
#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires))]
#[proc_macro_derive(
PostgresType,
attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, to_sql)
)]
pub fn postgres_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -913,7 +933,7 @@ enum DogNames {
```
*/
#[proc_macro_derive(PostgresEq)]
#[proc_macro_derive(PostgresEq, attributes(to_sql))]
pub fn postgres_eq(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_eq(ast).into()
Expand All @@ -937,10 +957,19 @@ enum DogNames {
```
*/
#[proc_macro_derive(PostgresOrd)]
#[proc_macro_derive(PostgresOrd, attributes(to_sql))]
pub fn postgres_ord(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_ord(ast).into()
let to_sql_config = match sql_entity_graph::ToSqlConfig::from_attributes(ast.attrs.as_slice()) {
Err(e) => {
let msg = e.to_string();
return TokenStream::from(quote! {
compile_error!(#msg);
});
}
Ok(maybe_conf) => maybe_conf.unwrap_or_default(),
};
impl_postgres_ord(ast, to_sql_config).into()
}

/**
Expand All @@ -958,10 +987,19 @@ enum DogNames {
```
*/
#[proc_macro_derive(PostgresHash)]
#[proc_macro_derive(PostgresHash, attributes(to_sql))]
pub fn postgres_hash(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_hash(ast).into()
let to_sql_config = match sql_entity_graph::ToSqlConfig::from_attributes(ast.attrs.as_slice()) {
Err(e) => {
let msg = e.to_string();
return TokenStream::from(quote! {
compile_error!(#msg);
});
}
Ok(maybe_conf) => maybe_conf.unwrap_or_default(),
};
impl_postgres_hash(ast, to_sql_config).into()
}

/**
Expand Down Expand Up @@ -1002,4 +1040,3 @@ Used outside of a [`#[pg_aggregate]`](macro@pg_aggregate), this does nothing.
pub fn pgx(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}

16 changes: 12 additions & 4 deletions pgx-macros/src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> proc_macro2::TokenStream {
stream
}

pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_ord(
ast: DeriveInput,
to_sql_config: sql_entity_graph::ToSqlConfig,
) -> proc_macro2::TokenStream {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(lt(&ast.ident));
Expand All @@ -20,18 +23,23 @@ pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream {
stream.extend(ge(&ast.ident));
stream.extend(cmp(&ast.ident));

let sql_graph_entity_item = sql_entity_graph::PostgresOrd::new(ast.ident.clone());
let sql_graph_entity_item =
sql_entity_graph::PostgresOrd::new(ast.ident.clone(), to_sql_config);
sql_graph_entity_item.to_tokens(&mut stream);

stream
}

pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_hash(
ast: DeriveInput,
to_sql_config: sql_entity_graph::ToSqlConfig,
) -> proc_macro2::TokenStream {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(hash(&ast.ident));

let sql_graph_entity_item = sql_entity_graph::PostgresHash::new(ast.ident.clone());
let sql_graph_entity_item =
sql_entity_graph::PostgresHash::new(ast.ident.clone(), to_sql_config);
sql_graph_entity_item.to_tokens(&mut stream);

stream
Expand Down
94 changes: 94 additions & 0 deletions pgx-tests/src/tests/schema_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,58 @@ use pgx::*;

#[pgx::pg_schema]
mod test_schema {
use pgx::datum::sql_entity_graph::{PgxSql, SqlGraphEntity};
use pgx::*;
use serde::{Deserialize, Serialize};

#[pg_extern]
fn func_in_diff_schema() {}

#[pg_extern]
#[to_sql(false)]
fn func_elided_from_schema() {}

#[pg_extern]
#[to_sql(generate_function)]
fn func_generated_with_custom_sql() {}

#[derive(Debug, PostgresType, Serialize, Deserialize)]
pub struct TestType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[to_sql(false)]
pub struct ElidedType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[to_sql(generate_type)]
pub struct OtherType(pub u64);

fn generate_function(entity: &SqlGraphEntity, _context: &PgxSql) -> eyre::Result<String> {
if let SqlGraphEntity::Function(ref func) = entity {
Ok(format!("\
CREATE OR REPLACE FUNCTION test_schema.\"func_generated_with_custom_name\"() RETURNS void\n\
LANGUAGE c /* Rust */\n\
AS 'MODULE_PATHNAME', '{unaliased_name}_wrapper';\
",
unaliased_name = func.unaliased_name,
))
} else {
panic!("expected extern function entity, got {:?}", entity);
}
}

fn generate_type(entity: &SqlGraphEntity, _context: &PgxSql) -> eyre::Result<String> {
if let SqlGraphEntity::Type(ref ty) = entity {
Ok(format!(
"\n\
CREATE TYPE test_schema.Custom{name};\
",
name = ty.name,
))
} else {
panic!("expected type entity, got {:?}", entity);
}
}
}

#[pg_extern(schema = "test_schema")]
Expand Down Expand Up @@ -44,4 +88,54 @@ mod tests {
fn test_type_in_different_schema() {
Spi::run("SELECT type_in_diff_schema();");
}

#[pg_test]
fn elided_extern_is_elided() {
// Validate that a function we know exists, exists
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_in_diff_schema');",
)
.expect("expected result");
assert_eq!(result, true);

// Validate that the function we expect not to exist, doesn't
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_elided_from_schema');",
)
.expect("expected result");
assert_eq!(result, false);
}

#[pg_test]
fn elided_type_is_elided() {
// Validate that a type we know exists, exists
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'testtype');")
.expect("expected result");
assert_eq!(result, true);

// Validate that the type we expect not to exist, doesn't
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'elidedtype');")
.expect("expected result");
assert_eq!(result, false);
}

#[pg_test]
fn custom_to_sql_extern() {
// Validate that the function we generated has the modifications we expect
let result: bool = Spi::get_one("SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_generated_with_custom_name');").expect("expected result");
assert_eq!(result, true);

Spi::run("SELECT test_schema.func_generated_with_custom_name();");
}

#[pg_test]
fn custom_to_sql_type() {
// Validate that the type we generated has the expected modifications
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'customothertype');")
.expect("expected result");
assert_eq!(result, true);
}
}
2 changes: 2 additions & 0 deletions pgx-utils/src/sql_entity_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod postgres_enum;
mod postgres_hash;
mod postgres_ord;
mod postgres_type;
mod to_sql;

pub use super::ExternArgs;
pub use extension_sql::{ExtensionSql, ExtensionSqlFile, SqlDeclared};
Expand All @@ -18,6 +19,7 @@ pub use postgres_enum::PostgresEnum;
pub use postgres_hash::PostgresHash;
pub use postgres_ord::PostgresOrd;
pub use postgres_type::PostgresType;
pub use to_sql::ToSqlConfig;

/// Reexports for the pgx SQL generator binaries.
#[doc(hidden)]
Expand Down
7 changes: 7 additions & 0 deletions pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use syn::{
Expr,
};

use super::ToSqlConfig;

// We support only 32 tuples...
const ARG_NAMES: [&str; 32] = [
"arg_one",
Expand Down Expand Up @@ -77,10 +79,12 @@ pub struct PgAggregate {
fn_moving_state_inverse: Option<Ident>,
fn_moving_finalize: Option<Ident>,
hypothetical: bool,
to_sql_config: ToSqlConfig,
}

impl PgAggregate {
pub fn new(mut item_impl: ItemImpl) -> Result<Self, syn::Error> {
let to_sql_config = ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
let target_path = get_target_path(&item_impl)?;
let target_ident = get_target_ident(&target_path)?;
let snake_case_target_ident = Ident::new(
Expand Down Expand Up @@ -495,6 +499,7 @@ impl PgAggregate {
} else {
false
},
to_sql_config,
})
}

Expand Down Expand Up @@ -534,6 +539,7 @@ impl PgAggregate {
let fn_moving_state_iter = self.fn_moving_state.iter();
let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
let to_sql_config = &self.to_sql_config;

let entity_item_fn: ItemFn = parse_quote! {
#[no_mangle]
Expand Down Expand Up @@ -569,6 +575,7 @@ impl PgAggregate {
sortop: None#( .unwrap_or(Some(#const_sort_operator_iter)) )*,
parallel: None#( .unwrap_or(#const_parallel_iter) )*,
hypothetical: #hypothetical,
to_sql_config: #to_sql_config,
};
pgx::datum::sql_entity_graph::SqlGraphEntity::Aggregate(submission)
}
Expand Down
19 changes: 15 additions & 4 deletions pgx-utils/src/sql_entity_graph/pg_extern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ pub use argument::Argument;
use attribute::{Attribute, PgxAttributes};
pub use operator::PgOperator;
use operator::{PgxOperatorAttributeWithIdent, PgxOperatorOpName};
use returning::Returning;
pub(crate) use returning::NameMacro;
use returning::Returning;
use search_path::SearchPathList;

use eyre::WrapErr;
Expand All @@ -19,6 +19,8 @@ use std::convert::TryFrom;
use syn::parse::{Parse, ParseStream};
use syn::Meta;

use crate::sql_entity_graph::ToSqlConfig;

/// A parsed `#[pg_extern]` item.
///
/// It should be used with [`syn::parse::Parse`] functions.
Expand All @@ -45,6 +47,7 @@ pub struct PgExtern {
attrs: Option<PgxAttributes>,
attr_tokens: proc_macro2::TokenStream,
func: syn::ItemFn,
to_sql_config: ToSqlConfig,
}

impl PgExtern {
Expand Down Expand Up @@ -193,10 +196,13 @@ impl PgExtern {
pub fn new(attr: TokenStream2, item: TokenStream2) -> Result<Self, syn::Error> {
let attrs = syn::parse2::<PgxAttributes>(attr.clone()).ok();
let func = syn::parse2::<syn::ItemFn>(item)?;
let to_sql_config =
ToSqlConfig::from_attributes(func.attrs.as_slice())?.unwrap_or_default();
Ok(Self {
attrs: attrs,
attrs,
attr_tokens: attr,
func: func,
func,
to_sql_config,
})
}
}
Expand All @@ -222,6 +228,7 @@ impl ToTokens for PgExtern {
};
let operator = self.operator().into_iter();
let overridden = self.overridden().into_iter();
let to_sql_config = &self.to_sql_config;

let sql_graph_entity_fn_name =
syn::Ident::new(&format!("__pgx_internals_fn_{}", ident), Span::call_site());
Expand All @@ -243,6 +250,7 @@ impl ToTokens for PgExtern {
fn_return: #returns,
operator: None#( .unwrap_or(Some(#operator)) )*,
overridden: None#( .unwrap_or(Some(#overridden)) )*,
to_sql_config: #to_sql_config,
};
pgx::datum::sql_entity_graph::SqlGraphEntity::Function(submission)
}
Expand All @@ -254,12 +262,15 @@ impl ToTokens for PgExtern {
impl Parse for PgExtern {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
let attrs: Option<PgxAttributes> = input.parse().ok();
let func = input.parse()?;
let func: syn::ItemFn = input.parse()?;
let to_sql_config =
ToSqlConfig::from_attributes(func.attrs.as_slice())?.unwrap_or_default();
let attr_tokens: proc_macro2::TokenStream = attrs.clone().into_token_stream();
Ok(Self {
attrs,
attr_tokens,
func,
to_sql_config,
})
}
}
Loading

0 comments on commit abebee0

Please sign in to comment.