From 6edfe10f8a6c2abcfbcb49d8002d65190b8f388d Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Feb 2022 21:38:45 +0000 Subject: [PATCH] feat: extend sql generation via #[pgx(sql)] (#410) This commit extends the `#[pgx]` and `#[pg_extern]` attributes with a new `sql` argument type that allows for customizing the behavior of SQL generation on certain entities in three different ways: 1. `#[pgx(sql = false)]` disables generation of the decorated item's SQL 2. `#[pgx(sql = path::to::function)]` delegates responsibility for generating SQL to the named function 3. `#[pgx(sql = "RAW SQL CODE;")]` uses the provided SQL string in place of the SQL that would have been generated by default In the 2nd case, the function has almost the same signature as the `to_sql` function of the built-in `ToSql` trait, except the function receives a reference to a `SqlGraphEntity` in place of `&self`. This allows for extending any of the `SqlGraphEntity` variants using a single function, as trying to use specific typed functions for each entity type was deemed overly complex. This also works well with the way `ToSql` is invoked today, which starts by calling the implementation on a value of type `SqlGraphEntity`, so we can perform all the checks in a single place. As an aside, these custom callbacks can still delegate to the built-in `ToSql` trait if desired. For example, if you wish to write the generated SQL for specific entities to a different file. The motivation for this is that in some edge cases, it is desirable to elide or modify the SQL being generated. In our case specifically, we need to chop up the generated SQL so that we can split out non-idempotent operations separately from idempotent operations in order to properly manage extension upgrades. Rather than lose the facilities pgx provides, the `#[pgx(sql)]` attribute allows us to make ad-hoc adjustments to what, when and where things get generated without needing any upstream support for those details. --- pgx-macros/src/lib.rs | 51 ++++-- pgx-macros/src/operators.rs | 17 +- pgx-macros/src/rewriter.rs | 17 +- pgx-tests/src/tests/schema_tests.rs | 112 +++++++++++++ pgx-utils/src/lib.rs | 6 + pgx-utils/src/sql_entity_graph/mod.rs | 4 + .../src/sql_entity_graph/pg_aggregate/mod.rs | 7 + .../sql_entity_graph/pg_extern/attribute.rs | 47 +++--- .../src/sql_entity_graph/pg_extern/mod.rs | 109 ++++++++----- .../src/sql_entity_graph/pgx_attribute.rs | 80 +++++++++ .../src/sql_entity_graph/postgres_enum.rs | 20 ++- .../src/sql_entity_graph/postgres_hash.rs | 35 ++-- .../src/sql_entity_graph/postgres_ord.rs | 35 ++-- .../src/sql_entity_graph/postgres_type.rs | 23 ++- pgx-utils/src/sql_entity_graph/to_sql.rs | 148 +++++++++++++++++ pgx/src/datum/sql_entity_graph/aggregate.rs | 23 ++- pgx/src/datum/sql_entity_graph/mod.rs | 105 ++++++++++++ .../datum/sql_entity_graph/pg_extern/mod.rs | 152 ++++++++---------- .../datum/sql_entity_graph/postgres_enum.rs | 3 +- .../datum/sql_entity_graph/postgres_hash.rs | 3 +- .../datum/sql_entity_graph/postgres_ord.rs | 3 +- .../datum/sql_entity_graph/postgres_type.rs | 3 +- .../sql_entity_graph/sql_graph_entity.rs | 82 ++++++---- 23 files changed, 850 insertions(+), 235 deletions(-) create mode 100644 pgx-utils/src/sql_entity_graph/pgx_attribute.rs create mode 100644 pgx-utils/src/sql_entity_graph/to_sql.rs diff --git a/pgx-macros/src/lib.rs b/pgx-macros/src/lib.rs index 7824fc416..7ed1ddb68 100644 --- a/pgx-macros/src/lib.rs +++ b/pgx-macros/src/lib.rs @@ -402,6 +402,7 @@ Optionally accepts the following attributes: * `parallel_unsafe`: Corresponds to [`PARALLEL UNSAFE`](https://www.postgresql.org/docs/current/sql-createfunction.html). * `parallel_restricted`: Corresponds to [`PARALLEL RESTRICTED`](https://www.postgresql.org/docs/current/sql-createfunction.html). * `no_guard`: Do not use `#[pg_guard]` with the function. +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). Functions can accept and return any type which `pgx` supports. `pgx` supports many PostgreSQL types by default. New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`]. @@ -597,7 +598,7 @@ enum DogNames { ``` */ -#[proc_macro_derive(PostgresEnum, attributes(requires))] +#[proc_macro_derive(PostgresEnum, attributes(requires, pgx))] pub fn postgres_enum(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -683,9 +684,12 @@ Optionally accepts the following attributes: * `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type. * `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type. - +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires))] +#[proc_macro_derive( + PostgresType, + attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx) +)] pub fn postgres_type(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -911,12 +915,16 @@ enum DogNames { Brandy, } ``` +Optionally accepts the following attributes: +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresEq)] +#[proc_macro_derive(PostgresEq, attributes(pgx))] pub fn postgres_eq(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - impl_postgres_eq(ast).into() + impl_postgres_eq(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -935,12 +943,16 @@ enum DogNames { Brandy, } ``` +Optionally accepts the following attributes: +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresOrd)] +#[proc_macro_derive(PostgresOrd, attributes(pgx))] pub fn postgres_ord(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - impl_postgres_ord(ast).into() + impl_postgres_ord(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -956,12 +968,16 @@ enum DogNames { Brandy, } ``` +Optionally accepts the following attributes: +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresHash)] +#[proc_macro_derive(PostgresHash, attributes(pgx))] pub fn postgres_hash(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - impl_postgres_hash(ast).into() + impl_postgres_hash(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -991,15 +1007,26 @@ pub fn pg_aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream { } /** -An inner attribute for [`#[pg_aggregate]`](macro@pg_aggregate). +A helper attribute for various contexts. + +## Usage with [`#[pg_aggregate]`](macro@pg_aggregate). It can be decorated on functions inside a [`#[pg_aggregate]`](macro@pg_aggregate) implementation. In this position, it takes the same args as [`#[pg_extern]`](macro@pg_extern), and those args have the same effect. -Used outside of a [`#[pg_aggregate]`](macro@pg_aggregate), this does nothing. +## Usage for configuring SQL generation + +This attribute can be used to control the behavior of the SQL generator on a decorated item, +e.g. `#[pgx(sql = false)]` + +Currently `sql` can be provided one of the following: + +* Disable SQL generation with `#[pgx(sql = false)]` +* Call custom SQL generator function with `#[pgx(sql = path::to_function)]` +* Render a specific fragment of SQL with a string `#[pgx(sql = "CREATE OR REPLACE FUNCTION ...")]` + */ #[proc_macro_attribute] pub fn pgx(_attr: TokenStream, item: TokenStream) -> TokenStream { item } - diff --git a/pgx-macros/src/operators.rs b/pgx-macros/src/operators.rs index 8e5bb9852..fb17c4447 100644 --- a/pgx-macros/src/operators.rs +++ b/pgx-macros/src/operators.rs @@ -1,17 +1,18 @@ use pgx_utils::{operator_common::*, sql_entity_graph}; + use quote::ToTokens; use syn::DeriveInput; -pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(eq(&ast.ident)); stream.extend(ne(&ast.ident)); - stream + Ok(stream) } -pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(lt(&ast.ident)); @@ -20,19 +21,19 @@ pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream { stream.extend(ge(&ast.ident)); stream.extend(cmp(&ast.ident)); - let sql_graph_entity_item = sql_entity_graph::PostgresOrd::new(ast.ident.clone()); + let sql_graph_entity_item = sql_entity_graph::PostgresOrd::from_derive_input(ast)?; sql_graph_entity_item.to_tokens(&mut stream); - stream + Ok(stream) } -pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(hash(&ast.ident)); - let sql_graph_entity_item = sql_entity_graph::PostgresHash::new(ast.ident.clone()); + let sql_graph_entity_item = sql_entity_graph::PostgresHash::from_derive_input(ast)?; sql_graph_entity_item.to_tokens(&mut stream); - stream + Ok(stream) } diff --git a/pgx-macros/src/rewriter.rs b/pgx-macros/src/rewriter.rs index c435173aa..faeb05b1b 100644 --- a/pgx-macros/src/rewriter.rs +++ b/pgx-macros/src/rewriter.rs @@ -8,10 +8,11 @@ use proc_macro2::{Ident, Span}; use quote::{quote, quote_spanned, ToTokens}; use std::ops::Deref; use std::str::FromStr; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ FnArg, ForeignItem, ForeignItemFn, Generics, ItemFn, ItemForeignMod, Pat, ReturnType, - Signature, Type, Visibility, + Signature, Token, Type, Visibility, }; pub struct PgGuardRewriter(); @@ -224,7 +225,11 @@ impl PgGuardRewriter { let return_type = proc_macro2::TokenStream::from_str(return_type.trim_start_matches("->")).unwrap(); let return_type = quote! {impl std::iter::Iterator}; - let attrs = entity_submission.unwrap().extern_attr_tokens(); + let attrs = entity_submission + .unwrap() + .extern_attrs() + .iter() + .collect::>(); func.sig.output = ReturnType::Default; let sig = func.sig; @@ -708,7 +713,9 @@ impl FunctionSignatureRewriter { fn args(&self, is_raw: bool) -> proc_macro2::TokenStream { if self.func.sig.inputs.len() == 1 && self.return_type_is_datum() { if let FnArg::Typed(ty) = self.func.sig.inputs.first().unwrap() { - if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") || type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo") { + if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") + || type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo") + { return proc_macro2::TokenStream::new(); } } @@ -738,7 +745,9 @@ impl FunctionSignatureRewriter { quote_spanned! {ident.span()=> let #name = pgx::pg_getarg::<#option_type>(fcinfo, #i); } - } else if type_matches(&type_, "pg_sys :: FunctionCallInfo") || type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo") { + } else if type_matches(&type_, "pg_sys :: FunctionCallInfo") + || type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo") + { have_fcinfo = true; quote_spanned! {ident.span()=> let #name = fcinfo; diff --git a/pgx-tests/src/tests/schema_tests.rs b/pgx-tests/src/tests/schema_tests.rs index 61bc6fe10..66ae5bede 100644 --- a/pgx-tests/src/tests/schema_tests.rs +++ b/pgx-tests/src/tests/schema_tests.rs @@ -4,14 +4,66 @@ use pgx::*; #[pgx::pg_schema] mod test_schema { + use pgx::datum::sql_entity_graph::{PgxSql, SqlGraphEntity}; use pgx::*; use serde::{Deserialize, Serialize}; #[pg_extern] fn func_in_diff_schema() {} + #[pg_extern(sql = false)] + fn func_elided_from_schema() {} + + #[pg_extern(sql = generate_function)] + fn func_generated_with_custom_sql() {} + #[derive(Debug, PostgresType, Serialize, Deserialize)] pub struct TestType(pub u64); + + #[derive(Debug, PostgresType, Serialize, Deserialize)] + #[pgx(sql = false)] + pub struct ElidedType(pub u64); + + #[derive(Debug, PostgresType, Serialize, Deserialize)] + #[pgx(sql = generate_type)] + pub struct OtherType(pub u64); + + #[derive(Debug, PostgresType, Serialize, Deserialize)] + #[pgx(sql = "CREATE TYPE test_schema.ManuallyRenderedType;")] + pub struct OverriddenType(pub u64); + + fn generate_function( + entity: &SqlGraphEntity, + _context: &PgxSql, + ) -> Result> { + if let SqlGraphEntity::Function(ref func) = entity { + Ok(format!("\ + CREATE OR REPLACE FUNCTION test_schema.\"func_generated_with_custom_name\"() RETURNS void\n\ + LANGUAGE c /* Rust */\n\ + AS 'MODULE_PATHNAME', '{unaliased_name}_wrapper';\ + ", + unaliased_name = func.unaliased_name, + )) + } else { + panic!("expected extern function entity, got {:?}", entity); + } + } + + fn generate_type( + entity: &SqlGraphEntity, + _context: &PgxSql, + ) -> Result> { + if let SqlGraphEntity::Type(ref ty) = entity { + Ok(format!( + "\n\ + CREATE TYPE test_schema.Custom{name};\ + ", + name = ty.name, + )) + } else { + panic!("expected type entity, got {:?}", entity); + } + } } #[pg_extern(schema = "test_schema")] @@ -44,4 +96,64 @@ mod tests { fn test_type_in_different_schema() { Spi::run("SELECT type_in_diff_schema();"); } + + #[pg_test] + fn elided_extern_is_elided() { + // Validate that a function we know exists, exists + let result: bool = Spi::get_one( + "SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_in_diff_schema');", + ) + .expect("expected result"); + assert_eq!(result, true); + + // Validate that the function we expect not to exist, doesn't + let result: bool = Spi::get_one( + "SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_elided_from_schema');", + ) + .expect("expected result"); + assert_eq!(result, false); + } + + #[pg_test] + fn elided_type_is_elided() { + // Validate that a type we know exists, exists + let result: bool = + Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'testtype');") + .expect("expected result"); + assert_eq!(result, true); + + // Validate that the type we expect not to exist, doesn't + let result: bool = + Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'elidedtype');") + .expect("expected result"); + assert_eq!(result, false); + } + + #[pg_test] + fn custom_to_sql_extern() { + // Validate that the function we generated has the modifications we expect + let result: bool = Spi::get_one("SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_generated_with_custom_name');").expect("expected result"); + assert_eq!(result, true); + + Spi::run("SELECT test_schema.func_generated_with_custom_name();"); + } + + #[pg_test] + fn custom_to_sql_type() { + // Validate that the type we generated has the expected modifications + let result: bool = + Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'customothertype');") + .expect("expected result"); + assert_eq!(result, true); + } + + #[pg_test] + fn custom_handwritten_to_sql_type() { + // Validate that the SQL we provided was used + let result: bool = Spi::get_one( + "SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'manuallyrenderedtype');", + ) + .expect("expected result"); + assert_eq!(result, true); + } } diff --git a/pgx-utils/src/lib.rs b/pgx-utils/src/lib.rs index 3b06d392f..8057f024f 100644 --- a/pgx-utils/src/lib.rs +++ b/pgx-utils/src/lib.rs @@ -311,6 +311,12 @@ pub fn parse_extern_attributes(attr: TokenStream) -> HashSet { let name = name[1..name.len() - 1].to_string(); args.insert(ExternArgs::Name(name.to_string())) } + // Recognized, but not handled as an extern argument + "sql" => { + let _punc = itr.next().unwrap(); + let _value = itr.next().unwrap(); + false + } _ => false, }; } diff --git a/pgx-utils/src/sql_entity_graph/mod.rs b/pgx-utils/src/sql_entity_graph/mod.rs index 04d4394a4..899e2a355 100644 --- a/pgx-utils/src/sql_entity_graph/mod.rs +++ b/pgx-utils/src/sql_entity_graph/mod.rs @@ -2,22 +2,26 @@ mod extension_sql; mod pg_aggregate; mod pg_extern; mod pg_schema; +mod pgx_attribute; mod positioning_ref; mod postgres_enum; mod postgres_hash; mod postgres_ord; mod postgres_type; +mod to_sql; pub use super::ExternArgs; pub use extension_sql::{ExtensionSql, ExtensionSqlFile, SqlDeclared}; pub use pg_aggregate::PgAggregate; pub use pg_extern::{Argument, PgExtern, PgOperator}; pub use pg_schema::Schema; +pub use pgx_attribute::{ArgValue, NameValueArg, PgxArg, PgxAttribute}; pub use positioning_ref::PositioningRef; pub use postgres_enum::PostgresEnum; pub use postgres_hash::PostgresHash; pub use postgres_ord::PostgresOrd; pub use postgres_type::PostgresType; +pub use to_sql::ToSqlConfig; /// Reexports for the pgx SQL generator binaries. #[doc(hidden)] diff --git a/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs b/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs index 223072d1f..f7cd9671c 100644 --- a/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs +++ b/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs @@ -14,6 +14,8 @@ use syn::{ Expr, }; +use super::ToSqlConfig; + // We support only 32 tuples... const ARG_NAMES: [&str; 32] = [ "arg_one", @@ -77,10 +79,12 @@ pub struct PgAggregate { fn_moving_state_inverse: Option, fn_moving_finalize: Option, hypothetical: bool, + to_sql_config: ToSqlConfig, } impl PgAggregate { pub fn new(mut item_impl: ItemImpl) -> Result { + let to_sql_config = ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default(); let target_path = get_target_path(&item_impl)?; let target_ident = get_target_ident(&target_path)?; let snake_case_target_ident = Ident::new( @@ -495,6 +499,7 @@ impl PgAggregate { } else { false }, + to_sql_config, }) } @@ -534,6 +539,7 @@ impl PgAggregate { let fn_moving_state_iter = self.fn_moving_state.iter(); let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter(); let fn_moving_finalize_iter = self.fn_moving_finalize.iter(); + let to_sql_config = &self.to_sql_config; let entity_item_fn: ItemFn = parse_quote! { #[no_mangle] @@ -569,6 +575,7 @@ impl PgAggregate { sortop: None#( .unwrap_or(Some(#const_sort_operator_iter)) )*, parallel: None#( .unwrap_or(#const_parallel_iter) )*, hypothetical: #hypothetical, + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Aggregate(submission) } diff --git a/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs b/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs index 5e932acfe..c0a6367f8 100644 --- a/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs +++ b/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs @@ -1,4 +1,4 @@ -use crate::sql_entity_graph::PositioningRef; +use crate::sql_entity_graph::{PositioningRef, ToSqlConfig}; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ @@ -7,29 +7,6 @@ use syn::{ Token, }; -#[derive(Debug, Clone)] -pub struct PgxAttributes { - pub attrs: Punctuated, -} - -impl Parse for PgxAttributes { - fn parse(input: ParseStream) -> Result { - Ok(Self { - attrs: input.parse_terminated(Attribute::parse)?, - }) - } -} - -impl ToTokens for PgxAttributes { - fn to_tokens(&self, tokens: &mut TokenStream2) { - let attrs = &self.attrs; - let quoted = quote! { - vec![#attrs] - }; - tokens.append_all(quoted); - } -} - #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum Attribute { Immutable, @@ -46,6 +23,7 @@ pub enum Attribute { Name(syn::LitStr), Cost(syn::Expr), Requires(Punctuated), + Sql(ToSqlConfig), } impl ToTokens for Attribute { @@ -85,6 +63,10 @@ impl ToTokens for Attribute { .collect::>(); quote! { pgx::datum::sql_entity_graph::ExternArgs::Requires(vec![#(#items_iter),*],) } } + // This attribute is handled separately + Attribute::Sql(_) => { + return; + } }; tokens.append_all(quoted); } @@ -129,6 +111,23 @@ impl Parse for Attribute { let _bracket = syn::bracketed!(content in input); Self::Requires(content.parse_terminated(PositioningRef::parse)?) } + "sql" => { + use crate::sql_entity_graph::ArgValue; + use syn::Lit; + + let _eq: Token![=] = input.parse()?; + match input.parse::()? { + ArgValue::Path(p) => Self::Sql(ToSqlConfig::from(p)), + ArgValue::Lit(Lit::Bool(b)) => Self::Sql(ToSqlConfig::from(b.value)), + ArgValue::Lit(Lit::Str(s)) => Self::Sql(ToSqlConfig::from(s)), + ArgValue::Lit(other) => { + return Err(syn::Error::new( + other.span(), + "expected boolean, path, or string literal", + )) + } + } + } _ => return Err(syn::Error::new(Span::call_site(), "Invalid option")), }; Ok(found) diff --git a/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs b/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs index 227659f07..6728f2d74 100644 --- a/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs +++ b/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs @@ -5,19 +5,22 @@ mod returning; mod search_path; pub use argument::Argument; -use attribute::{Attribute, PgxAttributes}; +use attribute::Attribute; pub use operator::PgOperator; use operator::{PgxOperatorAttributeWithIdent, PgxOperatorOpName}; -use returning::Returning; pub(crate) use returning::NameMacro; +use returning::Returning; use search_path::SearchPathList; use eyre::WrapErr; use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use std::convert::TryFrom; -use syn::parse::{Parse, ParseStream}; -use syn::Meta; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::punctuated::Punctuated; +use syn::{Meta, Token}; + +use crate::sql_entity_graph::ToSqlConfig; /// A parsed `#[pg_extern]` item. /// @@ -42,42 +45,35 @@ use syn::Meta; /// ``` #[derive(Debug, Clone)] pub struct PgExtern { - attrs: Option, - attr_tokens: proc_macro2::TokenStream, + attrs: Vec, func: syn::ItemFn, + to_sql_config: ToSqlConfig, } impl PgExtern { fn name(&self) -> String { self.attrs - .as_ref() - .and_then(|a| { - a.attrs.iter().find_map(|candidate| match candidate { - Attribute::Name(name) => Some(name.value()), - _ => None, - }) + .iter() + .find_map(|a| match a { + Attribute::Name(name) => Some(name.value()), + _ => None, }) .unwrap_or_else(|| self.func.sig.ident.to_string()) } fn schema(&self) -> Option { - self.attrs.as_ref().and_then(|a| { - a.attrs.iter().find_map(|candidate| match candidate { - Attribute::Schema(name) => Some(name.value()), - _ => None, - }) + self.attrs.iter().find_map(|a| match a { + Attribute::Schema(name) => Some(name.value()), + _ => None, }) } - fn extern_attrs(&self) -> Option<&PgxAttributes> { - self.attrs.as_ref() + pub fn extern_attrs(&self) -> &[Attribute] { + self.attrs.as_slice() } - pub fn extern_attr_tokens(&self) -> &proc_macro2::TokenStream { - &self.attr_tokens - } - - fn overridden(&self) -> Option { + fn overridden(&self) -> Option { + let mut span = None; let mut retval = None; let mut in_commented_sql_block = false; for attr in &self.func.attrs { @@ -88,7 +84,8 @@ impl PgExtern { Meta::Path(_) | Meta::List(_) => continue, Meta::NameValue(mnv) => mnv, }; - if let syn::Lit::Str(inner) = content.lit { + if let syn::Lit::Str(ref inner) = content.lit { + span.get_or_insert(content.lit.span()); if !in_commented_sql_block && inner.value().trim() == "```pgxsql" { in_commented_sql_block = true; } else if in_commented_sql_block && inner.value().trim() == "```" { @@ -105,7 +102,7 @@ impl PgExtern { } } } - retval + retval.map(|s| syn::LitStr::new(s.as_ref(), span.unwrap())) } fn operator(&self) -> Option { @@ -191,12 +188,27 @@ impl PgExtern { } pub fn new(attr: TokenStream2, item: TokenStream2) -> Result { - let attrs = syn::parse2::(attr.clone()).ok(); + let mut attrs = Vec::new(); + let mut to_sql_config: Option = None; + + let parser = Punctuated::::parse_terminated; + let punctuated_attrs = parser.parse2(attr)?; + for pair in punctuated_attrs.into_pairs() { + match pair.into_value() { + Attribute::Sql(config) => { + to_sql_config.get_or_insert(config); + } + attr => { + attrs.push(attr); + } + } + } + let func = syn::parse2::(item)?; Ok(Self { - attrs: attrs, - attr_tokens: attr, - func: func, + attrs, + func, + to_sql_config: to_sql_config.unwrap_or_default(), }) } } @@ -207,7 +219,7 @@ impl ToTokens for PgExtern { let name = self.name(); let schema = self.schema(); let schema_iter = schema.iter(); - let extern_attrs = self.extern_attrs(); + let extern_attrs = self.attrs.iter().collect::>(); let search_path = self.search_path().into_iter(); let inputs = self.inputs().unwrap(); let returns = match self.returns() { @@ -221,7 +233,14 @@ impl ToTokens for PgExtern { } }; let operator = self.operator().into_iter(); - let overridden = self.overridden().into_iter(); + let to_sql_config = match self.overridden() { + None => self.to_sql_config.clone(), + Some(content) => { + let mut config = self.to_sql_config.clone(); + config.content = Some(content); + config + } + }; let sql_graph_entity_fn_name = syn::Ident::new(&format!("__pgx_internals_fn_{}", ident), Span::call_site()); @@ -240,12 +259,12 @@ impl ToTokens for PgExtern { line: line!(), module_path: core::module_path!(), full_path: concat!(core::module_path!(), "::", stringify!(#ident)), - extern_attrs: #extern_attrs, + extern_attrs: vec![#extern_attrs], search_path: None#( .unwrap_or(Some(vec![#search_path])) )*, fn_args: vec![#(#inputs),*], fn_return: #returns, operator: None#( .unwrap_or(Some(#operator)) )*, - overridden: None#( .unwrap_or(Some(#overridden)) )*, + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Function(submission) } @@ -256,13 +275,27 @@ impl ToTokens for PgExtern { impl Parse for PgExtern { fn parse(input: ParseStream) -> Result { - let attrs: Option = input.parse().ok(); - let func = input.parse()?; - let attr_tokens: proc_macro2::TokenStream = attrs.clone().into_token_stream(); + let mut attrs = Vec::new(); + let mut to_sql_config: Option = None; + + let parser = Punctuated::::parse_terminated; + let punctuated_attrs = input.call(parser).ok().unwrap_or_default(); + for pair in punctuated_attrs.into_pairs() { + match pair.into_value() { + Attribute::Sql(config) => { + to_sql_config.get_or_insert(config); + } + attr => { + attrs.push(attr); + } + } + } + + let func: syn::ItemFn = input.parse()?; Ok(Self { attrs, - attr_tokens, func, + to_sql_config: to_sql_config.unwrap_or_default(), }) } } diff --git a/pgx-utils/src/sql_entity_graph/pgx_attribute.rs b/pgx-utils/src/sql_entity_graph/pgx_attribute.rs new file mode 100644 index 000000000..3fcf3faee --- /dev/null +++ b/pgx-utils/src/sql_entity_graph/pgx_attribute.rs @@ -0,0 +1,80 @@ +use syn::parse::{Parse, ParseStream}; +use syn::{parenthesized, punctuated::Punctuated}; +use syn::{token, Token}; + +/// This struct is intented to represent the contents of the `#[pgx]` attribute when parsed. +/// +/// The intended usage is to parse an `Attribute`, then use `attr.parse_args::()?` to +/// parse the contents of the attribute into this struct. +/// +/// We use this rather than `Attribute::parse_meta` because it is not supported to parse bare paths +/// as values of a `NameValueMeta`, and we want to support that to avoid conflating SQL strings with +/// paths-as-strings. We re-use as much of the standard `parse_meta` structure types as possible though. +pub struct PgxAttribute { + pub args: Vec, +} + +impl Parse for PgxAttribute { + fn parse(input: ParseStream<'_>) -> syn::Result { + let parser = Punctuated::::parse_terminated; + let punctuated = input.call(parser)?; + let args = punctuated + .into_pairs() + .map(|p| p.into_value()) + .collect::>(); + Ok(Self { args }) + } +} + +/// This enum is akin to `syn::Meta`, but supports a custom `NameValue` variant which allows +/// for bare paths in the value position. +pub enum PgxArg { + Path(syn::Path), + List(syn::MetaList), + NameValue(NameValueArg), +} + +impl Parse for PgxArg { + fn parse(input: ParseStream<'_>) -> syn::Result { + let path = input.parse::()?; + if input.peek(token::Paren) { + let content; + Ok(Self::List(syn::MetaList { + path, + paren_token: parenthesized!(content in input), + nested: content.parse_terminated(syn::NestedMeta::parse)?, + })) + } else if input.peek(Token![=]) { + Ok(Self::NameValue(NameValueArg { + path, + eq_token: input.parse()?, + value: input.parse()?, + })) + } else { + Ok(Self::Path(path)) + } + } +} + +/// This struct is akin to `syn::NameValueMeta`, but allows for more than just `syn::Lit` as a value. +pub struct NameValueArg { + pub path: syn::Path, + pub eq_token: syn::token::Eq, + pub value: ArgValue, +} + +/// This is the type of a value that can be used in the value position of a `name = value` attribute argument. +pub enum ArgValue { + Path(syn::Path), + Lit(syn::Lit), +} + +impl Parse for ArgValue { + fn parse(input: ParseStream<'_>) -> syn::Result { + if input.peek(syn::Lit) { + return Ok(Self::Lit(input.parse()?)); + } + + Ok(Self::Path(input.parse()?)) + } +} diff --git a/pgx-utils/src/sql_entity_graph/postgres_enum.rs b/pgx-utils/src/sql_entity_graph/postgres_enum.rs index 6c43880a9..b2c0bd1ef 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_enum.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_enum.rs @@ -6,6 +6,8 @@ use syn::{ }; use syn::{punctuated::Punctuated, Ident, Token}; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresEnum)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -33,6 +35,7 @@ pub struct PostgresEnum { name: Ident, generics: Generics, variants: Punctuated, + to_sql_config: ToSqlConfig, } impl PostgresEnum { @@ -40,15 +43,19 @@ impl PostgresEnum { name: Ident, generics: Generics, variants: Punctuated, + to_sql_config: ToSqlConfig, ) -> Self { Self { name, generics, variants, + to_sql_config, } } pub fn from_derive_input(derive_input: DeriveInput) -> Result { + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); let data_enum = match derive_input.data { syn::Data::Enum(data_enum) => data_enum, syn::Data::Union(_) | syn::Data::Struct(_) => { @@ -59,6 +66,7 @@ impl PostgresEnum { derive_input.ident, derive_input.generics, data_enum.variants, + to_sql_config, )) } } @@ -66,7 +74,14 @@ impl PostgresEnum { impl Parse for PostgresEnum { fn parse(input: ParseStream) -> Result { let parsed: ItemEnum = input.parse()?; - Ok(Self::new(parsed.ident, parsed.generics, parsed.variants)) + let to_sql_config = + ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default(); + Ok(Self::new( + parsed.ident, + parsed.generics, + parsed.variants, + to_sql_config, + )) } } @@ -84,6 +99,8 @@ impl ToTokens for PostgresEnum { let sql_graph_entity_fn_name = syn::Ident::new(&format!("__pgx_internals_enum_{}", name), Span::call_site()); + let to_sql_config = &self.to_sql_config; + let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -104,6 +121,7 @@ impl ToTokens for PostgresEnum { full_path: core::any::type_name::<#name #ty_generics>(), mappings, variants: vec![ #( stringify!(#variants) ),* ], + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Enum(submission) } diff --git a/pgx-utils/src/sql_entity_graph/postgres_hash.rs b/pgx-utils/src/sql_entity_graph/postgres_hash.rs index c06e50206..68a2919f5 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_hash.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_hash.rs @@ -2,9 +2,11 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ parse::{Parse, ParseStream}, - DeriveInput, Ident, ItemEnum, ItemStruct, + DeriveInput, Ident, }; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresHash)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -51,27 +53,36 @@ use syn::{ #[derive(Debug, Clone)] pub struct PostgresHash { pub name: Ident, + pub to_sql_config: ToSqlConfig, } impl PostgresHash { - pub fn new(name: Ident) -> Self { - Self { name } + pub fn new(name: Ident, to_sql_config: ToSqlConfig) -> Self { + Self { + name, + to_sql_config, + } } pub fn from_derive_input(derive_input: DeriveInput) -> Result { - Ok(Self::new(derive_input.ident)) + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); + Ok(Self::new(derive_input.ident, to_sql_config)) } } impl Parse for PostgresHash { fn parse(input: ParseStream) -> Result { - let parsed_enum: Result = input.parse(); - let parsed_struct: Result = input.parse(); - let ident = parsed_enum - .map(|x| x.ident) - .or_else(|_| parsed_struct.map(|x| x.ident)) - .map_err(|_| syn::Error::new(input.span(), "expected enum or struct"))?; - Ok(Self::new(ident)) + use syn::Item; + + let parsed = input.parse()?; + let (ident, attrs) = match &parsed { + Item::Enum(item) => (item.ident.clone(), item.attrs.as_slice()), + Item::Struct(item) => (item.ident.clone(), item.attrs.as_slice()), + _ => return Err(syn::Error::new(input.span(), "expected enum or struct")), + }; + let to_sql_config = ToSqlConfig::from_attributes(attrs)?.unwrap_or_default(); + Ok(Self::new(ident, to_sql_config)) } } @@ -82,6 +93,7 @@ impl ToTokens for PostgresHash { &format!("__pgx_internals_hash_{}", self.name), Span::call_site(), ); + let to_sql_config = &self.to_sql_config; let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -96,6 +108,7 @@ impl ToTokens for PostgresHash { full_path: core::any::type_name::<#name>(), module_path: module_path!(), id: TypeId::of::<#name>(), + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Hash(submission) } diff --git a/pgx-utils/src/sql_entity_graph/postgres_ord.rs b/pgx-utils/src/sql_entity_graph/postgres_ord.rs index d270030d1..d564cbd73 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_ord.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_ord.rs @@ -2,9 +2,11 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ parse::{Parse, ParseStream}, - DeriveInput, Ident, ItemEnum, ItemStruct, + DeriveInput, Ident, }; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresOrd)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -51,27 +53,36 @@ use syn::{ #[derive(Debug, Clone)] pub struct PostgresOrd { pub name: Ident, + pub to_sql_config: ToSqlConfig, } impl PostgresOrd { - pub fn new(name: Ident) -> Self { - Self { name } + pub fn new(name: Ident, to_sql_config: ToSqlConfig) -> Self { + Self { + name, + to_sql_config, + } } pub fn from_derive_input(derive_input: DeriveInput) -> Result { - Ok(Self::new(derive_input.ident)) + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); + Ok(Self::new(derive_input.ident, to_sql_config)) } } impl Parse for PostgresOrd { fn parse(input: ParseStream) -> Result { - let parsed_enum: Result = input.parse(); - let parsed_struct: Result = input.parse(); - let ident = parsed_enum - .map(|x| x.ident) - .or_else(|_| parsed_struct.map(|x| x.ident)) - .map_err(|_| syn::Error::new(input.span(), "expected enum or struct"))?; - Ok(Self::new(ident)) + use syn::Item; + + let parsed = input.parse()?; + let (ident, attrs) = match &parsed { + Item::Enum(item) => (item.ident.clone(), item.attrs.as_slice()), + Item::Struct(item) => (item.ident.clone(), item.attrs.as_slice()), + _ => return Err(syn::Error::new(input.span(), "expected enum or struct")), + }; + let to_sql_config = ToSqlConfig::from_attributes(attrs)?.unwrap_or_default(); + Ok(Self::new(ident, to_sql_config)) } } @@ -82,6 +93,7 @@ impl ToTokens for PostgresOrd { &format!("__pgx_internals_ord_{}", self.name), Span::call_site(), ); + let to_sql_config = &self.to_sql_config; let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -96,6 +108,7 @@ impl ToTokens for PostgresOrd { full_path: core::any::type_name::<#name>(), module_path: module_path!(), id: TypeId::of::<#name>(), + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Ord(submission) } diff --git a/pgx-utils/src/sql_entity_graph/postgres_type.rs b/pgx-utils/src/sql_entity_graph/postgres_type.rs index 805f0bdb5..2bd38bf32 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_type.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_type.rs @@ -9,6 +9,8 @@ use syn::{ DeriveInput, Generics, ItemStruct, }; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresType)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -37,15 +39,23 @@ pub struct PostgresType { generics: Generics, in_fn: Ident, out_fn: Ident, + to_sql_config: ToSqlConfig, } impl PostgresType { - pub fn new(name: Ident, generics: Generics, in_fn: Ident, out_fn: Ident) -> Self { + pub fn new( + name: Ident, + generics: Generics, + in_fn: Ident, + out_fn: Ident, + to_sql_config: ToSqlConfig, + ) -> Self { Self { generics, name, in_fn, out_fn, + to_sql_config, } } @@ -59,6 +69,8 @@ impl PostgresType { )) } }; + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); let funcname_in = Ident::new( &format!("{}_in", derive_input.ident).to_lowercase(), derive_input.ident.span(), @@ -72,6 +84,7 @@ impl PostgresType { derive_input.generics, funcname_in, funcname_out, + to_sql_config, )) } @@ -93,6 +106,8 @@ impl PostgresType { impl Parse for PostgresType { fn parse(input: ParseStream) -> Result { let parsed: ItemStruct = input.parse()?; + let to_sql_config = + ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default(); let funcname_in = Ident::new( &format!("{}_in", parsed.ident).to_lowercase(), parsed.ident.span(), @@ -106,6 +121,7 @@ impl Parse for PostgresType { parsed.generics, funcname_in, funcname_out, + to_sql_config, )) } } @@ -127,6 +143,8 @@ impl ToTokens for PostgresType { Span::call_site(), ); + let to_sql_config = &self.to_sql_config; + let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -172,7 +190,8 @@ impl ToTokens for PostgresType { let mut path_items: Vec<_> = out_fn.split("::").collect(); let _ = path_items.pop(); // Drop the one we don't want. path_items.join("::") - } + }, + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Type(submission) } diff --git a/pgx-utils/src/sql_entity_graph/to_sql.rs b/pgx-utils/src/sql_entity_graph/to_sql.rs new file mode 100644 index 000000000..a939d02a2 --- /dev/null +++ b/pgx-utils/src/sql_entity_graph/to_sql.rs @@ -0,0 +1,148 @@ +use std::hash::Hash; + +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::spanned::Spanned; +use syn::{AttrStyle, Attribute, Lit}; + +use super::{ArgValue, PgxArg, PgxAttribute}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ToSqlConfig { + pub enabled: bool, + pub callback: Option, + pub content: Option, +} +impl From for ToSqlConfig { + fn from(enabled: bool) -> Self { + Self { + enabled, + callback: None, + content: None, + } + } +} +impl From for ToSqlConfig { + fn from(path: syn::Path) -> Self { + Self { + enabled: true, + callback: Some(path), + content: None, + } + } +} +impl From for ToSqlConfig { + fn from(content: syn::LitStr) -> Self { + Self { + enabled: true, + callback: None, + content: Some(content), + } + } +} +impl Default for ToSqlConfig { + fn default() -> Self { + Self { + enabled: true, + callback: None, + content: None, + } + } +} + +const INVALID_ATTR_CONTENT: &str = + "expected `#[pgx(sql = content)]`, where `content` is a boolean, string, or path to a function"; + +impl ToSqlConfig { + /// Used for general purpose parsing from an attribute + pub fn from_attribute(attr: &Attribute) -> Result, syn::Error> { + if attr.style != AttrStyle::Outer { + return Err(syn::Error::new( + attr.span(), + "#[pgx(sql = ..)] is only valid in an outer context", + )); + } + + let attr = attr.parse_args::()?; + for arg in attr.args.iter() { + if let PgxArg::NameValue(ref nv) = arg { + if !nv.path.is_ident("sql") { + continue; + } + + match nv.value { + ArgValue::Path(ref callback_path) => { + return Ok(Some(Self { + enabled: true, + callback: Some(callback_path.clone()), + content: None, + })); + } + ArgValue::Lit(Lit::Bool(ref b)) => { + return Ok(Some(Self { + enabled: b.value, + callback: None, + content: None, + })); + } + ArgValue::Lit(Lit::Str(ref s)) => { + return Ok(Some(Self { + enabled: true, + callback: None, + content: Some(s.clone()), + })); + } + ArgValue::Lit(ref other) => { + return Err(syn::Error::new(other.span(), INVALID_ATTR_CONTENT)); + } + } + } + } + + Ok(None) + } + + /// Used to parse a generator config from a set of item attributes + pub fn from_attributes(attrs: &[Attribute]) -> Result, syn::Error> { + if let Some(attr) = attrs.iter().find(|attr| attr.path.is_ident("pgx")) { + Self::from_attribute(attr) + } else { + Ok(None) + } + } +} + +impl ToTokens for ToSqlConfig { + fn to_tokens(&self, tokens: &mut TokenStream2) { + let enabled = self.enabled; + let callback = &self.callback; + let content = &self.content; + if let Some(callback_path) = callback { + tokens.append_all(quote! { + ::pgx::datum::sql_entity_graph::ToSqlConfigEntity { + enabled: #enabled, + callback: Some(#callback_path), + content: None, + } + }); + return; + } + if let Some(sql) = content { + tokens.append_all(quote! { + ::pgx::datum::sql_entity_graph::ToSqlConfigEntity { + enabled: #enabled, + callback: None, + content: Some(#sql), + } + }); + return; + } + tokens.append_all(quote! { + ::pgx::datum::sql_entity_graph::ToSqlConfigEntity { + enabled: #enabled, + callback: None, + content: None, + } + }); + } +} diff --git a/pgx/src/datum/sql_entity_graph/aggregate.rs b/pgx/src/datum/sql_entity_graph/aggregate.rs index f744c7641..64ecc287d 100644 --- a/pgx/src/datum/sql_entity_graph/aggregate.rs +++ b/pgx/src/datum/sql_entity_graph/aggregate.rs @@ -1,6 +1,6 @@ use crate::{ aggregate::{FinalizeModify, ParallelOption}, - datum::sql_entity_graph::{SqlGraphEntity, SqlGraphIdentifier, ToSql}, + datum::sql_entity_graph::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}, }; use core::{any::TypeId, cmp::Ordering}; use eyre::eyre as eyre_err; @@ -127,6 +127,7 @@ pub struct PgAggregateEntity { /// /// Corresponds to `hypothetical` in [`crate::aggregate::Aggregate`]. pub hypothetical: bool, + pub to_sql_config: ToSqlConfigEntity, } impl Ord for PgAggregateEntity { @@ -278,21 +279,29 @@ impl ToSql for PgAggregateEntity { })?; optional_attributes.push(( format!("\tMSTYPE = {}", sql), - format!("/* {}::MovingState = {} */", self.full_path, value.full_path), + format!( + "/* {}::MovingState = {} */", + self.full_path, value.full_path + ), )); } let mut optional_attributes_string = String::new(); for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() { - let optional_attribute_string = format!("{optional_attribute}{maybe_comma} {comment}{maybe_newline}", + let optional_attribute_string = format!( + "{optional_attribute}{maybe_comma} {comment}{maybe_newline}", optional_attribute = optional_attribute, - maybe_comma = if index == optional_attributes.len() -1 { + maybe_comma = if index == optional_attributes.len() - 1 { "" - } else { "," }, + } else { + "," + }, comment = comment, - maybe_newline = if index == optional_attributes.len() -1 { + maybe_newline = if index == optional_attributes.len() - 1 { "" - } else { "\n" } + } else { + "\n" + } ); optional_attributes_string += &optional_attribute_string; } diff --git a/pgx/src/datum/sql_entity_graph/mod.rs b/pgx/src/datum/sql_entity_graph/mod.rs index 4dfb56de3..84b73fe17 100644 --- a/pgx/src/datum/sql_entity_graph/mod.rs +++ b/pgx/src/datum/sql_entity_graph/mod.rs @@ -64,6 +64,111 @@ pub trait ToSql { fn to_sql(&self, context: &PgxSql) -> eyre::Result; } +/// The signature of a function that can transform a SqlGraphEntity to a SQL string +/// +/// This is used to provide a facility for overriding the default SQL generator behavior using +/// the `#[to_sql(path::to::function)]` attribute in circumstances where the default behavior is +/// not desirable. +/// +/// Implementations can invoke `ToSql::to_sql(entity, context)` on the unwrapped SqlGraphEntity +/// type should they wish to delegate to the default behavior for any reason. +pub type ToSqlFn = + fn( + &SqlGraphEntity, + &PgxSql, + ) -> std::result::Result>; + +/// Represents configuration options for tuning the SQL generator. +/// +/// When an item that can be rendered to SQL has these options at hand, they should be +/// respected. If an item does not have them, then it is not expected that the SQL generation +/// for those items can be modified. +/// +/// The default configuration has `enabled` set to `true`, and `callback` to `None`, which indicates +/// that the default SQL generation behavior will be used. These are intended to be mutually exclusive +/// options, so `callback` should only be set if generation is enabled. +/// +/// When `enabled` is false, no SQL is generated for the item being configured. +/// +/// When `callback` has a value, the corresponding `ToSql` implementation should invoke the +/// callback instead of performing their default behavior. +#[derive(Default, Clone)] +pub struct ToSqlConfigEntity { + pub enabled: bool, + pub callback: Option, + pub content: Option<&'static str>, +} +impl ToSqlConfigEntity { + /// Given a SqlGraphEntity, this function converts it to SQL based on the current configuration. + /// + /// If the config overrides the default behavior (i.e. using the `ToSql` trait), then `Some(eyre::Result)` + /// is returned. If the config does not override the default behavior, then `None` is returned. This can + /// be used to dispatch SQL generation in a single line, e.g.: + /// + /// ```rust,ignore + /// config.to_sql(entity, context).unwrap_or_else(|| entity.to_sql(context))? + /// ``` + pub fn to_sql( + &self, + entity: &SqlGraphEntity, + context: &PgxSql, + ) -> Option> { + use eyre::{eyre, WrapErr}; + + if !self.enabled { + return Some(Ok(String::default())); + } + + if let Some(content) = self.content { + return Some(Ok("\n".to_owned() + content)); + } + + if let Some(callback) = self.callback { + return Some( + callback(entity, context) + .map_err(|e| eyre!(e)) + .wrap_err("Failed to run specified `#[pgx(sql = path)] function`"), + ); + } + + None + } +} +impl std::cmp::PartialEq for ToSqlConfigEntity { + fn eq(&self, other: &Self) -> bool { + if self.enabled != other.enabled { + return false; + } + match (self.callback, other.callback) { + (None, None) => match (self.content, other.content) { + (None, None) => true, + (Some(a), Some(b)) => a == b, + _ => false, + }, + (Some(a), Some(b)) => std::ptr::eq(std::ptr::addr_of!(a), std::ptr::addr_of!(b)), + _ => false, + } + } +} +impl std::cmp::Eq for ToSqlConfigEntity {} +impl std::hash::Hash for ToSqlConfigEntity { + fn hash(&self, state: &mut H) { + self.enabled.hash(state); + self.callback.map(|cb| std::ptr::addr_of!(cb)).hash(state); + self.content.hash(state); + } +} +impl std::fmt::Debug for ToSqlConfigEntity { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let callback = self.callback.map(|cb| std::ptr::addr_of!(cb)); + f.debug_struct("ToSqlConfigEntity") + .field("enabled", &self.enabled) + .field("callback", &format_args!("{:?}", &callback)) + .field("content", &self.content) + .finish() + } +} + /// A mapping from a Rust type to a SQL type, with a `TypeId`. /// /// ```rust diff --git a/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs b/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs index e48847fda..91a41373b 100644 --- a/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs +++ b/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs @@ -10,7 +10,7 @@ pub use returning::PgExternReturnEntity; use pgx_utils::ExternArgs; -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; use pgx_utils::sql_entity_graph::SqlDeclared; use std::cmp::Ordering; @@ -29,7 +29,7 @@ pub struct PgExternEntity { pub fn_args: Vec, pub fn_return: PgExternReturnEntity, pub operator: Option, - pub overridden: Option<&'static str>, + pub to_sql_config: ToSqlConfigEntity, } impl Ord for PgExternEntity { @@ -238,25 +238,12 @@ impl ToSql for PgExternEntity { -- {module_path}::{name}\n\ {requires}\ {fn_sql}\ - {overridden}\ ", name = self.name, module_path = self.module_path, file = self.file, line = self.line, - fn_sql = if self.overridden.is_some() { - let mut inner = fn_sql - .lines() - .map(|f| format!("-- {}", f)) - .collect::>() - .join("\n"); - inner.push_str( - "\n--\n-- Overridden as (due to a `///` comment with a `pgxsql` code block):", - ); - inner - } else { - fn_sql - }, + fn_sql = fn_sql, requires = { let requires_attrs = self .extern_attrs @@ -283,59 +270,56 @@ impl ToSql for PgExternEntity { "".to_string() } }, - overridden = self - .overridden - .map(|f| String::from("\n") + f + "\n") - .unwrap_or_default(), ); tracing::trace!(sql = %ext_sql); - let rendered = match (self.overridden, &self.operator) { - (None, Some(op)) => { - let mut optionals = vec![]; - if let Some(it) = op.commutator { - optionals.push(format!("\tCOMMUTATOR = {}", it)); - }; - if let Some(it) = op.negator { - optionals.push(format!("\tNEGATOR = {}", it)); - }; - if let Some(it) = op.restrict { - optionals.push(format!("\tRESTRICT = {}", it)); - }; - if let Some(it) = op.join { - optionals.push(format!("\tJOIN = {}", it)); - }; - if op.hashes { - optionals.push(String::from("\tHASHES")); - }; - if op.merges { - optionals.push(String::from("\tMERGES")); - }; + let rendered = if let Some(op) = &self.operator { + let mut optionals = vec![]; + if let Some(it) = op.commutator { + optionals.push(format!("\tCOMMUTATOR = {}", it)); + }; + if let Some(it) = op.negator { + optionals.push(format!("\tNEGATOR = {}", it)); + }; + if let Some(it) = op.restrict { + optionals.push(format!("\tRESTRICT = {}", it)); + }; + if let Some(it) = op.join { + optionals.push(format!("\tJOIN = {}", it)); + }; + if op.hashes { + optionals.push(String::from("\tHASHES")); + }; + if op.merges { + optionals.push(String::from("\tMERGES")); + }; - let left_arg = self.fn_args.get(0).ok_or_else(|| { - eyre!("Did not find `left_arg` for operator `{}`.", self.name) - })?; - let left_arg_graph_index = context - .graph - .neighbors_undirected(self_index) - .find(|neighbor| match &context.graph[*neighbor] { - SqlGraphEntity::Type(ty) => ty.id_matches(&left_arg.ty_id), - _ => false, - }) - .ok_or_else(|| eyre!("Could not find left arg function in graph."))?; - let right_arg = self.fn_args.get(1).ok_or_else(|| { - eyre!("Did not find `left_arg` for operator `{}`.", self.name) - })?; - let right_arg_graph_index = context - .graph - .neighbors_undirected(self_index) - .find(|neighbor| match &context.graph[*neighbor] { - SqlGraphEntity::Type(ty) => ty.id_matches(&right_arg.ty_id), - _ => false, - }) - .ok_or_else(|| eyre!("Could not find right arg function in graph."))?; + let left_arg = self + .fn_args + .get(0) + .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?; + let left_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match &context.graph[*neighbor] { + SqlGraphEntity::Type(ty) => ty.id_matches(&left_arg.ty_id), + _ => false, + }) + .ok_or_else(|| eyre!("Could not find left arg function in graph."))?; + let right_arg = self + .fn_args + .get(1) + .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?; + let right_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match &context.graph[*neighbor] { + SqlGraphEntity::Type(ty) => ty.id_matches(&right_arg.ty_id), + _ => false, + }) + .ok_or_else(|| eyre!("Could not find right arg function in graph."))?; - let operator_sql = format!("\n\n\ + let operator_sql = format!("\n\n\ -- {file}:{line}\n\ -- {module_path}::{unaliased_name}\n\ CREATE OPERATOR {opname} (\n\ @@ -344,26 +328,26 @@ impl ToSql for PgExternEntity { \tRIGHTARG={schema_prefix_right}{right_arg}{maybe_comma} /* {right_name} */\n\ {optionals}\ );\ - ", - opname = op.opname.unwrap(), - file = self.file, - line = self.line, - name = self.name, - unaliased_name = self.unaliased_name, - module_path = self.module_path, - left_name = left_arg.full_path, - right_name = right_arg.full_path, - schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index), - left_arg = context.type_id_to_sql_type(left_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", left_arg.pattern, left_arg.full_path, self.name))?, - schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index), - right_arg = context.type_id_to_sql_type(right_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", right_arg.pattern, right_arg.full_path, self.name))?, - maybe_comma = if optionals.len() >= 1 { "," } else { "" }, - optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() }, - ); - tracing::trace!(sql = %operator_sql); - ext_sql + &operator_sql - } - (None, None) | (Some(_), Some(_)) | (Some(_), None) => ext_sql, + ", + opname = op.opname.unwrap(), + file = self.file, + line = self.line, + name = self.name, + unaliased_name = self.unaliased_name, + module_path = self.module_path, + left_name = left_arg.full_path, + right_name = right_arg.full_path, + schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index), + left_arg = context.type_id_to_sql_type(left_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", left_arg.pattern, left_arg.full_path, self.name))?, + schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index), + right_arg = context.type_id_to_sql_type(right_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", right_arg.pattern, right_arg.full_path, self.name))?, + maybe_comma = if optionals.len() >= 1 { "," } else { "" }, + optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() }, + ); + tracing::trace!(sql = %operator_sql); + ext_sql + &operator_sql + } else { + ext_sql }; Ok(rendered) } diff --git a/pgx/src/datum/sql_entity_graph/postgres_enum.rs b/pgx/src/datum/sql_entity_graph/postgres_enum.rs index 159173329..15ec1451e 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_enum.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_enum.rs @@ -3,7 +3,7 @@ use std::{ hash::{Hash, Hasher}, }; -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; /// The output of a [`PostgresEnum`](crate::datum::sql_entity_graph::PostgresEnum) from `quote::ToTokens::to_tokens`. #[derive(Debug, Clone, PartialEq, Eq)] @@ -15,6 +15,7 @@ pub struct PostgresEnumEntity { pub module_path: &'static str, pub mappings: std::collections::HashSet, pub variants: Vec<&'static str>, + pub to_sql_config: ToSqlConfigEntity, } impl Hash for PostgresEnumEntity { diff --git a/pgx/src/datum/sql_entity_graph/postgres_hash.rs b/pgx/src/datum/sql_entity_graph/postgres_hash.rs index 0111cb7e6..af1fc1b85 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_hash.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_hash.rs @@ -1,4 +1,4 @@ -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; use std::cmp::Ordering; /// The output of a [`PostgresHash`](crate::datum::sql_entity_graph::PostgresHash) from `quote::ToTokens::to_tokens`. @@ -10,6 +10,7 @@ pub struct PostgresHashEntity { pub full_path: &'static str, pub module_path: &'static str, pub id: core::any::TypeId, + pub to_sql_config: ToSqlConfigEntity, } impl PostgresHashEntity { diff --git a/pgx/src/datum/sql_entity_graph/postgres_ord.rs b/pgx/src/datum/sql_entity_graph/postgres_ord.rs index 0020100ea..6ec1be5bf 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_ord.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_ord.rs @@ -1,4 +1,4 @@ -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; use std::cmp::Ordering; /// The output of a [`PostgresOrd`](crate::datum::sql_entity_graph::PostgresOrd) from `quote::ToTokens::to_tokens`. @@ -10,6 +10,7 @@ pub struct PostgresOrdEntity { pub full_path: &'static str, pub module_path: &'static str, pub id: core::any::TypeId, + pub to_sql_config: ToSqlConfigEntity, } impl PostgresOrdEntity { diff --git a/pgx/src/datum/sql_entity_graph/postgres_type.rs b/pgx/src/datum/sql_entity_graph/postgres_type.rs index 0873bd15d..7041a157c 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_type.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_type.rs @@ -5,7 +5,7 @@ use std::{ hash::{Hash, Hasher}, }; -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; /// The output of a [`PostgresType`](crate::datum::sql_entity_graph::PostgresType) from `quote::ToTokens::to_tokens`. #[derive(Debug, Clone, PartialEq, Eq)] @@ -20,6 +20,7 @@ pub struct PostgresTypeEntity { pub in_fn_module_path: String, pub out_fn: &'static str, pub out_fn_module_path: String, + pub to_sql_config: ToSqlConfigEntity, } impl Hash for PostgresTypeEntity { diff --git a/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs b/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs index 1d9598959..de7650f8f 100644 --- a/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs +++ b/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs @@ -86,37 +86,61 @@ impl ToSql for SqlGraphEntity { #[tracing::instrument(level = "debug", skip(self, context), fields(identifier = %self.rust_identifier()))] fn to_sql(&self, context: &super::PgxSql) -> eyre::Result { match self { - SqlGraphEntity::Schema(item) => if item.name != "public" && item.name != "pg_catalog" { - item.to_sql(context) - } else { Ok(String::default()) }, - SqlGraphEntity::CustomSql(item) => { - item.to_sql(context) - }, - SqlGraphEntity::Function(item) => if context.graph.neighbors_undirected(context.externs.get(item).unwrap().clone()).any(|neighbor| { - let neighbor_item = &context.graph[neighbor]; - match neighbor_item { - SqlGraphEntity::Type(PostgresTypeEntity { in_fn, in_fn_module_path, out_fn, out_fn_module_path, .. }) => { - let is_in_fn = item.full_path.starts_with(in_fn_module_path) && item.full_path.ends_with(in_fn); - if is_in_fn { - tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an in_fn."); - } - let is_out_fn = item.full_path.starts_with(out_fn_module_path) && item.full_path.ends_with(out_fn); - if is_out_fn { - tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an out_fn."); - } - is_in_fn || is_out_fn - }, - _ => false, + SqlGraphEntity::Schema(item) => { + if item.name != "public" && item.name != "pg_catalog" { + item.to_sql(context) + } else { + Ok(String::default()) } - }) { - Ok(String::default()) - } else { item.to_sql(context) }, - SqlGraphEntity::Type(item) => item.to_sql(context), + } + SqlGraphEntity::CustomSql(item) => item.to_sql(context), + SqlGraphEntity::Function(item) => { + if let Some(result) = item.to_sql_config.to_sql(self, context) { + return result; + } + if context.graph.neighbors_undirected(context.externs.get(item).unwrap().clone()).any(|neighbor| { + let neighbor_item = &context.graph[neighbor]; + match neighbor_item { + SqlGraphEntity::Type(PostgresTypeEntity { in_fn, in_fn_module_path, out_fn, out_fn_module_path, .. }) => { + let is_in_fn = item.full_path.starts_with(in_fn_module_path) && item.full_path.ends_with(in_fn); + if is_in_fn { + tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an in_fn."); + } + let is_out_fn = item.full_path.starts_with(out_fn_module_path) && item.full_path.ends_with(out_fn); + if is_out_fn { + tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an out_fn."); + } + is_in_fn || is_out_fn + }, + _ => false, + } + }) { + Ok(String::default()) + } else { + item.to_sql(context) + } + } + SqlGraphEntity::Type(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), SqlGraphEntity::BuiltinType(_) => Ok(String::default()), - SqlGraphEntity::Enum(item) => item.to_sql(context), - SqlGraphEntity::Ord(item) => item.to_sql(context), - SqlGraphEntity::Hash(item) => item.to_sql(context), - SqlGraphEntity::Aggregate(item) => item.to_sql(context), + SqlGraphEntity::Enum(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), + SqlGraphEntity::Ord(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), + SqlGraphEntity::Hash(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), + SqlGraphEntity::Aggregate(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), SqlGraphEntity::ExtensionRoot(item) => item.to_sql(context), } }