diff --git a/pgx-macros/src/lib.rs b/pgx-macros/src/lib.rs index 7824fc416..7ed1ddb68 100644 --- a/pgx-macros/src/lib.rs +++ b/pgx-macros/src/lib.rs @@ -402,6 +402,7 @@ Optionally accepts the following attributes: * `parallel_unsafe`: Corresponds to [`PARALLEL UNSAFE`](https://www.postgresql.org/docs/current/sql-createfunction.html). * `parallel_restricted`: Corresponds to [`PARALLEL RESTRICTED`](https://www.postgresql.org/docs/current/sql-createfunction.html). * `no_guard`: Do not use `#[pg_guard]` with the function. +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). Functions can accept and return any type which `pgx` supports. `pgx` supports many PostgreSQL types by default. New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`]. @@ -597,7 +598,7 @@ enum DogNames { ``` */ -#[proc_macro_derive(PostgresEnum, attributes(requires))] +#[proc_macro_derive(PostgresEnum, attributes(requires, pgx))] pub fn postgres_enum(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -683,9 +684,12 @@ Optionally accepts the following attributes: * `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type. * `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type. - +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires))] +#[proc_macro_derive( + PostgresType, + attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx) +)] pub fn postgres_type(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -911,12 +915,16 @@ enum DogNames { Brandy, } ``` +Optionally accepts the following attributes: +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresEq)] +#[proc_macro_derive(PostgresEq, attributes(pgx))] pub fn postgres_eq(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - impl_postgres_eq(ast).into() + impl_postgres_eq(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -935,12 +943,16 @@ enum DogNames { Brandy, } ``` +Optionally accepts the following attributes: +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresOrd)] +#[proc_macro_derive(PostgresOrd, attributes(pgx))] pub fn postgres_ord(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - impl_postgres_ord(ast).into() + impl_postgres_ord(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -956,12 +968,16 @@ enum DogNames { Brandy, } ``` +Optionally accepts the following attributes: +* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresHash)] +#[proc_macro_derive(PostgresHash, attributes(pgx))] pub fn postgres_hash(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - impl_postgres_hash(ast).into() + impl_postgres_hash(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -991,15 +1007,26 @@ pub fn pg_aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream { } /** -An inner attribute for [`#[pg_aggregate]`](macro@pg_aggregate). +A helper attribute for various contexts. + +## Usage with [`#[pg_aggregate]`](macro@pg_aggregate). It can be decorated on functions inside a [`#[pg_aggregate]`](macro@pg_aggregate) implementation. In this position, it takes the same args as [`#[pg_extern]`](macro@pg_extern), and those args have the same effect. -Used outside of a [`#[pg_aggregate]`](macro@pg_aggregate), this does nothing. +## Usage for configuring SQL generation + +This attribute can be used to control the behavior of the SQL generator on a decorated item, +e.g. `#[pgx(sql = false)]` + +Currently `sql` can be provided one of the following: + +* Disable SQL generation with `#[pgx(sql = false)]` +* Call custom SQL generator function with `#[pgx(sql = path::to_function)]` +* Render a specific fragment of SQL with a string `#[pgx(sql = "CREATE OR REPLACE FUNCTION ...")]` + */ #[proc_macro_attribute] pub fn pgx(_attr: TokenStream, item: TokenStream) -> TokenStream { item } - diff --git a/pgx-macros/src/operators.rs b/pgx-macros/src/operators.rs index 8e5bb9852..fb17c4447 100644 --- a/pgx-macros/src/operators.rs +++ b/pgx-macros/src/operators.rs @@ -1,17 +1,18 @@ use pgx_utils::{operator_common::*, sql_entity_graph}; + use quote::ToTokens; use syn::DeriveInput; -pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(eq(&ast.ident)); stream.extend(ne(&ast.ident)); - stream + Ok(stream) } -pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(lt(&ast.ident)); @@ -20,19 +21,19 @@ pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream { stream.extend(ge(&ast.ident)); stream.extend(cmp(&ast.ident)); - let sql_graph_entity_item = sql_entity_graph::PostgresOrd::new(ast.ident.clone()); + let sql_graph_entity_item = sql_entity_graph::PostgresOrd::from_derive_input(ast)?; sql_graph_entity_item.to_tokens(&mut stream); - stream + Ok(stream) } -pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(hash(&ast.ident)); - let sql_graph_entity_item = sql_entity_graph::PostgresHash::new(ast.ident.clone()); + let sql_graph_entity_item = sql_entity_graph::PostgresHash::from_derive_input(ast)?; sql_graph_entity_item.to_tokens(&mut stream); - stream + Ok(stream) } diff --git a/pgx-macros/src/rewriter.rs b/pgx-macros/src/rewriter.rs index c435173aa..faeb05b1b 100644 --- a/pgx-macros/src/rewriter.rs +++ b/pgx-macros/src/rewriter.rs @@ -8,10 +8,11 @@ use proc_macro2::{Ident, Span}; use quote::{quote, quote_spanned, ToTokens}; use std::ops::Deref; use std::str::FromStr; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ FnArg, ForeignItem, ForeignItemFn, Generics, ItemFn, ItemForeignMod, Pat, ReturnType, - Signature, Type, Visibility, + Signature, Token, Type, Visibility, }; pub struct PgGuardRewriter(); @@ -224,7 +225,11 @@ impl PgGuardRewriter { let return_type = proc_macro2::TokenStream::from_str(return_type.trim_start_matches("->")).unwrap(); let return_type = quote! {impl std::iter::Iterator}; - let attrs = entity_submission.unwrap().extern_attr_tokens(); + let attrs = entity_submission + .unwrap() + .extern_attrs() + .iter() + .collect::>(); func.sig.output = ReturnType::Default; let sig = func.sig; @@ -708,7 +713,9 @@ impl FunctionSignatureRewriter { fn args(&self, is_raw: bool) -> proc_macro2::TokenStream { if self.func.sig.inputs.len() == 1 && self.return_type_is_datum() { if let FnArg::Typed(ty) = self.func.sig.inputs.first().unwrap() { - if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") || type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo") { + if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") + || type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo") + { return proc_macro2::TokenStream::new(); } } @@ -738,7 +745,9 @@ impl FunctionSignatureRewriter { quote_spanned! {ident.span()=> let #name = pgx::pg_getarg::<#option_type>(fcinfo, #i); } - } else if type_matches(&type_, "pg_sys :: FunctionCallInfo") || type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo") { + } else if type_matches(&type_, "pg_sys :: FunctionCallInfo") + || type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo") + { have_fcinfo = true; quote_spanned! {ident.span()=> let #name = fcinfo; diff --git a/pgx-tests/src/tests/schema_tests.rs b/pgx-tests/src/tests/schema_tests.rs index 61bc6fe10..66ae5bede 100644 --- a/pgx-tests/src/tests/schema_tests.rs +++ b/pgx-tests/src/tests/schema_tests.rs @@ -4,14 +4,66 @@ use pgx::*; #[pgx::pg_schema] mod test_schema { + use pgx::datum::sql_entity_graph::{PgxSql, SqlGraphEntity}; use pgx::*; use serde::{Deserialize, Serialize}; #[pg_extern] fn func_in_diff_schema() {} + #[pg_extern(sql = false)] + fn func_elided_from_schema() {} + + #[pg_extern(sql = generate_function)] + fn func_generated_with_custom_sql() {} + #[derive(Debug, PostgresType, Serialize, Deserialize)] pub struct TestType(pub u64); + + #[derive(Debug, PostgresType, Serialize, Deserialize)] + #[pgx(sql = false)] + pub struct ElidedType(pub u64); + + #[derive(Debug, PostgresType, Serialize, Deserialize)] + #[pgx(sql = generate_type)] + pub struct OtherType(pub u64); + + #[derive(Debug, PostgresType, Serialize, Deserialize)] + #[pgx(sql = "CREATE TYPE test_schema.ManuallyRenderedType;")] + pub struct OverriddenType(pub u64); + + fn generate_function( + entity: &SqlGraphEntity, + _context: &PgxSql, + ) -> Result> { + if let SqlGraphEntity::Function(ref func) = entity { + Ok(format!("\ + CREATE OR REPLACE FUNCTION test_schema.\"func_generated_with_custom_name\"() RETURNS void\n\ + LANGUAGE c /* Rust */\n\ + AS 'MODULE_PATHNAME', '{unaliased_name}_wrapper';\ + ", + unaliased_name = func.unaliased_name, + )) + } else { + panic!("expected extern function entity, got {:?}", entity); + } + } + + fn generate_type( + entity: &SqlGraphEntity, + _context: &PgxSql, + ) -> Result> { + if let SqlGraphEntity::Type(ref ty) = entity { + Ok(format!( + "\n\ + CREATE TYPE test_schema.Custom{name};\ + ", + name = ty.name, + )) + } else { + panic!("expected type entity, got {:?}", entity); + } + } } #[pg_extern(schema = "test_schema")] @@ -44,4 +96,64 @@ mod tests { fn test_type_in_different_schema() { Spi::run("SELECT type_in_diff_schema();"); } + + #[pg_test] + fn elided_extern_is_elided() { + // Validate that a function we know exists, exists + let result: bool = Spi::get_one( + "SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_in_diff_schema');", + ) + .expect("expected result"); + assert_eq!(result, true); + + // Validate that the function we expect not to exist, doesn't + let result: bool = Spi::get_one( + "SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_elided_from_schema');", + ) + .expect("expected result"); + assert_eq!(result, false); + } + + #[pg_test] + fn elided_type_is_elided() { + // Validate that a type we know exists, exists + let result: bool = + Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'testtype');") + .expect("expected result"); + assert_eq!(result, true); + + // Validate that the type we expect not to exist, doesn't + let result: bool = + Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'elidedtype');") + .expect("expected result"); + assert_eq!(result, false); + } + + #[pg_test] + fn custom_to_sql_extern() { + // Validate that the function we generated has the modifications we expect + let result: bool = Spi::get_one("SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_generated_with_custom_name');").expect("expected result"); + assert_eq!(result, true); + + Spi::run("SELECT test_schema.func_generated_with_custom_name();"); + } + + #[pg_test] + fn custom_to_sql_type() { + // Validate that the type we generated has the expected modifications + let result: bool = + Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'customothertype');") + .expect("expected result"); + assert_eq!(result, true); + } + + #[pg_test] + fn custom_handwritten_to_sql_type() { + // Validate that the SQL we provided was used + let result: bool = Spi::get_one( + "SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'manuallyrenderedtype');", + ) + .expect("expected result"); + assert_eq!(result, true); + } } diff --git a/pgx-utils/src/lib.rs b/pgx-utils/src/lib.rs index 3b06d392f..8057f024f 100644 --- a/pgx-utils/src/lib.rs +++ b/pgx-utils/src/lib.rs @@ -311,6 +311,12 @@ pub fn parse_extern_attributes(attr: TokenStream) -> HashSet { let name = name[1..name.len() - 1].to_string(); args.insert(ExternArgs::Name(name.to_string())) } + // Recognized, but not handled as an extern argument + "sql" => { + let _punc = itr.next().unwrap(); + let _value = itr.next().unwrap(); + false + } _ => false, }; } diff --git a/pgx-utils/src/sql_entity_graph/mod.rs b/pgx-utils/src/sql_entity_graph/mod.rs index 04d4394a4..899e2a355 100644 --- a/pgx-utils/src/sql_entity_graph/mod.rs +++ b/pgx-utils/src/sql_entity_graph/mod.rs @@ -2,22 +2,26 @@ mod extension_sql; mod pg_aggregate; mod pg_extern; mod pg_schema; +mod pgx_attribute; mod positioning_ref; mod postgres_enum; mod postgres_hash; mod postgres_ord; mod postgres_type; +mod to_sql; pub use super::ExternArgs; pub use extension_sql::{ExtensionSql, ExtensionSqlFile, SqlDeclared}; pub use pg_aggregate::PgAggregate; pub use pg_extern::{Argument, PgExtern, PgOperator}; pub use pg_schema::Schema; +pub use pgx_attribute::{ArgValue, NameValueArg, PgxArg, PgxAttribute}; pub use positioning_ref::PositioningRef; pub use postgres_enum::PostgresEnum; pub use postgres_hash::PostgresHash; pub use postgres_ord::PostgresOrd; pub use postgres_type::PostgresType; +pub use to_sql::ToSqlConfig; /// Reexports for the pgx SQL generator binaries. #[doc(hidden)] diff --git a/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs b/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs index 223072d1f..f7cd9671c 100644 --- a/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs +++ b/pgx-utils/src/sql_entity_graph/pg_aggregate/mod.rs @@ -14,6 +14,8 @@ use syn::{ Expr, }; +use super::ToSqlConfig; + // We support only 32 tuples... const ARG_NAMES: [&str; 32] = [ "arg_one", @@ -77,10 +79,12 @@ pub struct PgAggregate { fn_moving_state_inverse: Option, fn_moving_finalize: Option, hypothetical: bool, + to_sql_config: ToSqlConfig, } impl PgAggregate { pub fn new(mut item_impl: ItemImpl) -> Result { + let to_sql_config = ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default(); let target_path = get_target_path(&item_impl)?; let target_ident = get_target_ident(&target_path)?; let snake_case_target_ident = Ident::new( @@ -495,6 +499,7 @@ impl PgAggregate { } else { false }, + to_sql_config, }) } @@ -534,6 +539,7 @@ impl PgAggregate { let fn_moving_state_iter = self.fn_moving_state.iter(); let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter(); let fn_moving_finalize_iter = self.fn_moving_finalize.iter(); + let to_sql_config = &self.to_sql_config; let entity_item_fn: ItemFn = parse_quote! { #[no_mangle] @@ -569,6 +575,7 @@ impl PgAggregate { sortop: None#( .unwrap_or(Some(#const_sort_operator_iter)) )*, parallel: None#( .unwrap_or(#const_parallel_iter) )*, hypothetical: #hypothetical, + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Aggregate(submission) } diff --git a/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs b/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs index 5e932acfe..c0a6367f8 100644 --- a/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs +++ b/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs @@ -1,4 +1,4 @@ -use crate::sql_entity_graph::PositioningRef; +use crate::sql_entity_graph::{PositioningRef, ToSqlConfig}; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ @@ -7,29 +7,6 @@ use syn::{ Token, }; -#[derive(Debug, Clone)] -pub struct PgxAttributes { - pub attrs: Punctuated, -} - -impl Parse for PgxAttributes { - fn parse(input: ParseStream) -> Result { - Ok(Self { - attrs: input.parse_terminated(Attribute::parse)?, - }) - } -} - -impl ToTokens for PgxAttributes { - fn to_tokens(&self, tokens: &mut TokenStream2) { - let attrs = &self.attrs; - let quoted = quote! { - vec![#attrs] - }; - tokens.append_all(quoted); - } -} - #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum Attribute { Immutable, @@ -46,6 +23,7 @@ pub enum Attribute { Name(syn::LitStr), Cost(syn::Expr), Requires(Punctuated), + Sql(ToSqlConfig), } impl ToTokens for Attribute { @@ -85,6 +63,10 @@ impl ToTokens for Attribute { .collect::>(); quote! { pgx::datum::sql_entity_graph::ExternArgs::Requires(vec![#(#items_iter),*],) } } + // This attribute is handled separately + Attribute::Sql(_) => { + return; + } }; tokens.append_all(quoted); } @@ -129,6 +111,23 @@ impl Parse for Attribute { let _bracket = syn::bracketed!(content in input); Self::Requires(content.parse_terminated(PositioningRef::parse)?) } + "sql" => { + use crate::sql_entity_graph::ArgValue; + use syn::Lit; + + let _eq: Token![=] = input.parse()?; + match input.parse::()? { + ArgValue::Path(p) => Self::Sql(ToSqlConfig::from(p)), + ArgValue::Lit(Lit::Bool(b)) => Self::Sql(ToSqlConfig::from(b.value)), + ArgValue::Lit(Lit::Str(s)) => Self::Sql(ToSqlConfig::from(s)), + ArgValue::Lit(other) => { + return Err(syn::Error::new( + other.span(), + "expected boolean, path, or string literal", + )) + } + } + } _ => return Err(syn::Error::new(Span::call_site(), "Invalid option")), }; Ok(found) diff --git a/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs b/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs index 227659f07..6728f2d74 100644 --- a/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs +++ b/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs @@ -5,19 +5,22 @@ mod returning; mod search_path; pub use argument::Argument; -use attribute::{Attribute, PgxAttributes}; +use attribute::Attribute; pub use operator::PgOperator; use operator::{PgxOperatorAttributeWithIdent, PgxOperatorOpName}; -use returning::Returning; pub(crate) use returning::NameMacro; +use returning::Returning; use search_path::SearchPathList; use eyre::WrapErr; use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use std::convert::TryFrom; -use syn::parse::{Parse, ParseStream}; -use syn::Meta; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::punctuated::Punctuated; +use syn::{Meta, Token}; + +use crate::sql_entity_graph::ToSqlConfig; /// A parsed `#[pg_extern]` item. /// @@ -42,42 +45,35 @@ use syn::Meta; /// ``` #[derive(Debug, Clone)] pub struct PgExtern { - attrs: Option, - attr_tokens: proc_macro2::TokenStream, + attrs: Vec, func: syn::ItemFn, + to_sql_config: ToSqlConfig, } impl PgExtern { fn name(&self) -> String { self.attrs - .as_ref() - .and_then(|a| { - a.attrs.iter().find_map(|candidate| match candidate { - Attribute::Name(name) => Some(name.value()), - _ => None, - }) + .iter() + .find_map(|a| match a { + Attribute::Name(name) => Some(name.value()), + _ => None, }) .unwrap_or_else(|| self.func.sig.ident.to_string()) } fn schema(&self) -> Option { - self.attrs.as_ref().and_then(|a| { - a.attrs.iter().find_map(|candidate| match candidate { - Attribute::Schema(name) => Some(name.value()), - _ => None, - }) + self.attrs.iter().find_map(|a| match a { + Attribute::Schema(name) => Some(name.value()), + _ => None, }) } - fn extern_attrs(&self) -> Option<&PgxAttributes> { - self.attrs.as_ref() + pub fn extern_attrs(&self) -> &[Attribute] { + self.attrs.as_slice() } - pub fn extern_attr_tokens(&self) -> &proc_macro2::TokenStream { - &self.attr_tokens - } - - fn overridden(&self) -> Option { + fn overridden(&self) -> Option { + let mut span = None; let mut retval = None; let mut in_commented_sql_block = false; for attr in &self.func.attrs { @@ -88,7 +84,8 @@ impl PgExtern { Meta::Path(_) | Meta::List(_) => continue, Meta::NameValue(mnv) => mnv, }; - if let syn::Lit::Str(inner) = content.lit { + if let syn::Lit::Str(ref inner) = content.lit { + span.get_or_insert(content.lit.span()); if !in_commented_sql_block && inner.value().trim() == "```pgxsql" { in_commented_sql_block = true; } else if in_commented_sql_block && inner.value().trim() == "```" { @@ -105,7 +102,7 @@ impl PgExtern { } } } - retval + retval.map(|s| syn::LitStr::new(s.as_ref(), span.unwrap())) } fn operator(&self) -> Option { @@ -191,12 +188,27 @@ impl PgExtern { } pub fn new(attr: TokenStream2, item: TokenStream2) -> Result { - let attrs = syn::parse2::(attr.clone()).ok(); + let mut attrs = Vec::new(); + let mut to_sql_config: Option = None; + + let parser = Punctuated::::parse_terminated; + let punctuated_attrs = parser.parse2(attr)?; + for pair in punctuated_attrs.into_pairs() { + match pair.into_value() { + Attribute::Sql(config) => { + to_sql_config.get_or_insert(config); + } + attr => { + attrs.push(attr); + } + } + } + let func = syn::parse2::(item)?; Ok(Self { - attrs: attrs, - attr_tokens: attr, - func: func, + attrs, + func, + to_sql_config: to_sql_config.unwrap_or_default(), }) } } @@ -207,7 +219,7 @@ impl ToTokens for PgExtern { let name = self.name(); let schema = self.schema(); let schema_iter = schema.iter(); - let extern_attrs = self.extern_attrs(); + let extern_attrs = self.attrs.iter().collect::>(); let search_path = self.search_path().into_iter(); let inputs = self.inputs().unwrap(); let returns = match self.returns() { @@ -221,7 +233,14 @@ impl ToTokens for PgExtern { } }; let operator = self.operator().into_iter(); - let overridden = self.overridden().into_iter(); + let to_sql_config = match self.overridden() { + None => self.to_sql_config.clone(), + Some(content) => { + let mut config = self.to_sql_config.clone(); + config.content = Some(content); + config + } + }; let sql_graph_entity_fn_name = syn::Ident::new(&format!("__pgx_internals_fn_{}", ident), Span::call_site()); @@ -240,12 +259,12 @@ impl ToTokens for PgExtern { line: line!(), module_path: core::module_path!(), full_path: concat!(core::module_path!(), "::", stringify!(#ident)), - extern_attrs: #extern_attrs, + extern_attrs: vec![#extern_attrs], search_path: None#( .unwrap_or(Some(vec![#search_path])) )*, fn_args: vec![#(#inputs),*], fn_return: #returns, operator: None#( .unwrap_or(Some(#operator)) )*, - overridden: None#( .unwrap_or(Some(#overridden)) )*, + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Function(submission) } @@ -256,13 +275,27 @@ impl ToTokens for PgExtern { impl Parse for PgExtern { fn parse(input: ParseStream) -> Result { - let attrs: Option = input.parse().ok(); - let func = input.parse()?; - let attr_tokens: proc_macro2::TokenStream = attrs.clone().into_token_stream(); + let mut attrs = Vec::new(); + let mut to_sql_config: Option = None; + + let parser = Punctuated::::parse_terminated; + let punctuated_attrs = input.call(parser).ok().unwrap_or_default(); + for pair in punctuated_attrs.into_pairs() { + match pair.into_value() { + Attribute::Sql(config) => { + to_sql_config.get_or_insert(config); + } + attr => { + attrs.push(attr); + } + } + } + + let func: syn::ItemFn = input.parse()?; Ok(Self { attrs, - attr_tokens, func, + to_sql_config: to_sql_config.unwrap_or_default(), }) } } diff --git a/pgx-utils/src/sql_entity_graph/pgx_attribute.rs b/pgx-utils/src/sql_entity_graph/pgx_attribute.rs new file mode 100644 index 000000000..3fcf3faee --- /dev/null +++ b/pgx-utils/src/sql_entity_graph/pgx_attribute.rs @@ -0,0 +1,80 @@ +use syn::parse::{Parse, ParseStream}; +use syn::{parenthesized, punctuated::Punctuated}; +use syn::{token, Token}; + +/// This struct is intented to represent the contents of the `#[pgx]` attribute when parsed. +/// +/// The intended usage is to parse an `Attribute`, then use `attr.parse_args::()?` to +/// parse the contents of the attribute into this struct. +/// +/// We use this rather than `Attribute::parse_meta` because it is not supported to parse bare paths +/// as values of a `NameValueMeta`, and we want to support that to avoid conflating SQL strings with +/// paths-as-strings. We re-use as much of the standard `parse_meta` structure types as possible though. +pub struct PgxAttribute { + pub args: Vec, +} + +impl Parse for PgxAttribute { + fn parse(input: ParseStream<'_>) -> syn::Result { + let parser = Punctuated::::parse_terminated; + let punctuated = input.call(parser)?; + let args = punctuated + .into_pairs() + .map(|p| p.into_value()) + .collect::>(); + Ok(Self { args }) + } +} + +/// This enum is akin to `syn::Meta`, but supports a custom `NameValue` variant which allows +/// for bare paths in the value position. +pub enum PgxArg { + Path(syn::Path), + List(syn::MetaList), + NameValue(NameValueArg), +} + +impl Parse for PgxArg { + fn parse(input: ParseStream<'_>) -> syn::Result { + let path = input.parse::()?; + if input.peek(token::Paren) { + let content; + Ok(Self::List(syn::MetaList { + path, + paren_token: parenthesized!(content in input), + nested: content.parse_terminated(syn::NestedMeta::parse)?, + })) + } else if input.peek(Token![=]) { + Ok(Self::NameValue(NameValueArg { + path, + eq_token: input.parse()?, + value: input.parse()?, + })) + } else { + Ok(Self::Path(path)) + } + } +} + +/// This struct is akin to `syn::NameValueMeta`, but allows for more than just `syn::Lit` as a value. +pub struct NameValueArg { + pub path: syn::Path, + pub eq_token: syn::token::Eq, + pub value: ArgValue, +} + +/// This is the type of a value that can be used in the value position of a `name = value` attribute argument. +pub enum ArgValue { + Path(syn::Path), + Lit(syn::Lit), +} + +impl Parse for ArgValue { + fn parse(input: ParseStream<'_>) -> syn::Result { + if input.peek(syn::Lit) { + return Ok(Self::Lit(input.parse()?)); + } + + Ok(Self::Path(input.parse()?)) + } +} diff --git a/pgx-utils/src/sql_entity_graph/postgres_enum.rs b/pgx-utils/src/sql_entity_graph/postgres_enum.rs index 6c43880a9..b2c0bd1ef 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_enum.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_enum.rs @@ -6,6 +6,8 @@ use syn::{ }; use syn::{punctuated::Punctuated, Ident, Token}; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresEnum)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -33,6 +35,7 @@ pub struct PostgresEnum { name: Ident, generics: Generics, variants: Punctuated, + to_sql_config: ToSqlConfig, } impl PostgresEnum { @@ -40,15 +43,19 @@ impl PostgresEnum { name: Ident, generics: Generics, variants: Punctuated, + to_sql_config: ToSqlConfig, ) -> Self { Self { name, generics, variants, + to_sql_config, } } pub fn from_derive_input(derive_input: DeriveInput) -> Result { + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); let data_enum = match derive_input.data { syn::Data::Enum(data_enum) => data_enum, syn::Data::Union(_) | syn::Data::Struct(_) => { @@ -59,6 +66,7 @@ impl PostgresEnum { derive_input.ident, derive_input.generics, data_enum.variants, + to_sql_config, )) } } @@ -66,7 +74,14 @@ impl PostgresEnum { impl Parse for PostgresEnum { fn parse(input: ParseStream) -> Result { let parsed: ItemEnum = input.parse()?; - Ok(Self::new(parsed.ident, parsed.generics, parsed.variants)) + let to_sql_config = + ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default(); + Ok(Self::new( + parsed.ident, + parsed.generics, + parsed.variants, + to_sql_config, + )) } } @@ -84,6 +99,8 @@ impl ToTokens for PostgresEnum { let sql_graph_entity_fn_name = syn::Ident::new(&format!("__pgx_internals_enum_{}", name), Span::call_site()); + let to_sql_config = &self.to_sql_config; + let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -104,6 +121,7 @@ impl ToTokens for PostgresEnum { full_path: core::any::type_name::<#name #ty_generics>(), mappings, variants: vec![ #( stringify!(#variants) ),* ], + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Enum(submission) } diff --git a/pgx-utils/src/sql_entity_graph/postgres_hash.rs b/pgx-utils/src/sql_entity_graph/postgres_hash.rs index c06e50206..68a2919f5 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_hash.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_hash.rs @@ -2,9 +2,11 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ parse::{Parse, ParseStream}, - DeriveInput, Ident, ItemEnum, ItemStruct, + DeriveInput, Ident, }; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresHash)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -51,27 +53,36 @@ use syn::{ #[derive(Debug, Clone)] pub struct PostgresHash { pub name: Ident, + pub to_sql_config: ToSqlConfig, } impl PostgresHash { - pub fn new(name: Ident) -> Self { - Self { name } + pub fn new(name: Ident, to_sql_config: ToSqlConfig) -> Self { + Self { + name, + to_sql_config, + } } pub fn from_derive_input(derive_input: DeriveInput) -> Result { - Ok(Self::new(derive_input.ident)) + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); + Ok(Self::new(derive_input.ident, to_sql_config)) } } impl Parse for PostgresHash { fn parse(input: ParseStream) -> Result { - let parsed_enum: Result = input.parse(); - let parsed_struct: Result = input.parse(); - let ident = parsed_enum - .map(|x| x.ident) - .or_else(|_| parsed_struct.map(|x| x.ident)) - .map_err(|_| syn::Error::new(input.span(), "expected enum or struct"))?; - Ok(Self::new(ident)) + use syn::Item; + + let parsed = input.parse()?; + let (ident, attrs) = match &parsed { + Item::Enum(item) => (item.ident.clone(), item.attrs.as_slice()), + Item::Struct(item) => (item.ident.clone(), item.attrs.as_slice()), + _ => return Err(syn::Error::new(input.span(), "expected enum or struct")), + }; + let to_sql_config = ToSqlConfig::from_attributes(attrs)?.unwrap_or_default(); + Ok(Self::new(ident, to_sql_config)) } } @@ -82,6 +93,7 @@ impl ToTokens for PostgresHash { &format!("__pgx_internals_hash_{}", self.name), Span::call_site(), ); + let to_sql_config = &self.to_sql_config; let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -96,6 +108,7 @@ impl ToTokens for PostgresHash { full_path: core::any::type_name::<#name>(), module_path: module_path!(), id: TypeId::of::<#name>(), + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Hash(submission) } diff --git a/pgx-utils/src/sql_entity_graph/postgres_ord.rs b/pgx-utils/src/sql_entity_graph/postgres_ord.rs index d270030d1..d564cbd73 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_ord.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_ord.rs @@ -2,9 +2,11 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ parse::{Parse, ParseStream}, - DeriveInput, Ident, ItemEnum, ItemStruct, + DeriveInput, Ident, }; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresOrd)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -51,27 +53,36 @@ use syn::{ #[derive(Debug, Clone)] pub struct PostgresOrd { pub name: Ident, + pub to_sql_config: ToSqlConfig, } impl PostgresOrd { - pub fn new(name: Ident) -> Self { - Self { name } + pub fn new(name: Ident, to_sql_config: ToSqlConfig) -> Self { + Self { + name, + to_sql_config, + } } pub fn from_derive_input(derive_input: DeriveInput) -> Result { - Ok(Self::new(derive_input.ident)) + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); + Ok(Self::new(derive_input.ident, to_sql_config)) } } impl Parse for PostgresOrd { fn parse(input: ParseStream) -> Result { - let parsed_enum: Result = input.parse(); - let parsed_struct: Result = input.parse(); - let ident = parsed_enum - .map(|x| x.ident) - .or_else(|_| parsed_struct.map(|x| x.ident)) - .map_err(|_| syn::Error::new(input.span(), "expected enum or struct"))?; - Ok(Self::new(ident)) + use syn::Item; + + let parsed = input.parse()?; + let (ident, attrs) = match &parsed { + Item::Enum(item) => (item.ident.clone(), item.attrs.as_slice()), + Item::Struct(item) => (item.ident.clone(), item.attrs.as_slice()), + _ => return Err(syn::Error::new(input.span(), "expected enum or struct")), + }; + let to_sql_config = ToSqlConfig::from_attributes(attrs)?.unwrap_or_default(); + Ok(Self::new(ident, to_sql_config)) } } @@ -82,6 +93,7 @@ impl ToTokens for PostgresOrd { &format!("__pgx_internals_ord_{}", self.name), Span::call_site(), ); + let to_sql_config = &self.to_sql_config; let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -96,6 +108,7 @@ impl ToTokens for PostgresOrd { full_path: core::any::type_name::<#name>(), module_path: module_path!(), id: TypeId::of::<#name>(), + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Ord(submission) } diff --git a/pgx-utils/src/sql_entity_graph/postgres_type.rs b/pgx-utils/src/sql_entity_graph/postgres_type.rs index 805f0bdb5..2bd38bf32 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_type.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_type.rs @@ -9,6 +9,8 @@ use syn::{ DeriveInput, Generics, ItemStruct, }; +use super::ToSqlConfig; + /// A parsed `#[derive(PostgresType)]` item. /// /// It should be used with [`syn::parse::Parse`] functions. @@ -37,15 +39,23 @@ pub struct PostgresType { generics: Generics, in_fn: Ident, out_fn: Ident, + to_sql_config: ToSqlConfig, } impl PostgresType { - pub fn new(name: Ident, generics: Generics, in_fn: Ident, out_fn: Ident) -> Self { + pub fn new( + name: Ident, + generics: Generics, + in_fn: Ident, + out_fn: Ident, + to_sql_config: ToSqlConfig, + ) -> Self { Self { generics, name, in_fn, out_fn, + to_sql_config, } } @@ -59,6 +69,8 @@ impl PostgresType { )) } }; + let to_sql_config = + ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default(); let funcname_in = Ident::new( &format!("{}_in", derive_input.ident).to_lowercase(), derive_input.ident.span(), @@ -72,6 +84,7 @@ impl PostgresType { derive_input.generics, funcname_in, funcname_out, + to_sql_config, )) } @@ -93,6 +106,8 @@ impl PostgresType { impl Parse for PostgresType { fn parse(input: ParseStream) -> Result { let parsed: ItemStruct = input.parse()?; + let to_sql_config = + ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default(); let funcname_in = Ident::new( &format!("{}_in", parsed.ident).to_lowercase(), parsed.ident.span(), @@ -106,6 +121,7 @@ impl Parse for PostgresType { parsed.generics, funcname_in, funcname_out, + to_sql_config, )) } } @@ -127,6 +143,8 @@ impl ToTokens for PostgresType { Span::call_site(), ); + let to_sql_config = &self.to_sql_config; + let inv = quote! { #[no_mangle] pub extern "C" fn #sql_graph_entity_fn_name() -> pgx::datum::sql_entity_graph::SqlGraphEntity { @@ -172,7 +190,8 @@ impl ToTokens for PostgresType { let mut path_items: Vec<_> = out_fn.split("::").collect(); let _ = path_items.pop(); // Drop the one we don't want. path_items.join("::") - } + }, + to_sql_config: #to_sql_config, }; pgx::datum::sql_entity_graph::SqlGraphEntity::Type(submission) } diff --git a/pgx-utils/src/sql_entity_graph/to_sql.rs b/pgx-utils/src/sql_entity_graph/to_sql.rs new file mode 100644 index 000000000..a939d02a2 --- /dev/null +++ b/pgx-utils/src/sql_entity_graph/to_sql.rs @@ -0,0 +1,148 @@ +use std::hash::Hash; + +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::spanned::Spanned; +use syn::{AttrStyle, Attribute, Lit}; + +use super::{ArgValue, PgxArg, PgxAttribute}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ToSqlConfig { + pub enabled: bool, + pub callback: Option, + pub content: Option, +} +impl From for ToSqlConfig { + fn from(enabled: bool) -> Self { + Self { + enabled, + callback: None, + content: None, + } + } +} +impl From for ToSqlConfig { + fn from(path: syn::Path) -> Self { + Self { + enabled: true, + callback: Some(path), + content: None, + } + } +} +impl From for ToSqlConfig { + fn from(content: syn::LitStr) -> Self { + Self { + enabled: true, + callback: None, + content: Some(content), + } + } +} +impl Default for ToSqlConfig { + fn default() -> Self { + Self { + enabled: true, + callback: None, + content: None, + } + } +} + +const INVALID_ATTR_CONTENT: &str = + "expected `#[pgx(sql = content)]`, where `content` is a boolean, string, or path to a function"; + +impl ToSqlConfig { + /// Used for general purpose parsing from an attribute + pub fn from_attribute(attr: &Attribute) -> Result, syn::Error> { + if attr.style != AttrStyle::Outer { + return Err(syn::Error::new( + attr.span(), + "#[pgx(sql = ..)] is only valid in an outer context", + )); + } + + let attr = attr.parse_args::()?; + for arg in attr.args.iter() { + if let PgxArg::NameValue(ref nv) = arg { + if !nv.path.is_ident("sql") { + continue; + } + + match nv.value { + ArgValue::Path(ref callback_path) => { + return Ok(Some(Self { + enabled: true, + callback: Some(callback_path.clone()), + content: None, + })); + } + ArgValue::Lit(Lit::Bool(ref b)) => { + return Ok(Some(Self { + enabled: b.value, + callback: None, + content: None, + })); + } + ArgValue::Lit(Lit::Str(ref s)) => { + return Ok(Some(Self { + enabled: true, + callback: None, + content: Some(s.clone()), + })); + } + ArgValue::Lit(ref other) => { + return Err(syn::Error::new(other.span(), INVALID_ATTR_CONTENT)); + } + } + } + } + + Ok(None) + } + + /// Used to parse a generator config from a set of item attributes + pub fn from_attributes(attrs: &[Attribute]) -> Result, syn::Error> { + if let Some(attr) = attrs.iter().find(|attr| attr.path.is_ident("pgx")) { + Self::from_attribute(attr) + } else { + Ok(None) + } + } +} + +impl ToTokens for ToSqlConfig { + fn to_tokens(&self, tokens: &mut TokenStream2) { + let enabled = self.enabled; + let callback = &self.callback; + let content = &self.content; + if let Some(callback_path) = callback { + tokens.append_all(quote! { + ::pgx::datum::sql_entity_graph::ToSqlConfigEntity { + enabled: #enabled, + callback: Some(#callback_path), + content: None, + } + }); + return; + } + if let Some(sql) = content { + tokens.append_all(quote! { + ::pgx::datum::sql_entity_graph::ToSqlConfigEntity { + enabled: #enabled, + callback: None, + content: Some(#sql), + } + }); + return; + } + tokens.append_all(quote! { + ::pgx::datum::sql_entity_graph::ToSqlConfigEntity { + enabled: #enabled, + callback: None, + content: None, + } + }); + } +} diff --git a/pgx/src/datum/sql_entity_graph/aggregate.rs b/pgx/src/datum/sql_entity_graph/aggregate.rs index f744c7641..64ecc287d 100644 --- a/pgx/src/datum/sql_entity_graph/aggregate.rs +++ b/pgx/src/datum/sql_entity_graph/aggregate.rs @@ -1,6 +1,6 @@ use crate::{ aggregate::{FinalizeModify, ParallelOption}, - datum::sql_entity_graph::{SqlGraphEntity, SqlGraphIdentifier, ToSql}, + datum::sql_entity_graph::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}, }; use core::{any::TypeId, cmp::Ordering}; use eyre::eyre as eyre_err; @@ -127,6 +127,7 @@ pub struct PgAggregateEntity { /// /// Corresponds to `hypothetical` in [`crate::aggregate::Aggregate`]. pub hypothetical: bool, + pub to_sql_config: ToSqlConfigEntity, } impl Ord for PgAggregateEntity { @@ -278,21 +279,29 @@ impl ToSql for PgAggregateEntity { })?; optional_attributes.push(( format!("\tMSTYPE = {}", sql), - format!("/* {}::MovingState = {} */", self.full_path, value.full_path), + format!( + "/* {}::MovingState = {} */", + self.full_path, value.full_path + ), )); } let mut optional_attributes_string = String::new(); for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() { - let optional_attribute_string = format!("{optional_attribute}{maybe_comma} {comment}{maybe_newline}", + let optional_attribute_string = format!( + "{optional_attribute}{maybe_comma} {comment}{maybe_newline}", optional_attribute = optional_attribute, - maybe_comma = if index == optional_attributes.len() -1 { + maybe_comma = if index == optional_attributes.len() - 1 { "" - } else { "," }, + } else { + "," + }, comment = comment, - maybe_newline = if index == optional_attributes.len() -1 { + maybe_newline = if index == optional_attributes.len() - 1 { "" - } else { "\n" } + } else { + "\n" + } ); optional_attributes_string += &optional_attribute_string; } diff --git a/pgx/src/datum/sql_entity_graph/mod.rs b/pgx/src/datum/sql_entity_graph/mod.rs index 4dfb56de3..84b73fe17 100644 --- a/pgx/src/datum/sql_entity_graph/mod.rs +++ b/pgx/src/datum/sql_entity_graph/mod.rs @@ -64,6 +64,111 @@ pub trait ToSql { fn to_sql(&self, context: &PgxSql) -> eyre::Result; } +/// The signature of a function that can transform a SqlGraphEntity to a SQL string +/// +/// This is used to provide a facility for overriding the default SQL generator behavior using +/// the `#[to_sql(path::to::function)]` attribute in circumstances where the default behavior is +/// not desirable. +/// +/// Implementations can invoke `ToSql::to_sql(entity, context)` on the unwrapped SqlGraphEntity +/// type should they wish to delegate to the default behavior for any reason. +pub type ToSqlFn = + fn( + &SqlGraphEntity, + &PgxSql, + ) -> std::result::Result>; + +/// Represents configuration options for tuning the SQL generator. +/// +/// When an item that can be rendered to SQL has these options at hand, they should be +/// respected. If an item does not have them, then it is not expected that the SQL generation +/// for those items can be modified. +/// +/// The default configuration has `enabled` set to `true`, and `callback` to `None`, which indicates +/// that the default SQL generation behavior will be used. These are intended to be mutually exclusive +/// options, so `callback` should only be set if generation is enabled. +/// +/// When `enabled` is false, no SQL is generated for the item being configured. +/// +/// When `callback` has a value, the corresponding `ToSql` implementation should invoke the +/// callback instead of performing their default behavior. +#[derive(Default, Clone)] +pub struct ToSqlConfigEntity { + pub enabled: bool, + pub callback: Option, + pub content: Option<&'static str>, +} +impl ToSqlConfigEntity { + /// Given a SqlGraphEntity, this function converts it to SQL based on the current configuration. + /// + /// If the config overrides the default behavior (i.e. using the `ToSql` trait), then `Some(eyre::Result)` + /// is returned. If the config does not override the default behavior, then `None` is returned. This can + /// be used to dispatch SQL generation in a single line, e.g.: + /// + /// ```rust,ignore + /// config.to_sql(entity, context).unwrap_or_else(|| entity.to_sql(context))? + /// ``` + pub fn to_sql( + &self, + entity: &SqlGraphEntity, + context: &PgxSql, + ) -> Option> { + use eyre::{eyre, WrapErr}; + + if !self.enabled { + return Some(Ok(String::default())); + } + + if let Some(content) = self.content { + return Some(Ok("\n".to_owned() + content)); + } + + if let Some(callback) = self.callback { + return Some( + callback(entity, context) + .map_err(|e| eyre!(e)) + .wrap_err("Failed to run specified `#[pgx(sql = path)] function`"), + ); + } + + None + } +} +impl std::cmp::PartialEq for ToSqlConfigEntity { + fn eq(&self, other: &Self) -> bool { + if self.enabled != other.enabled { + return false; + } + match (self.callback, other.callback) { + (None, None) => match (self.content, other.content) { + (None, None) => true, + (Some(a), Some(b)) => a == b, + _ => false, + }, + (Some(a), Some(b)) => std::ptr::eq(std::ptr::addr_of!(a), std::ptr::addr_of!(b)), + _ => false, + } + } +} +impl std::cmp::Eq for ToSqlConfigEntity {} +impl std::hash::Hash for ToSqlConfigEntity { + fn hash(&self, state: &mut H) { + self.enabled.hash(state); + self.callback.map(|cb| std::ptr::addr_of!(cb)).hash(state); + self.content.hash(state); + } +} +impl std::fmt::Debug for ToSqlConfigEntity { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let callback = self.callback.map(|cb| std::ptr::addr_of!(cb)); + f.debug_struct("ToSqlConfigEntity") + .field("enabled", &self.enabled) + .field("callback", &format_args!("{:?}", &callback)) + .field("content", &self.content) + .finish() + } +} + /// A mapping from a Rust type to a SQL type, with a `TypeId`. /// /// ```rust diff --git a/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs b/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs index e48847fda..91a41373b 100644 --- a/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs +++ b/pgx/src/datum/sql_entity_graph/pg_extern/mod.rs @@ -10,7 +10,7 @@ pub use returning::PgExternReturnEntity; use pgx_utils::ExternArgs; -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; use pgx_utils::sql_entity_graph::SqlDeclared; use std::cmp::Ordering; @@ -29,7 +29,7 @@ pub struct PgExternEntity { pub fn_args: Vec, pub fn_return: PgExternReturnEntity, pub operator: Option, - pub overridden: Option<&'static str>, + pub to_sql_config: ToSqlConfigEntity, } impl Ord for PgExternEntity { @@ -238,25 +238,12 @@ impl ToSql for PgExternEntity { -- {module_path}::{name}\n\ {requires}\ {fn_sql}\ - {overridden}\ ", name = self.name, module_path = self.module_path, file = self.file, line = self.line, - fn_sql = if self.overridden.is_some() { - let mut inner = fn_sql - .lines() - .map(|f| format!("-- {}", f)) - .collect::>() - .join("\n"); - inner.push_str( - "\n--\n-- Overridden as (due to a `///` comment with a `pgxsql` code block):", - ); - inner - } else { - fn_sql - }, + fn_sql = fn_sql, requires = { let requires_attrs = self .extern_attrs @@ -283,59 +270,56 @@ impl ToSql for PgExternEntity { "".to_string() } }, - overridden = self - .overridden - .map(|f| String::from("\n") + f + "\n") - .unwrap_or_default(), ); tracing::trace!(sql = %ext_sql); - let rendered = match (self.overridden, &self.operator) { - (None, Some(op)) => { - let mut optionals = vec![]; - if let Some(it) = op.commutator { - optionals.push(format!("\tCOMMUTATOR = {}", it)); - }; - if let Some(it) = op.negator { - optionals.push(format!("\tNEGATOR = {}", it)); - }; - if let Some(it) = op.restrict { - optionals.push(format!("\tRESTRICT = {}", it)); - }; - if let Some(it) = op.join { - optionals.push(format!("\tJOIN = {}", it)); - }; - if op.hashes { - optionals.push(String::from("\tHASHES")); - }; - if op.merges { - optionals.push(String::from("\tMERGES")); - }; + let rendered = if let Some(op) = &self.operator { + let mut optionals = vec![]; + if let Some(it) = op.commutator { + optionals.push(format!("\tCOMMUTATOR = {}", it)); + }; + if let Some(it) = op.negator { + optionals.push(format!("\tNEGATOR = {}", it)); + }; + if let Some(it) = op.restrict { + optionals.push(format!("\tRESTRICT = {}", it)); + }; + if let Some(it) = op.join { + optionals.push(format!("\tJOIN = {}", it)); + }; + if op.hashes { + optionals.push(String::from("\tHASHES")); + }; + if op.merges { + optionals.push(String::from("\tMERGES")); + }; - let left_arg = self.fn_args.get(0).ok_or_else(|| { - eyre!("Did not find `left_arg` for operator `{}`.", self.name) - })?; - let left_arg_graph_index = context - .graph - .neighbors_undirected(self_index) - .find(|neighbor| match &context.graph[*neighbor] { - SqlGraphEntity::Type(ty) => ty.id_matches(&left_arg.ty_id), - _ => false, - }) - .ok_or_else(|| eyre!("Could not find left arg function in graph."))?; - let right_arg = self.fn_args.get(1).ok_or_else(|| { - eyre!("Did not find `left_arg` for operator `{}`.", self.name) - })?; - let right_arg_graph_index = context - .graph - .neighbors_undirected(self_index) - .find(|neighbor| match &context.graph[*neighbor] { - SqlGraphEntity::Type(ty) => ty.id_matches(&right_arg.ty_id), - _ => false, - }) - .ok_or_else(|| eyre!("Could not find right arg function in graph."))?; + let left_arg = self + .fn_args + .get(0) + .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?; + let left_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match &context.graph[*neighbor] { + SqlGraphEntity::Type(ty) => ty.id_matches(&left_arg.ty_id), + _ => false, + }) + .ok_or_else(|| eyre!("Could not find left arg function in graph."))?; + let right_arg = self + .fn_args + .get(1) + .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?; + let right_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match &context.graph[*neighbor] { + SqlGraphEntity::Type(ty) => ty.id_matches(&right_arg.ty_id), + _ => false, + }) + .ok_or_else(|| eyre!("Could not find right arg function in graph."))?; - let operator_sql = format!("\n\n\ + let operator_sql = format!("\n\n\ -- {file}:{line}\n\ -- {module_path}::{unaliased_name}\n\ CREATE OPERATOR {opname} (\n\ @@ -344,26 +328,26 @@ impl ToSql for PgExternEntity { \tRIGHTARG={schema_prefix_right}{right_arg}{maybe_comma} /* {right_name} */\n\ {optionals}\ );\ - ", - opname = op.opname.unwrap(), - file = self.file, - line = self.line, - name = self.name, - unaliased_name = self.unaliased_name, - module_path = self.module_path, - left_name = left_arg.full_path, - right_name = right_arg.full_path, - schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index), - left_arg = context.type_id_to_sql_type(left_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", left_arg.pattern, left_arg.full_path, self.name))?, - schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index), - right_arg = context.type_id_to_sql_type(right_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", right_arg.pattern, right_arg.full_path, self.name))?, - maybe_comma = if optionals.len() >= 1 { "," } else { "" }, - optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() }, - ); - tracing::trace!(sql = %operator_sql); - ext_sql + &operator_sql - } - (None, None) | (Some(_), Some(_)) | (Some(_), None) => ext_sql, + ", + opname = op.opname.unwrap(), + file = self.file, + line = self.line, + name = self.name, + unaliased_name = self.unaliased_name, + module_path = self.module_path, + left_name = left_arg.full_path, + right_name = right_arg.full_path, + schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index), + left_arg = context.type_id_to_sql_type(left_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", left_arg.pattern, left_arg.full_path, self.name))?, + schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index), + right_arg = context.type_id_to_sql_type(right_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", right_arg.pattern, right_arg.full_path, self.name))?, + maybe_comma = if optionals.len() >= 1 { "," } else { "" }, + optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() }, + ); + tracing::trace!(sql = %operator_sql); + ext_sql + &operator_sql + } else { + ext_sql }; Ok(rendered) } diff --git a/pgx/src/datum/sql_entity_graph/postgres_enum.rs b/pgx/src/datum/sql_entity_graph/postgres_enum.rs index 159173329..15ec1451e 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_enum.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_enum.rs @@ -3,7 +3,7 @@ use std::{ hash::{Hash, Hasher}, }; -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; /// The output of a [`PostgresEnum`](crate::datum::sql_entity_graph::PostgresEnum) from `quote::ToTokens::to_tokens`. #[derive(Debug, Clone, PartialEq, Eq)] @@ -15,6 +15,7 @@ pub struct PostgresEnumEntity { pub module_path: &'static str, pub mappings: std::collections::HashSet, pub variants: Vec<&'static str>, + pub to_sql_config: ToSqlConfigEntity, } impl Hash for PostgresEnumEntity { diff --git a/pgx/src/datum/sql_entity_graph/postgres_hash.rs b/pgx/src/datum/sql_entity_graph/postgres_hash.rs index 0111cb7e6..af1fc1b85 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_hash.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_hash.rs @@ -1,4 +1,4 @@ -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; use std::cmp::Ordering; /// The output of a [`PostgresHash`](crate::datum::sql_entity_graph::PostgresHash) from `quote::ToTokens::to_tokens`. @@ -10,6 +10,7 @@ pub struct PostgresHashEntity { pub full_path: &'static str, pub module_path: &'static str, pub id: core::any::TypeId, + pub to_sql_config: ToSqlConfigEntity, } impl PostgresHashEntity { diff --git a/pgx/src/datum/sql_entity_graph/postgres_ord.rs b/pgx/src/datum/sql_entity_graph/postgres_ord.rs index 0020100ea..6ec1be5bf 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_ord.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_ord.rs @@ -1,4 +1,4 @@ -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; use std::cmp::Ordering; /// The output of a [`PostgresOrd`](crate::datum::sql_entity_graph::PostgresOrd) from `quote::ToTokens::to_tokens`. @@ -10,6 +10,7 @@ pub struct PostgresOrdEntity { pub full_path: &'static str, pub module_path: &'static str, pub id: core::any::TypeId, + pub to_sql_config: ToSqlConfigEntity, } impl PostgresOrdEntity { diff --git a/pgx/src/datum/sql_entity_graph/postgres_type.rs b/pgx/src/datum/sql_entity_graph/postgres_type.rs index 0873bd15d..7041a157c 100644 --- a/pgx/src/datum/sql_entity_graph/postgres_type.rs +++ b/pgx/src/datum/sql_entity_graph/postgres_type.rs @@ -5,7 +5,7 @@ use std::{ hash::{Hash, Hasher}, }; -use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql}; +use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql, ToSqlConfigEntity}; /// The output of a [`PostgresType`](crate::datum::sql_entity_graph::PostgresType) from `quote::ToTokens::to_tokens`. #[derive(Debug, Clone, PartialEq, Eq)] @@ -20,6 +20,7 @@ pub struct PostgresTypeEntity { pub in_fn_module_path: String, pub out_fn: &'static str, pub out_fn_module_path: String, + pub to_sql_config: ToSqlConfigEntity, } impl Hash for PostgresTypeEntity { diff --git a/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs b/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs index 1d9598959..de7650f8f 100644 --- a/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs +++ b/pgx/src/datum/sql_entity_graph/sql_graph_entity.rs @@ -86,37 +86,61 @@ impl ToSql for SqlGraphEntity { #[tracing::instrument(level = "debug", skip(self, context), fields(identifier = %self.rust_identifier()))] fn to_sql(&self, context: &super::PgxSql) -> eyre::Result { match self { - SqlGraphEntity::Schema(item) => if item.name != "public" && item.name != "pg_catalog" { - item.to_sql(context) - } else { Ok(String::default()) }, - SqlGraphEntity::CustomSql(item) => { - item.to_sql(context) - }, - SqlGraphEntity::Function(item) => if context.graph.neighbors_undirected(context.externs.get(item).unwrap().clone()).any(|neighbor| { - let neighbor_item = &context.graph[neighbor]; - match neighbor_item { - SqlGraphEntity::Type(PostgresTypeEntity { in_fn, in_fn_module_path, out_fn, out_fn_module_path, .. }) => { - let is_in_fn = item.full_path.starts_with(in_fn_module_path) && item.full_path.ends_with(in_fn); - if is_in_fn { - tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an in_fn."); - } - let is_out_fn = item.full_path.starts_with(out_fn_module_path) && item.full_path.ends_with(out_fn); - if is_out_fn { - tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an out_fn."); - } - is_in_fn || is_out_fn - }, - _ => false, + SqlGraphEntity::Schema(item) => { + if item.name != "public" && item.name != "pg_catalog" { + item.to_sql(context) + } else { + Ok(String::default()) } - }) { - Ok(String::default()) - } else { item.to_sql(context) }, - SqlGraphEntity::Type(item) => item.to_sql(context), + } + SqlGraphEntity::CustomSql(item) => item.to_sql(context), + SqlGraphEntity::Function(item) => { + if let Some(result) = item.to_sql_config.to_sql(self, context) { + return result; + } + if context.graph.neighbors_undirected(context.externs.get(item).unwrap().clone()).any(|neighbor| { + let neighbor_item = &context.graph[neighbor]; + match neighbor_item { + SqlGraphEntity::Type(PostgresTypeEntity { in_fn, in_fn_module_path, out_fn, out_fn_module_path, .. }) => { + let is_in_fn = item.full_path.starts_with(in_fn_module_path) && item.full_path.ends_with(in_fn); + if is_in_fn { + tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an in_fn."); + } + let is_out_fn = item.full_path.starts_with(out_fn_module_path) && item.full_path.ends_with(out_fn); + if is_out_fn { + tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an out_fn."); + } + is_in_fn || is_out_fn + }, + _ => false, + } + }) { + Ok(String::default()) + } else { + item.to_sql(context) + } + } + SqlGraphEntity::Type(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), SqlGraphEntity::BuiltinType(_) => Ok(String::default()), - SqlGraphEntity::Enum(item) => item.to_sql(context), - SqlGraphEntity::Ord(item) => item.to_sql(context), - SqlGraphEntity::Hash(item) => item.to_sql(context), - SqlGraphEntity::Aggregate(item) => item.to_sql(context), + SqlGraphEntity::Enum(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), + SqlGraphEntity::Ord(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), + SqlGraphEntity::Hash(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), + SqlGraphEntity::Aggregate(item) => item + .to_sql_config + .to_sql(self, context) + .unwrap_or_else(|| item.to_sql(context)), SqlGraphEntity::ExtensionRoot(item) => item.to_sql(context), } }