Skip to content

Commit

Permalink
Complete integration
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Nov 3, 2023
1 parent 1118305 commit b3e25be
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 106 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ ctor = "0.2.0"
datafusion = { path = "datafusion/core" }
datafusion-common = { path = "datafusion/common" }
datafusion-expr = { path = "datafusion/expr" }
datafusion-functions = { path = "datafusion/functions" }
datafusion-sql = { path = "datafusion/sql" }
datafusion-optimizer = { path = "datafusion/optimizer" }
datafusion-physical-expr = { path = "datafusion/physical-expr" }
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ backtrace = ["datafusion-common/backtrace"]
compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression"]
crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"]
default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"]
encoding_expressions = ["datafusion-physical-expr/encoding_expressions"]
encoding_expressions = ["datafusion-functions/encoding_expressions"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
force_hash_collisions = []
parquet = ["datafusion-common/parquet", "dep:parquet"]
Expand All @@ -65,6 +65,7 @@ dashmap = { workspace = true }
datafusion-common = { path = "../common", version = "32.0.0", features = ["object_store"], default-features = false }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-optimizer = { path = "../optimizer", version = "32.0.0", default-features = false }
datafusion-physical-expr = { path = "../physical-expr", version = "32.0.0", default-features = false }
datafusion-physical-plan = { workspace = true }
Expand Down
107 changes: 48 additions & 59 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use datafusion_common::{
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion},
};
pub use datafusion_execution::registry::MutableFunctionRegistry;
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
Expand Down Expand Up @@ -796,6 +795,48 @@ impl SessionContext {
.add_var_provider(variable_type, provider);
}

/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
/// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
pub fn register_udf(&self, f: ScalarUDF) {
self.state
.write()
.scalar_functions
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers an aggregate UDF within this context.
///
/// Note in SQL queries, aggregate names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
pub fn register_udaf(&self, f: AggregateUDF) {
self.state
.write()
.aggregate_functions
.insert(f.name.clone(), Arc::new(f));
}

/// Registers a window UDF within this context.
///
/// Note in SQL queries, window function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
/// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
pub fn register_udwf(&self, f: WindowUDF) {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
}

/// Creates a [`DataFrame`] for reading a data source.
///
/// For more control such as reading multiple files, you can use
Expand Down Expand Up @@ -1117,50 +1158,6 @@ impl FunctionRegistry for SessionContext {
}
}

impl MutableFunctionRegistry for SessionContext {
/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
/// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
fn register_udf(&self, f: ScalarUDF) {
self.state
.write()
.scalar_functions
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers an aggregate UDF within this context.
///
/// Note in SQL queries, aggregate names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
fn register_udaf(&self, f: AggregateUDF) {
self.state
.write()
.aggregate_functions
.insert(f.name.clone(), Arc::new(f));
}

/// Registers a window UDF within this context.
///
/// Note in SQL queries, window function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
/// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
fn register_udwf(&self, f: WindowUDF) {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
}
}

/// A planner used to add extensions to DataFusion logical and physical plans.
#[async_trait]
pub trait QueryPlanner {
Expand Down Expand Up @@ -1301,14 +1298,18 @@ impl SessionState {
);
}

// register built in functions
let mut scalar_functions = HashMap::new();
datafusion_functions::register_all(&mut scalar_functions);

SessionState {
session_id,
analyzer: Analyzer::new(),
optimizer: Optimizer::new(),
physical_optimizers: PhysicalOptimizer::new(),
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
scalar_functions: HashMap::new(),
scalar_functions,
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
serializer_registry: Arc::new(EmptySerializerRegistry),
Expand All @@ -1318,19 +1319,7 @@ impl SessionState {
table_factories,
}
}
/// Returns new [`SessionState`] using the provided
/// [`SessionConfig`] and [`RuntimeEnv`].
#[deprecated(
since = "32.0.0",
note = "Use SessionState::new_with_config_rt_and_catalog_list"
)]
pub fn with_config_rt_and_catalog_list(
config: SessionConfig,
runtime: Arc<RuntimeEnv>,
catalog_list: Arc<dyn CatalogList>,
) -> Self {
Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list)
}

