Skip to content

Commit

Permalink
chore: refactor udfs into sqlbuiltins (#2214)
Browse files Browse the repository at this point in the history
closes #2213

Some context.. 

After finishing up the docs in
#2178, I realized that we were
missing **all** of our udfs. _kdl_*, and all of our postgres specific
functions._

It wasn't really obvious why either. They were in a separate part of the
codebase hidden away. This pr moves all of the udfs out of sqlexec and
puts them into sqlbuiltins alongside our udtfs (user defined table
functions).

---------

Co-authored-by: Sean Smith <[email protected]>
  • Loading branch information
universalmind303 and scsmithr authored Dec 6, 2023
1 parent 9912c82 commit da785d9
Show file tree
Hide file tree
Showing 16 changed files with 838 additions and 604 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/datafusion_ext/src/planner/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> {
return Ok(expr);
}
}

// next, aggregate built-ins
if let Ok(fun) = AggregateFunction::from_str(&name) {
let distinct = function.distinct;
Expand All @@ -139,6 +140,7 @@ impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> {
fun, args, distinct, None, order_by,
)));
};

// User defined aggregate functions
if let Some(fm) = self.schema_provider.get_aggregate_meta(&name).await {
let args = self
Expand All @@ -162,7 +164,7 @@ impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> {
.function_args_to_expr(function.args, schema, planner_context)
.await?;

if let Some(expr) = self.schema_provider.get_builtin(&name, args) {
if let Some(expr) = self.schema_provider.get_scalar_udf(&name, args) {
return Ok(expr);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/datafusion_ext/src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub trait AsyncContextProvider: Send + Sync {
) -> Result<Arc<dyn TableSource>>;

/// Getter for a UDF description
fn get_builtin(&mut self, name: &str, args: Vec<Expr>) -> Option<Expr>;
fn get_scalar_udf(&mut self, name: &str, args: Vec<Expr>) -> Option<Expr>;
/// Getter for a UDAF description
async fn get_aggregate_meta(&mut self, name: &str) -> Option<Arc<AggregateUDF>>;
/// Getter for system/user-defined variable type
Expand Down
24 changes: 21 additions & 3 deletions crates/metastore/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use sqlbuiltins::builtins::{
BuiltinDatabase, BuiltinSchema, BuiltinTable, BuiltinView, DATABASE_DEFAULT, DEFAULT_SCHEMA,
FIRST_NON_STATIC_OID,
};
use sqlbuiltins::functions::{BUILTIN_FUNCS, BUILTIN_TABLE_FUNCS};
use sqlbuiltins::functions::FUNCTION_REGISTRY;
use sqlbuiltins::validation::{
validate_database_tunnel_support, validate_object_name, validate_table_tunnel_support,
};
Expand Down Expand Up @@ -1297,7 +1297,7 @@ impl BuiltinCatalog {
oid += 1;
}

for func in BUILTIN_TABLE_FUNCS.iter_funcs() {
for func in FUNCTION_REGISTRY.table_funcs() {
// Put them all in the default schema.
let schema_id = schema_names
.get(DEFAULT_SCHEMA)
Expand All @@ -1315,7 +1315,25 @@ impl BuiltinCatalog {
oid += 1;
}

for func in BUILTIN_FUNCS.iter_funcs() {
for func in FUNCTION_REGISTRY.scalar_functions() {
// Put them all in the default schema.
let schema_id = schema_names
.get(DEFAULT_SCHEMA)
.ok_or_else(|| MetastoreError::MissingNamedSchema(DEFAULT_SCHEMA.to_string()))?;

insert_entry(
oid,
CatalogEntry::Function(func.as_function_entry(oid, *schema_id)),
)?;
schema_objects
.get_mut(schema_id)
.unwrap()
.functions
.insert(func.name().to_string(), oid);

oid += 1;
}
for func in FUNCTION_REGISTRY.scalar_udfs() {
// Put them all in the default schema.
let schema_id = schema_names
.get(DEFAULT_SCHEMA)
Expand Down
2 changes: 2 additions & 0 deletions crates/sqlbuiltins/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ once_cell = "1.18.0"
num-traits = "0.2.17"
url.workspace = true
strum = "0.25.0"
kdl = "5.0.0-alpha.1"

162 changes: 140 additions & 22 deletions crates/sqlbuiltins/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@ mod aggregates;
mod scalars;
mod table;

use self::scalars::ArrowCastFunction;
use self::table::BuiltinTableFuncs;
use datafusion::logical_expr::{AggregateFunction, BuiltinScalarFunction, Signature};
use self::scalars::df_scalars::ArrowCastFunction;
use self::scalars::kdl::{KDLMatches, KDLSelect};
use self::scalars::{postgres::*, ConnectionId, Version};
use self::table::{BuiltinTableFuncs, TableFunc};

use datafusion::logical_expr::{AggregateFunction, BuiltinScalarFunction, Expr, Signature};
use once_cell::sync::Lazy;
use protogen::metastore::types::catalog::{
EntryMeta, EntryType, FunctionEntry, FunctionType, RuntimePreference,
};

use std::collections::HashMap;
use std::sync::Arc;

/// Builtin table returning functions available for all sessions.
pub static BUILTIN_TABLE_FUNCS: Lazy<BuiltinTableFuncs> = Lazy::new(BuiltinTableFuncs::new);
static BUILTIN_TABLE_FUNCS: Lazy<BuiltinTableFuncs> = Lazy::new(BuiltinTableFuncs::new);
pub static ARROW_CAST_FUNC: Lazy<ArrowCastFunction> = Lazy::new(|| ArrowCastFunction {});
pub static BUILTIN_FUNCS: Lazy<BuiltinFuncs> = Lazy::new(BuiltinFuncs::new);
pub static FUNCTION_REGISTRY: Lazy<FunctionRegistry> = Lazy::new(FunctionRegistry::new);

/// A builtin function.
/// This trait is implemented by all builtin functions.
/// This is used to derive catalog entries for all supported functions.
/// Any new function MUST implement this trait.
pub trait BuiltinFunction: Sync + Send {
/// The name for this function. This name will be used when looking up
/// function implementations.
Expand Down Expand Up @@ -47,7 +53,7 @@ pub trait BuiltinFunction: Sync + Send {
// Returns the function type. 'aggregate', 'scalar', or 'table'
fn function_type(&self) -> FunctionType;

/// Convert to a builtin `FunctionEntry`
/// Convert to a builtin [`FunctionEntry`] for catalogging.
///
/// The default implementation is suitable for aggregates and scalars. Table
/// functions need to set runtime preference manually.
Expand All @@ -73,7 +79,7 @@ pub trait BuiltinFunction: Sync + Send {
}
}

/// The same as `BuiltinFunction` , but with const values.
/// The same as [`BuiltinFunction`] , but with const values.
pub trait ConstBuiltinFunction: Sync + Send {
const NAME: &'static str;
const DESCRIPTION: &'static str;
Expand All @@ -83,10 +89,43 @@ pub trait ConstBuiltinFunction: Sync + Send {
None
}
}
/// The namespace of a function.
///
/// Optional -> "namespace.function" || "function"
///
/// Required -> "namespace.function"
///
/// None -> "function"
pub enum FunctionNamespace {
/// The function can either be called under the namespace, or under global
/// e.g. "pg_catalog.current_user" or "current_user"
Optional(&'static str),
/// The function must be called under the namespace
/// e.g. "foo.my_function"
Required(&'static str),
/// The function can only be called under the global namespace
/// e.g. "avg"
None,
}

/// A custom builtin function provided by GlareDB.
///
/// These are functions that are implemented directly in GlareDB.
/// Unlike [`BuiltinFunction`], this contains an implementation of a UDF, and is not just a catalog entry for a DataFusion function.
///
/// Note: upcoming release of DataFusion will have a similar trait that'll likely be used instead.
pub trait BuiltinScalarUDF: BuiltinFunction {
fn as_expr(&self, args: Vec<Expr>) -> Expr;
/// The namespace of the function.
/// Defaults to global (None)
fn namespace(&self) -> FunctionNamespace {
FunctionNamespace::None
}
}

impl<T> BuiltinFunction for T
where
T: ConstBuiltinFunction,
T: ConstBuiltinFunction + Sized,
{
fn name(&self) -> &str {
Self::NAME
Expand All @@ -105,11 +144,16 @@ where
}
}

pub struct BuiltinFuncs {
/// Builtin Functions available for all sessions.
/// This is functionally equivalent to the datafusion `SessionState::scalar_functions`
/// We use our own implementation to allow us to have finer grained control over them.
/// We also don't have any session specific functions (for now), so it makes more sense to have a const global.
pub struct FunctionRegistry {
funcs: HashMap<String, Arc<dyn BuiltinFunction>>,
udfs: HashMap<String, Arc<dyn BuiltinScalarUDF>>,
}

impl BuiltinFuncs {
impl FunctionRegistry {
pub fn new() -> Self {
use strum::IntoEnumIterator;
let scalars = BuiltinScalarFunction::iter().map(|f| {
Expand All @@ -126,17 +170,91 @@ impl BuiltinFuncs {
let arrow_cast = (arrow_cast.name().to_string(), arrow_cast);
let arrow_cast = std::iter::once(arrow_cast);

// GlareDB specific functions
let udfs: Vec<Arc<dyn BuiltinScalarUDF>> = vec![
// Postgres functions
Arc::new(HasSchemaPrivilege),
Arc::new(HasDatabasePrivilege),
Arc::new(HasTablePrivilege),
Arc::new(CurrentSchemas),
Arc::new(CurrentUser),
Arc::new(CurrentRole),
Arc::new(CurrentSchema),
Arc::new(CurrentDatabase),
Arc::new(CurrentCatalog),
Arc::new(User),
Arc::new(PgGetUserById),
Arc::new(PgTableIsVisible),
Arc::new(PgEncodingToChar),
Arc::new(PgArrayToString),
// KDL functions
Arc::new(KDLMatches),
Arc::new(KDLSelect),
// Other functions
Arc::new(ConnectionId),
Arc::new(Version),
];
let udfs = udfs
.into_iter()
.flat_map(|f| {
let entry = (f.name().to_string(), f.clone());
match f.namespace() {
// we register the function under both the namespaced entry and the normal entry
// e.g. select foo.my_function() or select my_function()
FunctionNamespace::Optional(namespace) => {
let namespaced_entry = (format!("{}.{}", namespace, f.name()), f.clone());
vec![entry, namespaced_entry]
}
// we only register the function under the namespaced entry
// e.g. select foo.my_function()
FunctionNamespace::Required(namespace) => {
let namespaced_entry = (format!("{}.{}", namespace, f.name()), f.clone());
vec![namespaced_entry]
}
// we only register the function under the normal entry
// e.g. select my_function()
FunctionNamespace::None => {
vec![entry]
}
}
})
.collect::<HashMap<_, _>>();

let funcs: HashMap<String, Arc<dyn BuiltinFunction>> =
scalars.chain(aggregates).chain(arrow_cast).collect();

BuiltinFuncs { funcs }
FunctionRegistry { funcs, udfs }
}

pub fn find_function(&self, name: &str) -> Option<Arc<dyn BuiltinFunction>> {
self.funcs.get(name).cloned()
}
pub fn iter_funcs(&self) -> impl Iterator<Item = &Arc<dyn BuiltinFunction>> {

/// Find a scalar UDF by name
/// This is separate from `find_function` because we want to avoid downcasting
/// We already match on BuiltinScalarFunction and AggregateFunction when parsing the AST, so we just need to match on the UDF here.
pub fn get_scalar_udf(&self, name: &str) -> Option<Arc<dyn BuiltinScalarUDF>> {
self.udfs.get(name).cloned()
}

pub fn scalar_functions(&self) -> impl Iterator<Item = &Arc<dyn BuiltinFunction>> {
self.funcs.values()
}

pub fn scalar_udfs(&self) -> impl Iterator<Item = &Arc<dyn BuiltinScalarUDF>> {
self.udfs.values()
}
/// Return an iterator over all builtin table functions.
pub fn table_funcs(&self) -> impl Iterator<Item = &Arc<dyn TableFunc>> {
BUILTIN_TABLE_FUNCS.iter_funcs()
}

pub fn get_table_func(&self, name: &str) -> Option<Arc<dyn TableFunc>> {
BUILTIN_TABLE_FUNCS.find_function(name)
}
}

impl Default for BuiltinFuncs {
impl Default for FunctionRegistry {
fn default() -> Self {
Self::new()
}
Expand All @@ -151,27 +269,27 @@ macro_rules! document {
pub struct $item;

impl $item {
const DESCRIPTION: &'static str = $doc;
const EXAMPLE: &'static str = $example;
const NAME: &'static str = stringify!($item);
pub const DESCRIPTION: &'static str = $doc;
pub const EXAMPLE: &'static str = $example;
pub const NAME: &'static str = stringify!($item);
}
};
(doc => $doc:expr, example => $example:expr, $name:expr => $item:ident) => {
#[doc = $doc]
pub struct $item;

impl $item {
const DESCRIPTION: &'static str = $doc;
const EXAMPLE: &'static str = $example;
const NAME: &'static str = $name;
pub const DESCRIPTION: &'static str = $doc;
pub const EXAMPLE: &'static str = $example;
pub const NAME: &'static str = $name;
}
};
// uses an existing struct
($doc:expr, $example:expr, name => $name:expr, implementation => $item:ident) => {
impl $item {
const DESCRIPTION: &'static str = $doc;
const EXAMPLE: &'static str = $example;
const NAME: &'static str = $name;
pub const DESCRIPTION: &'static str = $doc;
pub const EXAMPLE: &'static str = $example;
pub const NAME: &'static str = $name;
}
};
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
// we make use of the document! macro to generate the documentation for the builtin functions.
// specifically the `stringify!` macro is used to get the name of the function.
// `Abs` would otherwise be `Abs` instead of `abs`. and so on.
#![allow(non_camel_case_types)]

use crate::{
document,
functions::{BuiltinFunction, ConstBuiltinFunction},
};
use datafusion::logical_expr::BuiltinScalarFunction;
use protogen::metastore::types::catalog::FunctionType;
#![allow(non_camel_case_types)]

use super::*;
pub struct ArrowCastFunction {}

impl ConstBuiltinFunction for ArrowCastFunction {
Expand Down
Loading

0 comments on commit da785d9

Please sign in to comment.