diff --git a/pgrx-macros/src/lib.rs b/pgrx-macros/src/lib.rs index 52d02500b4..37c70bdfcc 100644 --- a/pgrx-macros/src/lib.rs +++ b/pgrx-macros/src/lib.rs @@ -11,6 +11,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use std::collections::HashSet; +use std::ffi::CString; use proc_macro2::Ident; use quote::{format_ident, quote, ToTokens}; @@ -1064,7 +1065,7 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result } /// Derives the `GucEnum` trait, so that normal Rust enums can be used as a GUC. -#[proc_macro_derive(PostgresGucEnum, attributes(hidden))] +#[proc_macro_derive(PostgresGucEnum, attributes(name, hidden))] pub fn postgres_guc_enum(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -1072,83 +1073,100 @@ pub fn postgres_guc_enum(input: TokenStream) -> TokenStream { } fn impl_guc_enum(ast: DeriveInput) -> syn::Result { - let mut stream = proc_macro2::TokenStream::new(); + use std::str::FromStr; + use syn::parse::Parse; - // validate that we're only operating on an enum - let enum_data = match ast.data { - Data::Enum(e) => e, - _ => { - return Err(syn::Error::new( - ast.span(), - "#[derive(PostgresGucEnum)] can only be applied to enums", - )) - } - }; - let enum_name = ast.ident; - let enum_len = enum_data.variants.len(); - - let mut from_match_arms = proc_macro2::TokenStream::new(); - for (idx, e) in enum_data.variants.iter().enumerate() { - let label = &e.ident; - let idx = idx as i32; - from_match_arms.extend(quote! { #idx => #enum_name::#label, }) + enum GucEnumAttribute { + Name(CString), + Hidden(bool), } - from_match_arms.extend(quote! { _ => panic!("Unrecognized ordinal ")}); - let mut ordinal_match_arms = proc_macro2::TokenStream::new(); - for (idx, e) in enum_data.variants.iter().enumerate() { - let label = &e.ident; - let idx = idx as i32; - ordinal_match_arms.extend(quote! { #enum_name::#label => #idx, }); + impl Parse for GucEnumAttribute { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let ident: Ident = input.parse()?; + let _: syn::token::Eq = input.parse()?; + match ident.to_string().as_str() { + "name" => input.parse::().map(|val| Self::Name(val.value())), + "hidden" => input.parse::().map(|val| Self::Hidden(val.value())), + x => Err(syn::Error::new(input.span(), format!("unknown attribute {x}"))), + } + } } - let mut build_array_body = proc_macro2::TokenStream::new(); - for (idx, e) in enum_data.variants.iter().enumerate() { - let label = e.ident.to_string(); - let mut hidden = false; - - for att in e.attrs.iter() { - let att = quote! {#att}.to_string(); - if att == "# [hidden]" { - hidden = true; - break; + // validate that we're only operating on an enum + let Data::Enum(data) = ast.data.clone() else { + return Err(syn::Error::new( + ast.span(), + "#[derive(PostgresGucEnum)] can only be applied to enums", + )); + }; + let ident = ast.ident.clone(); + let mut config = Vec::new(); + for (index, variant) in data.variants.iter().enumerate() { + let default_name = CString::from_str(&variant.ident.to_string()) + .expect("the identifier contains a null character."); + let default_val = index as i32; + let default_hidden = false; + let mut name = None; + let mut hidden = None; + for attr in variant.attrs.iter() { + let tokens = attr.meta.require_name_value()?.to_token_stream(); + let pair: GucEnumAttribute = syn::parse2(tokens)?; + match pair { + GucEnumAttribute::Name(value) => { + if name.replace(value).is_some() { + return Err(syn::Error::new(ast.span(), "too many #[name] attributes")); + } + } + GucEnumAttribute::Hidden(value) => { + if hidden.replace(value).is_some() { + return Err(syn::Error::new(ast.span(), "too many #[hidden] attributes")); + } + } } } - - build_array_body.extend(quote! { - ::pgrx::pgbox::PgBox::<_, ::pgrx::pgbox::AllocatedByPostgres>::with(&mut slice[#idx], |v| { - v.name = ::pgrx::memcxt::PgMemoryContexts::TopMemoryContext.pstrdup(#label); - v.val = #idx as i32; - v.hidden = #hidden; - }); - }); + let ident = variant.ident.clone(); + let name = name.unwrap_or(default_name); + let val = default_val; + let hidden = hidden.unwrap_or(default_hidden); + config.push((ident, name, val, hidden)); } - - stream.extend(quote! { - unsafe impl ::pgrx::guc::GucEnum<#enum_name> for #enum_name { - fn from_ordinal(ordinal: i32) -> #enum_name { + let config_idents = config.iter().map(|x| &x.0).collect::>(); + let config_names = config.iter().map(|x| &x.1).collect::>(); + let config_vals = config.iter().map(|x| &x.2).collect::>(); + let config_hiddens = config.iter().map(|x| &x.3).collect::>(); + + Ok(quote! { + unsafe impl ::pgrx::guc::GucEnum for #ident { + fn from_ordinal(ordinal: i32) -> Self { match ordinal { - #from_match_arms + #(#config_vals => Self::#config_idents,)* + _ => panic!("Unrecognized ordinal"), } } fn to_ordinal(&self) -> i32 { - match *self { - #ordinal_match_arms + match self { + #(Self::#config_idents => #config_vals,)* } } - fn config_matrix() -> *const ::pgrx::pg_sys::config_enum_entry { - unsafe { - let slice = ::pgrx::memcxt::PgMemoryContexts::TopMemoryContext.palloc0_slice::<::pgrx::pg_sys::config_enum_entry>(#enum_len + 1usize); - #build_array_body - slice.as_ptr() - } - } + const CONFIG_ENUM_ENTRY: *const ::pgrx::pg_sys::config_enum_entry = [ + #( + ::pgrx::pg_sys::config_enum_entry { + name: #config_names.as_ptr(), + val: #config_vals, + hidden: #config_hiddens, + }, + )* + ::pgrx::pg_sys::config_enum_entry { + name: core::ptr::null(), + val: 0, + hidden: false, + }, + ].as_ptr(); } - }); - - Ok(stream) + }) } #[derive(Debug, Hash, Ord, PartialOrd, Eq, PartialEq)] diff --git a/pgrx-pg-sys/include/pg18.h b/pgrx-pg-sys/include/pg18.h index 948c26313f..3c8d18be1f 100644 --- a/pgrx-pg-sys/include/pg18.h +++ b/pgrx-pg-sys/include/pg18.h @@ -139,6 +139,7 @@ #include "partitioning/partprune.h" #include "plpgsql.h" #include "postmaster/bgworker.h" +#include "postmaster/interrupt.h" #include "postmaster/postmaster.h" #include "postmaster/syslogger.h" #include "replication/logical.h" diff --git a/pgrx-tests/src/tests/guc_tests.rs b/pgrx-tests/src/tests/guc_tests.rs index 49eb704e66..79c570954c 100644 --- a/pgrx-tests/src/tests/guc_tests.rs +++ b/pgrx-tests/src/tests/guc_tests.rs @@ -21,9 +21,9 @@ mod tests { fn test_bool_guc() { static GUC: GucSetting = GucSetting::::new(true); GucRegistry::define_bool_guc( - "test.bool", - "test bool gucs", - "test bool gucs", + c"test.bool", + c"test bool gucs", + c"test bool gucs", &GUC, GucContext::Userset, GucFlags::default(), @@ -41,9 +41,9 @@ mod tests { fn test_int_guc() { static GUC: GucSetting = GucSetting::::new(42); GucRegistry::define_int_guc( - "test.int", - "test int guc", - "test int guc", + c"test.int", + c"test int guc", + c"test int guc", &GUC, -1, 42, @@ -63,9 +63,9 @@ mod tests { fn test_mb_guc() { static GUC: GucSetting = GucSetting::::new(42); GucRegistry::define_int_guc( - "test.megabytes", - "test megabytes guc", - "test megabytes guc", + c"test.megabytes", + c"test megabytes guc", + c"test megabytes guc", &GUC, -1, 42000, @@ -82,9 +82,9 @@ mod tests { fn test_float_guc() { static GUC: GucSetting = GucSetting::::new(42.42); GucRegistry::define_float_guc( - "test.float", - "test float guc", - "test float guc", + c"test.float", + c"test float guc", + c"test float guc", &GUC, -1.0f64, 43.0f64, @@ -108,9 +108,9 @@ mod tests { static GUC: GucSetting> = GucSetting::>::new(Some(c"this is a test")); GucRegistry::define_string_guc( - "test.string", - "test string guc", - "test string guc", + c"test.string", + c"test string guc", + c"test string guc", &GUC, GucContext::Userset, GucFlags::default(), @@ -129,9 +129,9 @@ mod tests { fn test_string_guc_null_default() { static GUC: GucSetting> = GucSetting::>::new(None); GucRegistry::define_string_guc( - "test.string", - "test string guc", - "test string guc", + c"test.string", + c"test string guc", + c"test string guc", &GUC, GucContext::Userset, GucFlags::default(), @@ -152,12 +152,16 @@ mod tests { One, Two, Three, + #[name = c"five"] + Four, + #[hidden = true] + Six, } static GUC: GucSetting = GucSetting::::new(TestEnum::Two); GucRegistry::define_enum_guc( - "test.enum", - "test enum guc", - "test enum guc", + c"test.enum", + c"test enum guc", + c"test enum guc", &GUC, GucContext::Userset, GucFlags::default(), @@ -169,6 +173,9 @@ mod tests { Spi::run("SET test.enum = 'three'").expect("SPI failed"); assert_eq!(GUC.get(), TestEnum::Three); + + Spi::run("SET test.enum = 'five'").expect("SPI failed"); + assert_eq!(GUC.get(), TestEnum::Four); } #[pg_test] @@ -179,17 +186,17 @@ mod tests { static GUC_NO_SHOW: GucSetting = GucSetting::::new(true); static GUC_NO_RESET_ALL: GucSetting = GucSetting::::new(true); GucRegistry::define_bool_guc( - "test.no_show", - "test no show gucs", - "test no show gucs", + c"test.no_show", + c"test no show gucs", + c"test no show gucs", &GUC_NO_SHOW, GucContext::Userset, no_show_flag, ); GucRegistry::define_bool_guc( - "test.no_reset_all", - "test no reset gucs", - "test no reset gucs", + c"test.no_reset_all", + c"test no reset gucs", + c"test no reset gucs", &GUC_NO_RESET_ALL, GucContext::Userset, GucFlags::NO_RESET_ALL, diff --git a/pgrx/src/guc.rs b/pgrx/src/guc.rs index 4311f01bd2..46948db9bb 100644 --- a/pgrx/src/guc.rs +++ b/pgrx/src/guc.rs @@ -8,10 +8,11 @@ //LICENSE //LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. //! Provides a safe interface into Postgres' Configuration System (GUC) -use crate::{pg_sys, PgMemoryContexts}; +use crate::pg_sys; use core::ffi::CStr; pub use pgrx_macros::PostgresGucEnum; -use std::{cell::Cell, ffi::CString}; +use std::cell::Cell; +use std::ffi::CString; /// Defines at what level this GUC can be set pub enum GucContext { @@ -104,25 +105,31 @@ bitflags! { /// A trait that can be derived using [`PostgresGucEnum`] on enums, such that they can be /// used as a GUC. -pub unsafe trait GucEnum -where - T: Copy, -{ - fn from_ordinal(ordinal: i32) -> T; +/// +/// # Safety +/// +/// [`GucEnum::CONFIG_ENUM_ENTRY`] must be a valid pointer to the config enum entry. +pub unsafe trait GucEnum: Copy + Send + Sync { + fn from_ordinal(ordinal: i32) -> Self; fn to_ordinal(&self) -> i32; - fn config_matrix() -> *const pg_sys::config_enum_entry; + const CONFIG_ENUM_ENTRY: *const pg_sys::config_enum_entry; } +/// A trait that indicates that the type can be used as a GUC value. +/// +/// # Safety +/// +/// [`GucValue::Raw`] must be `Send` and `Sync`, or it's a pointer type. pub unsafe trait GucValue { type Raw: Copy; unsafe fn from_raw(raw: Self::Raw) -> Self; - type BoolVal: Copy; + type BootVal: Copy + Send + Sync; } /// A safe wrapper around a global variable that can be edited through a GUC pub struct GucSetting { value: Cell, - boot_val: T::BoolVal, + boot_val: T::BootVal, } unsafe impl Sync for GucSetting {} @@ -143,7 +150,7 @@ unsafe impl GucValue for bool { unsafe fn from_raw(raw: Self::Raw) -> Self { raw } - type BoolVal = (); + type BootVal = (); } impl GucSetting { pub const fn new(value: bool) -> Self { @@ -156,7 +163,7 @@ unsafe impl GucValue for i32 { unsafe fn from_raw(raw: Self::Raw) -> Self { raw } - type BoolVal = (); + type BootVal = (); } impl GucSetting { pub const fn new(value: i32) -> Self { @@ -169,7 +176,7 @@ unsafe impl GucValue for f64 { unsafe fn from_raw(raw: Self::Raw) -> Self { raw } - type BoolVal = (); + type BootVal = (); } impl GucSetting { pub const fn new(value: f64) -> Self { @@ -186,7 +193,7 @@ unsafe impl GucValue for Option { Some(CStr::from_ptr(raw).to_owned()) } } - type BoolVal = (); + type BootVal = (); } impl GucSetting> { pub const fn new(value: Option<&'static CStr>) -> Self { @@ -201,14 +208,14 @@ impl GucSetting> { } } -unsafe impl + Copy> GucValue for T { +unsafe impl GucValue for T { type Raw = i32; unsafe fn from_raw(raw: Self::Raw) -> Self { T::from_ordinal(raw) } - type BoolVal = T; + type BootVal = T; } -impl + Copy> GucSetting { +impl GucSetting { pub const fn new(value: T) -> Self { GucSetting { value: Cell::new(0), boot_val: value } } @@ -216,21 +223,22 @@ impl + Copy> GucSetting { /// A struct that has associated functions to register new GUCs pub struct GucRegistry {} + impl GucRegistry { pub fn define_bool_guc( - name: &str, - short_description: &str, - long_description: &str, + name: &'static CStr, + short_description: &'static CStr, + long_description: &'static CStr, setting: &'static GucSetting, context: GucContext, flags: GucFlags, ) { unsafe { pg_sys::DefineCustomBoolVariable( - PgMemoryContexts::TopMemoryContext.pstrdup(name), - PgMemoryContexts::TopMemoryContext.pstrdup(short_description), - PgMemoryContexts::TopMemoryContext.pstrdup(long_description), - setting.as_ptr(), + name.as_ptr(), + short_description.as_ptr(), + long_description.as_ptr(), + setting.value.as_ptr(), setting.value.get(), context as isize as _, flags.bits(), @@ -242,9 +250,9 @@ impl GucRegistry { } pub fn define_int_guc( - name: &str, - short_description: &str, - long_description: &str, + name: &'static CStr, + short_description: &'static CStr, + long_description: &'static CStr, setting: &'static GucSetting, min_value: i32, max_value: i32, @@ -253,10 +261,10 @@ impl GucRegistry { ) { unsafe { pg_sys::DefineCustomIntVariable( - PgMemoryContexts::TopMemoryContext.pstrdup(name), - PgMemoryContexts::TopMemoryContext.pstrdup(short_description), - PgMemoryContexts::TopMemoryContext.pstrdup(long_description), - setting.as_ptr(), + name.as_ptr(), + short_description.as_ptr(), + long_description.as_ptr(), + setting.value.as_ptr(), setting.value.get(), min_value, max_value, @@ -270,19 +278,19 @@ impl GucRegistry { } pub fn define_string_guc( - name: &str, - short_description: &str, - long_description: &str, + name: &'static CStr, + short_description: &'static CStr, + long_description: &'static CStr, setting: &'static GucSetting>, context: GucContext, flags: GucFlags, ) { unsafe { pg_sys::DefineCustomStringVariable( - PgMemoryContexts::TopMemoryContext.pstrdup(name), - PgMemoryContexts::TopMemoryContext.pstrdup(short_description), - PgMemoryContexts::TopMemoryContext.pstrdup(long_description), - setting.as_ptr(), + name.as_ptr(), + short_description.as_ptr(), + long_description.as_ptr(), + setting.value.as_ptr(), setting.value.get(), context as isize as _, flags.bits(), @@ -294,9 +302,9 @@ impl GucRegistry { } pub fn define_float_guc( - name: &str, - short_description: &str, - long_description: &str, + name: &'static CStr, + short_description: &'static CStr, + long_description: &'static CStr, setting: &'static GucSetting, min_value: f64, max_value: f64, @@ -305,10 +313,10 @@ impl GucRegistry { ) { unsafe { pg_sys::DefineCustomRealVariable( - PgMemoryContexts::TopMemoryContext.pstrdup(name), - PgMemoryContexts::TopMemoryContext.pstrdup(short_description), - PgMemoryContexts::TopMemoryContext.pstrdup(long_description), - setting.as_ptr(), + name.as_ptr(), + short_description.as_ptr(), + long_description.as_ptr(), + setting.value.as_ptr(), setting.value.get(), min_value, max_value, @@ -321,10 +329,10 @@ impl GucRegistry { } } - pub fn define_enum_guc + Copy>( - name: &str, - short_description: &str, - long_description: &str, + pub fn define_enum_guc( + name: &'static CStr, + short_description: &'static CStr, + long_description: &'static CStr, setting: &'static GucSetting, context: GucContext, flags: GucFlags, @@ -332,12 +340,12 @@ impl GucRegistry { setting.value.set(setting.boot_val.to_ordinal()); unsafe { pg_sys::DefineCustomEnumVariable( - PgMemoryContexts::TopMemoryContext.pstrdup(name), - PgMemoryContexts::TopMemoryContext.pstrdup(short_description), - PgMemoryContexts::TopMemoryContext.pstrdup(long_description), - setting.as_ptr(), + name.as_ptr(), + short_description.as_ptr(), + long_description.as_ptr(), + setting.value.as_ptr(), setting.value.get(), - T::config_matrix(), + T::CONFIG_ENUM_ENTRY, context as isize as _, flags.bits(), None,