fn register_default_schema(
config: &SessionConfig,
table_factories: &HashMap<String, Arc<dyn TableProviderFactory>>,
Expand Down
4 changes: 1 addition & 3 deletions datafusion/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
//! ```

pub use crate::dataframe::DataFrame;
pub use crate::execution::context::{
MutableFunctionRegistry, SQLOptions, SessionConfig, SessionContext,
};
pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext};
pub use crate::execution::options::{
AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use datafusion::{
},
assert_batches_eq,
error::Result,
execution::MutableFunctionRegistry,
logical_expr::{
AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, TypeSignature, Volatility,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use arrow_schema::DataType;
use datafusion::{assert_batches_eq, prelude::SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_execution::MutableFunctionRegistry;
use datafusion_expr::{
function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction,
Signature, Volatility, WindowUDF,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ mod stream;
mod task;

pub use disk_manager::DiskManager;
pub use registry::{FunctionRegistry, MutableFunctionRegistry};
pub use registry::FunctionRegistry;
pub use stream::{RecordBatchStream, SendableRecordBatchStream};
pub use task::TaskContext;
30 changes: 0 additions & 30 deletions datafusion/execution/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,6 @@ pub trait FunctionRegistry {
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
}

/// A Function registry that can have functions registered
pub trait MutableFunctionRegistry {
/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
/// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
fn register_udf(&self, f: ScalarUDF);

/// Registers an aggregate UDF within this context.
///
/// Note in SQL queries, aggregate names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
fn register_udaf(&self, f: AggregateUDF);

/// Registers a window UDF within this context.
///
/// Note in SQL queries, window function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
/// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
fn register_udwf(&self, f: WindowUDF);
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
pub trait SerializerRegistry: Send + Sync {
/// Serialize this node to a byte array. This serialization should not include
Expand Down
16 changes: 9 additions & 7 deletions datafusion/functions/src/encoding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ mod inner;

use datafusion_common::arrow::datatypes::DataType;
use datafusion_common::{plan_err, DataFusionError, Result};
use datafusion_execution::registry::MutableFunctionRegistry;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{
ColumnarValue, FunctionImplementation, ScalarUDF, Signature, Volatility,
};
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use DataType::*;

Expand All @@ -39,12 +39,14 @@ pub fn decode_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(Arc::new(DecodeFunc {}))
}

/// Registers the `encode` and `decode` functions with the function registry
pub fn register(registry: &dyn MutableFunctionRegistry) -> Result<()> {
registry.register_udf(encode_udf());
registry.register_udf(decode_udf());
fn insert(registry: &mut HashMap<String, Arc<ScalarUDF>>, udf: ScalarUDF) {
registry.insert(udf.name().to_string(), Arc::new(udf));
}

Ok(())
/// Registers the `encode` and `decode` functions with the function registry
pub fn register(registry: &mut HashMap<String, Arc<ScalarUDF>>) {
insert(registry, encode_udf());
insert(registry, decode_udf());
}

struct EncodeFunc {}
Expand Down Expand Up @@ -95,7 +97,7 @@ static DECODE_SIGNATURE: OnceLock<Signature> = OnceLock::new();

impl FunctionImplementation for DecodeFunc {
fn name(&self) -> &str {
"encode"
"decode"
}

fn signature(&self) -> &Signature {
Expand Down
9 changes: 9 additions & 0 deletions datafusion/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,13 @@

//! Several packages of built in functions for DataFusion

use datafusion_expr::ScalarUDF;
use std::collections::HashMap;
use std::sync::Arc;

pub mod encoding;

/// Registers all "built in" functions from this crate with the provided registry
pub fn register_all(registry: &mut HashMap<String, Arc<ScalarUDF>>) {
encoding::register(registry);
}
1 change: 0 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use datafusion::datasource::provider::TableProviderFactory;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::execution::MutableFunctionRegistry;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext};
use datafusion::test_util::{TestTableFactory, TestTableProvider};
Expand Down
1 change: 0 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
use datafusion::execution::context::ExecutionProps;
use datafusion::execution::MutableFunctionRegistry;
use datafusion::logical_expr::{
create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility,
};
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use datafusion::execution::{FunctionRegistry, MutableFunctionRegistry};
use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::SessionContext;
use datafusion_expr::{col, create_udf, lit};
Expand Down

0 comments on commit b3e25be

Please sign in to comment.