Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 79 additions & 61 deletions pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ extern crate proc_macro;

use proc_macro::TokenStream;
use std::collections::HashSet;
use std::ffi::CString;

use proc_macro2::Ident;
use quote::{format_ident, quote, ToTokens};
Expand Down Expand Up @@ -1064,91 +1065,108 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
}

/// Derives the `GucEnum` trait, so that normal Rust enums can be used as a GUC.
#[proc_macro_derive(PostgresGucEnum, attributes(hidden))]
#[proc_macro_derive(PostgresGucEnum, attributes(name, hidden))]
pub fn postgres_guc_enum(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

impl_guc_enum(ast).unwrap_or_else(|e| e.into_compile_error()).into()
}

fn impl_guc_enum(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();
use std::str::FromStr;
use syn::parse::Parse;

// validate that we're only operating on an enum
let enum_data = match ast.data {
Data::Enum(e) => e,
_ => {
return Err(syn::Error::new(
ast.span(),
"#[derive(PostgresGucEnum)] can only be applied to enums",
))
}
};
let enum_name = ast.ident;
let enum_len = enum_data.variants.len();

let mut from_match_arms = proc_macro2::TokenStream::new();
for (idx, e) in enum_data.variants.iter().enumerate() {
let label = &e.ident;
let idx = idx as i32;
from_match_arms.extend(quote! { #idx => #enum_name::#label, })
enum GucEnumAttribute {
Name(CString),
Hidden(bool),
}
from_match_arms.extend(quote! { _ => panic!("Unrecognized ordinal ")});

let mut ordinal_match_arms = proc_macro2::TokenStream::new();
for (idx, e) in enum_data.variants.iter().enumerate() {
let label = &e.ident;
let idx = idx as i32;
ordinal_match_arms.extend(quote! { #enum_name::#label => #idx, });
impl Parse for GucEnumAttribute {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let ident: Ident = input.parse()?;
let _: syn::token::Eq = input.parse()?;
match ident.to_string().as_str() {
"name" => input.parse::<syn::LitCStr>().map(|val| Self::Name(val.value())),
"hidden" => input.parse::<syn::LitBool>().map(|val| Self::Hidden(val.value())),
x => Err(syn::Error::new(input.span(), format!("unknown attribute {x}"))),
}
}
}

let mut build_array_body = proc_macro2::TokenStream::new();
for (idx, e) in enum_data.variants.iter().enumerate() {
let label = e.ident.to_string();
let mut hidden = false;

for att in e.attrs.iter() {
let att = quote! {#att}.to_string();
if att == "# [hidden]" {
hidden = true;
break;
// validate that we're only operating on an enum
let Data::Enum(data) = ast.data.clone() else {
return Err(syn::Error::new(
ast.span(),
"#[derive(PostgresGucEnum)] can only be applied to enums",
));
};
let ident = ast.ident.clone();
let mut config = Vec::new();
for (index, variant) in data.variants.iter().enumerate() {
let default_name = CString::from_str(&variant.ident.to_string())
.expect("the identifier contains a null character.");
let default_val = index as i32;
let default_hidden = false;
let mut name = None;
let mut hidden = None;
for attr in variant.attrs.iter() {
let tokens = attr.meta.require_name_value()?.to_token_stream();
let pair: GucEnumAttribute = syn::parse2(tokens)?;
match pair {
GucEnumAttribute::Name(value) => {
if name.replace(value).is_some() {
return Err(syn::Error::new(ast.span(), "too many #[name] attributes"));
}
}
GucEnumAttribute::Hidden(value) => {
if hidden.replace(value).is_some() {
return Err(syn::Error::new(ast.span(), "too many #[hidden] attributes"));
}
}
}
}

build_array_body.extend(quote! {
::pgrx::pgbox::PgBox::<_, ::pgrx::pgbox::AllocatedByPostgres>::with(&mut slice[#idx], |v| {
v.name = ::pgrx::memcxt::PgMemoryContexts::TopMemoryContext.pstrdup(#label);
v.val = #idx as i32;
v.hidden = #hidden;
});
});
let ident = variant.ident.clone();
let name = name.unwrap_or(default_name);
let val = default_val;
let hidden = hidden.unwrap_or(default_hidden);
config.push((ident, name, val, hidden));
}

stream.extend(quote! {
unsafe impl ::pgrx::guc::GucEnum<#enum_name> for #enum_name {
fn from_ordinal(ordinal: i32) -> #enum_name {
let config_idents = config.iter().map(|x| &x.0).collect::<Vec<_>>();
let config_names = config.iter().map(|x| &x.1).collect::<Vec<_>>();
let config_vals = config.iter().map(|x| &x.2).collect::<Vec<_>>();
let config_hiddens = config.iter().map(|x| &x.3).collect::<Vec<_>>();

Ok(quote! {
unsafe impl ::pgrx::guc::GucEnum for #ident {
fn from_ordinal(ordinal: i32) -> Self {
match ordinal {
#from_match_arms
#(#config_vals => Self::#config_idents,)*
_ => panic!("Unrecognized ordinal"),
}
}

fn to_ordinal(&self) -> i32 {
match *self {
#ordinal_match_arms
match self {
#(Self::#config_idents => #config_vals,)*
}
}

fn config_matrix() -> *const ::pgrx::pg_sys::config_enum_entry {
unsafe {
let slice = ::pgrx::memcxt::PgMemoryContexts::TopMemoryContext.palloc0_slice::<::pgrx::pg_sys::config_enum_entry>(#enum_len + 1usize);
#build_array_body
slice.as_ptr()
}
}
const CONFIG_ENUM_ENTRY: *const ::pgrx::pg_sys::config_enum_entry = [
#(
::pgrx::pg_sys::config_enum_entry {
name: #config_names.as_ptr(),
val: #config_vals,
hidden: #config_hiddens,
},
)*
::pgrx::pg_sys::config_enum_entry {
name: core::ptr::null(),
val: 0,
hidden: false,
},
].as_ptr();
}
});

Ok(stream)
})
}

#[derive(Debug, Hash, Ord, PartialOrd, Eq, PartialEq)]
Expand Down
1 change: 1 addition & 0 deletions pgrx-pg-sys/include/pg18.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
#include "partitioning/partprune.h"
#include "plpgsql.h"
#include "postmaster/bgworker.h"
#include "postmaster/interrupt.h"
#include "postmaster/postmaster.h"
#include "postmaster/syslogger.h"
#include "replication/logical.h"
Expand Down
61 changes: 34 additions & 27 deletions pgrx-tests/src/tests/guc_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ mod tests {
fn test_bool_guc() {
static GUC: GucSetting<bool> = GucSetting::<bool>::new(true);
GucRegistry::define_bool_guc(
"test.bool",
"test bool gucs",
"test bool gucs",
c"test.bool",
c"test bool gucs",
c"test bool gucs",
&GUC,
GucContext::Userset,
GucFlags::default(),
Expand All @@ -41,9 +41,9 @@ mod tests {
fn test_int_guc() {
static GUC: GucSetting<i32> = GucSetting::<i32>::new(42);
GucRegistry::define_int_guc(
"test.int",
"test int guc",
"test int guc",
c"test.int",
c"test int guc",
c"test int guc",
&GUC,
-1,
42,
Expand All @@ -63,9 +63,9 @@ mod tests {
fn test_mb_guc() {
static GUC: GucSetting<i32> = GucSetting::<i32>::new(42);
GucRegistry::define_int_guc(
"test.megabytes",
"test megabytes guc",
"test megabytes guc",
c"test.megabytes",
c"test megabytes guc",
c"test megabytes guc",
&GUC,
-1,
42000,
Expand All @@ -82,9 +82,9 @@ mod tests {
fn test_float_guc() {
static GUC: GucSetting<f64> = GucSetting::<f64>::new(42.42);
GucRegistry::define_float_guc(
"test.float",
"test float guc",
"test float guc",
c"test.float",
c"test float guc",
c"test float guc",
&GUC,
-1.0f64,
43.0f64,
Expand All @@ -108,9 +108,9 @@ mod tests {
static GUC: GucSetting<Option<CString>> =
GucSetting::<Option<CString>>::new(Some(c"this is a test"));
GucRegistry::define_string_guc(
"test.string",
"test string guc",
"test string guc",
c"test.string",
c"test string guc",
c"test string guc",
&GUC,
GucContext::Userset,
GucFlags::default(),
Expand All @@ -129,9 +129,9 @@ mod tests {
fn test_string_guc_null_default() {
static GUC: GucSetting<Option<CString>> = GucSetting::<Option<CString>>::new(None);
GucRegistry::define_string_guc(
"test.string",
"test string guc",
"test string guc",
c"test.string",
c"test string guc",
c"test string guc",
&GUC,
GucContext::Userset,
GucFlags::default(),
Expand All @@ -152,12 +152,16 @@ mod tests {
One,
Two,
Three,
#[name = c"five"]
Four,
#[hidden = true]
Six,
}
static GUC: GucSetting<TestEnum> = GucSetting::<TestEnum>::new(TestEnum::Two);
GucRegistry::define_enum_guc(
"test.enum",
"test enum guc",
"test enum guc",
c"test.enum",
c"test enum guc",
c"test enum guc",
&GUC,
GucContext::Userset,
GucFlags::default(),
Expand All @@ -169,6 +173,9 @@ mod tests {

Spi::run("SET test.enum = 'three'").expect("SPI failed");
assert_eq!(GUC.get(), TestEnum::Three);

Spi::run("SET test.enum = 'five'").expect("SPI failed");
assert_eq!(GUC.get(), TestEnum::Four);
}

#[pg_test]
Expand All @@ -179,17 +186,17 @@ mod tests {
static GUC_NO_SHOW: GucSetting<bool> = GucSetting::<bool>::new(true);
static GUC_NO_RESET_ALL: GucSetting<bool> = GucSetting::<bool>::new(true);
GucRegistry::define_bool_guc(
"test.no_show",
"test no show gucs",
"test no show gucs",
c"test.no_show",
c"test no show gucs",
c"test no show gucs",
&GUC_NO_SHOW,
GucContext::Userset,
no_show_flag,
);
GucRegistry::define_bool_guc(
"test.no_reset_all",
"test no reset gucs",
"test no reset gucs",
c"test.no_reset_all",
c"test no reset gucs",
c"test no reset gucs",
&GUC_NO_RESET_ALL,
GucContext::Userset,
GucFlags::NO_RESET_ALL,
Expand Down
Loading
Loading