Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extend sql generation via to_sql attribute #410

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions pgx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ Optionally accepts the following attributes:
* `parallel_unsafe`: Corresponds to [`PARALLEL UNSAFE`](https://www.postgresql.org/docs/current/sql-createfunction.html).
* `parallel_restricted`: Corresponds to [`PARALLEL RESTRICTED`](https://www.postgresql.org/docs/current/sql-createfunction.html).
* `no_guard`: Do not use `#[pg_guard]` with the function.
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).

Functions can accept and return any type which `pgx` supports. `pgx` supports many PostgreSQL types by default.
New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`].
Expand Down Expand Up @@ -597,7 +598,7 @@ enum DogNames {
```

*/
#[proc_macro_derive(PostgresEnum, attributes(requires))]
#[proc_macro_derive(PostgresEnum, attributes(requires, pgx))]
pub fn postgres_enum(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -683,9 +684,12 @@ Optionally accepts the following attributes:

* `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type.
* `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type.

* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires))]
#[proc_macro_derive(
PostgresType,
attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx)
)]
pub fn postgres_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -911,12 +915,16 @@ enum DogNames {
Brandy,
}
```
Optionally accepts the following attributes:

bitwalker marked this conversation as resolved.
Show resolved Hide resolved
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresEq)]
#[proc_macro_derive(PostgresEq, attributes(pgx))]
pub fn postgres_eq(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_eq(ast).into()
impl_postgres_eq(ast)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

/**
Expand All @@ -935,12 +943,16 @@ enum DogNames {
Brandy,
}
```
Optionally accepts the following attributes:

bitwalker marked this conversation as resolved.
Show resolved Hide resolved
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresOrd)]
#[proc_macro_derive(PostgresOrd, attributes(pgx))]
pub fn postgres_ord(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_ord(ast).into()
impl_postgres_ord(ast)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

/**
Expand All @@ -956,12 +968,16 @@ enum DogNames {
Brandy,
}
```
Optionally accepts the following attributes:

bitwalker marked this conversation as resolved.
Show resolved Hide resolved
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresHash)]
#[proc_macro_derive(PostgresHash, attributes(pgx))]
pub fn postgres_hash(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_hash(ast).into()
impl_postgres_hash(ast)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

/**
Expand Down Expand Up @@ -991,15 +1007,26 @@ pub fn pg_aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream {
}

/**
An inner attribute for [`#[pg_aggregate]`](macro@pg_aggregate).
A helper attribute for various contexts.

## Usage with [`#[pg_aggregate]`](macro@pg_aggregate).

It can be decorated on functions inside a [`#[pg_aggregate]`](macro@pg_aggregate) implementation.
In this position, it takes the same args as [`#[pg_extern]`](macro@pg_extern), and those args have the same effect.

Used outside of a [`#[pg_aggregate]`](macro@pg_aggregate), this does nothing.
## Usage for configuring SQL generation

This attribute can be used to control the behavior of the SQL generator on a decorated item,
e.g. `#[pgx(sql = false)]`

Currently `sql` can be provided one of the following:

* Disable SQL generation with `#[pgx(sql = false)]`
* Call custom SQL generator function with `#[pgx(sql = path::to_function)]`
* Render a specific fragment of SQL with a string `#[pgx(sql = "CREATE OR REPLACE FUNCTION ...")]`

*/
#[proc_macro_attribute]
pub fn pgx(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}

17 changes: 9 additions & 8 deletions pgx-macros/src/operators.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use pgx_utils::{operator_common::*, sql_entity_graph};

use quote::ToTokens;
use syn::DeriveInput;

pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(eq(&ast.ident));
stream.extend(ne(&ast.ident));

stream
Ok(stream)
}

pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(lt(&ast.ident));
Expand All @@ -20,19 +21,19 @@ pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> proc_macro2::TokenStream {
stream.extend(ge(&ast.ident));
stream.extend(cmp(&ast.ident));

let sql_graph_entity_item = sql_entity_graph::PostgresOrd::new(ast.ident.clone());
let sql_graph_entity_item = sql_entity_graph::PostgresOrd::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

stream
Ok(stream)
}

pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> proc_macro2::TokenStream {
pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();

stream.extend(hash(&ast.ident));

let sql_graph_entity_item = sql_entity_graph::PostgresHash::new(ast.ident.clone());
let sql_graph_entity_item = sql_entity_graph::PostgresHash::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

stream
Ok(stream)
}
17 changes: 13 additions & 4 deletions pgx-macros/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ use proc_macro2::{Ident, Span};
use quote::{quote, quote_spanned, ToTokens};
use std::ops::Deref;
use std::str::FromStr;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
FnArg, ForeignItem, ForeignItemFn, Generics, ItemFn, ItemForeignMod, Pat, ReturnType,
Signature, Type, Visibility,
Signature, Token, Type, Visibility,
};

pub struct PgGuardRewriter();
Expand Down Expand Up @@ -224,7 +225,11 @@ impl PgGuardRewriter {
let return_type =
proc_macro2::TokenStream::from_str(return_type.trim_start_matches("->")).unwrap();
let return_type = quote! {impl std::iter::Iterator<Item = #return_type>};
let attrs = entity_submission.unwrap().extern_attr_tokens();
let attrs = entity_submission
.unwrap()
.extern_attrs()
.iter()
.collect::<Punctuated<_, Token![,]>>();

func.sig.output = ReturnType::Default;
let sig = func.sig;
Expand Down Expand Up @@ -708,7 +713,9 @@ impl FunctionSignatureRewriter {
fn args(&self, is_raw: bool) -> proc_macro2::TokenStream {
if self.func.sig.inputs.len() == 1 && self.return_type_is_datum() {
if let FnArg::Typed(ty) = self.func.sig.inputs.first().unwrap() {
if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") || type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo") {
if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo")
|| type_matches(&ty.ty, "pgx :: pg_sys :: FunctionCallInfo")
{
return proc_macro2::TokenStream::new();
}
}
Expand Down Expand Up @@ -738,7 +745,9 @@ impl FunctionSignatureRewriter {
quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg::<#option_type>(fcinfo, #i);
}
} else if type_matches(&type_, "pg_sys :: FunctionCallInfo") || type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo") {
} else if type_matches(&type_, "pg_sys :: FunctionCallInfo")
|| type_matches(&type_, "pgx :: pg_sys :: FunctionCallInfo")
{
have_fcinfo = true;
quote_spanned! {ident.span()=>
let #name = fcinfo;
Expand Down
112 changes: 112 additions & 0 deletions pgx-tests/src/tests/schema_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,66 @@ use pgx::*;

#[pgx::pg_schema]
mod test_schema {
use pgx::datum::sql_entity_graph::{PgxSql, SqlGraphEntity};
use pgx::*;
use serde::{Deserialize, Serialize};

#[pg_extern]
fn func_in_diff_schema() {}

#[pg_extern(sql = false)]
fn func_elided_from_schema() {}

#[pg_extern(sql = generate_function)]
fn func_generated_with_custom_sql() {}

#[derive(Debug, PostgresType, Serialize, Deserialize)]
pub struct TestType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[pgx(sql = false)]
pub struct ElidedType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[pgx(sql = generate_type)]
pub struct OtherType(pub u64);

#[derive(Debug, PostgresType, Serialize, Deserialize)]
#[pgx(sql = "CREATE TYPE test_schema.ManuallyRenderedType;")]
pub struct OverriddenType(pub u64);

fn generate_function(
entity: &SqlGraphEntity,
_context: &PgxSql,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
if let SqlGraphEntity::Function(ref func) = entity {
Ok(format!("\
CREATE OR REPLACE FUNCTION test_schema.\"func_generated_with_custom_name\"() RETURNS void\n\
LANGUAGE c /* Rust */\n\
AS 'MODULE_PATHNAME', '{unaliased_name}_wrapper';\
",
unaliased_name = func.unaliased_name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅 I'm gonna need to clean up this API a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a similar thought when I was thinking through how to expose stuff to those functions haha. In my perfect world, users of this would get access to a hypothetical SqlBuilder API that let you replicate this using the builder pattern:

fn generate_function(entity: &SqlGraphEntity, builder: &mut SqlBuilder, _context: &PgxSql) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
    if let SqlGraphEntity::Function(ref f) = entity {
        // get_or_create lets you modify existing entities after initial construction, but before the builder is consumed
        // returns a entity builder that holds a mutable reference to the parent builder
        // when the entity is finalized, only then is the entity added to the parent builder
        let mut func = builder.get_or_create_function(f.name);
        func
          .args(&[]) // defaults to no arguments if not set
          .returns("void") // defaults to 'void' if not set
          .replace(true) // indicates the builder should use CREATE OR REPLACE
          .lang(FunctionLanguage::Rust) // lowers to LANGUAGE c /* Rust */
          .as_nif("MODULE_PATHNAME", format!("{}_wrapper", func.unaliased_name)) // indicates the function is a Natively Implemented Function from the given module
          .build()?; // adds the built function to the parent builder
        Ok(())
    } else {
        panic!("expected extern function entity, got {:?}", entity);
    }
}

It'd be a big undertaking to cover all possible entities, but could probably start by just covering the cases that pgx itself requires, expose some kind of escape hatch akin to extension_sql!, and expand the API with specific entities as needed. Not sure if it is worth it or not, but would be a lot nicer and less error prone to work with!

))
} else {
panic!("expected extern function entity, got {:?}", entity);
}
}

fn generate_type(
entity: &SqlGraphEntity,
_context: &PgxSql,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
if let SqlGraphEntity::Type(ref ty) = entity {
Ok(format!(
"\n\
CREATE TYPE test_schema.Custom{name};\
",
name = ty.name,
))
} else {
panic!("expected type entity, got {:?}", entity);
}
}
}

#[pg_extern(schema = "test_schema")]
Expand Down Expand Up @@ -44,4 +96,64 @@ mod tests {
fn test_type_in_different_schema() {
Spi::run("SELECT type_in_diff_schema();");
}

#[pg_test]
fn elided_extern_is_elided() {
// Validate that a function we know exists, exists
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_in_diff_schema');",
)
.expect("expected result");
assert_eq!(result, true);

// Validate that the function we expect not to exist, doesn't
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_elided_from_schema');",
)
.expect("expected result");
assert_eq!(result, false);
}

#[pg_test]
fn elided_type_is_elided() {
// Validate that a type we know exists, exists
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'testtype');")
.expect("expected result");
assert_eq!(result, true);

// Validate that the type we expect not to exist, doesn't
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'elidedtype');")
.expect("expected result");
assert_eq!(result, false);
}

#[pg_test]
fn custom_to_sql_extern() {
// Validate that the function we generated has the modifications we expect
let result: bool = Spi::get_one("SELECT exists(SELECT 1 FROM pg_proc WHERE proname = 'func_generated_with_custom_name');").expect("expected result");
assert_eq!(result, true);

Spi::run("SELECT test_schema.func_generated_with_custom_name();");
}

#[pg_test]
fn custom_to_sql_type() {
// Validate that the type we generated has the expected modifications
let result: bool =
Spi::get_one("SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'customothertype');")
.expect("expected result");
assert_eq!(result, true);
}

#[pg_test]
fn custom_handwritten_to_sql_type() {
// Validate that the SQL we provided was used
let result: bool = Spi::get_one(
"SELECT exists(SELECT 1 FROM pg_type WHERE typname = 'manuallyrenderedtype');",
)
.expect("expected result");
assert_eq!(result, true);
}
}
6 changes: 6 additions & 0 deletions pgx-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
let name = name[1..name.len() - 1].to_string();
args.insert(ExternArgs::Name(name.to_string()))
}
// Recognized, but not handled as an extern argument
"sql" => {
let _punc = itr.next().unwrap();
let _value = itr.next().unwrap();
false
}
_ => false,
};
}
Expand Down
4 changes: 4 additions & 0 deletions pgx-utils/src/sql_entity_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@ mod extension_sql;
mod pg_aggregate;
mod pg_extern;
mod pg_schema;
mod pgx_attribute;
mod positioning_ref;
mod postgres_enum;
mod postgres_hash;
mod postgres_ord;
mod postgres_type;
mod to_sql;

pub use super::ExternArgs;
pub use extension_sql::{ExtensionSql, ExtensionSqlFile, SqlDeclared};
pub use pg_aggregate::PgAggregate;
pub use pg_extern::{Argument, PgExtern, PgOperator};
pub use pg_schema::Schema;
pub use pgx_attribute::{ArgValue, NameValueArg, PgxArg, PgxAttribute};
pub use positioning_ref::PositioningRef;
pub use postgres_enum::PostgresEnum;
pub use postgres_hash::PostgresHash;
pub use postgres_ord::PostgresOrd;
pub use postgres_type::PostgresType;
pub use to_sql::ToSqlConfig;

/// Reexports for the pgx SQL generator binaries.
#[doc(hidden)]
Expand Down
Loading