From fab06da82eb351009381ede327009dcfd5421c65 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 10 Jul 2024 16:59:55 -0400 Subject: [PATCH 1/8] Create a SessionStateBuilder and use it for creating anything but a basic SessionState. --- datafusion-cli/src/catalog.rs | 9 +- datafusion/core/src/execution/context/mod.rs | 9 +- .../core/src/execution/session_state.rs | 624 +++++++++++++----- datafusion/core/tests/memory_limit/mod.rs | 5 +- .../tests/user_defined/user_defined_plan.rs | 9 +- .../tests/cases/roundtrip_logical_plan.rs | 8 +- 6 files changed, 495 insertions(+), 169 deletions(-) diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index c11eb3280c20..7aa8937a3b78 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -29,6 +29,7 @@ use datafusion::datasource::listing::{ use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::SessionState; +use datafusion::execution::session_state::SessionStateBuilder; use async_trait::async_trait; use dirs::home_dir; @@ -162,6 +163,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { .ok_or_else(|| plan_datafusion_err!("locking error"))? .read() .clone(); + let mut builder = SessionStateBuilder::new_from_existing(state.clone()); let optimized_name = substitute_tilde(name.to_owned()); let table_url = ListingTableUrl::parse(optimized_name.as_str())?; let scheme = table_url.scheme(); @@ -178,13 +180,16 @@ impl SchemaProvider for DynamicFileSchemaProvider { // to any command options so the only choice is to use an empty collection match scheme { "s3" | "oss" | "cos" => { - state = state.add_table_options_extension(AwsOptions::default()); + builder = + builder.with_table_options_extension(AwsOptions::default()); } "gs" | "gcs" => { - state = state.add_table_options_extension(GcpOptions::default()) + builder = + builder.with_table_options_extension(GcpOptions::default()) } _ => {} }; + state = builder.build(); let store = get_object_store( &state, table_url.scheme(), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 4b9e3e843341..63009f6eb959 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -315,7 +315,7 @@ impl SessionContext { } /// Creates a new `SessionContext` using the provided [`SessionState`] - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")] + #[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_state")] pub fn with_state(state: SessionState) -> Self { Self::new_with_state(state) } @@ -1574,6 +1574,7 @@ mod tests { use datafusion_common_runtime::SpawnedTask; use crate::catalog::schema::SchemaProvider; + use crate::execution::session_state::SessionStateBuilder; use crate::physical_planner::PhysicalPlanner; use async_trait::async_trait; use tempfile::TempDir; @@ -1734,8 +1735,10 @@ mod tests { async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); let session_state = - SessionState::new_with_config_rt(SessionConfig::new(), runtime) - .with_query_planner(Arc::new(MyQueryPlanner {})); + SessionStateBuilder::new_with_config_rt(SessionConfig::new(), runtime) + .with_defaults(true) + .with_query_planner(Arc::new(MyQueryPlanner {})) + .build(); let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c123ebb22ecb..c42499fedaea 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -91,7 +91,8 @@ use uuid::Uuid; /// /// Note that there is no `Default` or `new()` for SessionState, /// to avoid accidentally running queries or other operations without passing through -/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionContext`]. +/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionStateBuilder`] and +/// [`SessionContext`]. /// /// [`SessionContext`]: crate::execution::context::SessionContext #[derive(Clone)] @@ -140,7 +141,6 @@ pub struct SessionState { table_factories: HashMap>, /// Runtime environment runtime_env: Arc, - /// [FunctionFactory] to support pluggable user defined function handler. /// /// It will be invoked on `CREATE FUNCTION` statements. @@ -153,6 +153,7 @@ impl Debug for SessionState { f.debug_struct("SessionState") .field("session_id", &self.session_id) .field("analyzer", &"...") + .field("expr_planners", &"...") .field("optimizer", &"...") .field("physical_optimizers", &"...") .field("query_planner", &"...") @@ -195,122 +196,10 @@ impl SessionState { runtime: Arc, catalog_list: Arc, ) -> Self { - let session_id = Uuid::new_v4().to_string(); - - // Create table_factories for all default formats - let mut table_factories: HashMap> = - HashMap::new(); - #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); - - if config.create_default_catalog_and_schema() { - let default_catalog = MemoryCatalogProvider::new(); - - default_catalog - .register_schema( - &config.options().catalog.default_schema, - Arc::new(MemorySchemaProvider::new()), - ) - .expect("memory catalog provider can register schema"); - - Self::register_default_schema( - &config, - &table_factories, - &runtime, - &default_catalog, - ); - - catalog_list.register_catalog( - config.options().catalog.default_catalog.clone(), - Arc::new(default_catalog), - ); - } - - let expr_planners: Vec> = vec![ - Arc::new(functions::core::planner::CoreFunctionPlanner::default()), - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::ArrayFunctionPlanner), - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::FieldAccessPlanner), - #[cfg(any( - feature = "datetime_expressions", - feature = "unicode_expressions" - ))] - Arc::new(functions::planner::UserDefinedFunctionPlanner), - ]; - - let mut new_self = SessionState { - session_id, - analyzer: Analyzer::new(), - expr_planners, - optimizer: Optimizer::new(), - physical_optimizers: PhysicalOptimizer::new(), - query_planner: Arc::new(DefaultQueryPlanner {}), - catalog_list, - table_functions: HashMap::new(), - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), - window_functions: HashMap::new(), - serializer_registry: Arc::new(EmptySerializerRegistry), - file_formats: HashMap::new(), - table_options: TableOptions::default_from_session_config(config.options()), - config, - execution_props: ExecutionProps::new(), - runtime_env: runtime, - table_factories, - function_factory: None, - }; - - #[cfg(feature = "parquet")] - if let Err(e) = - new_self.register_file_format(Arc::new(ParquetFormatFactory::new()), false) - { - log::info!("Unable to register default ParquetFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(JsonFormatFactory::new()), false) - { - log::info!("Unable to register default JsonFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(CsvFormatFactory::new()), false) - { - log::info!("Unable to register default CsvFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(ArrowFormatFactory::new()), false) - { - log::info!("Unable to register default ArrowFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(AvroFormatFactory::new()), false) - { - log::info!("Unable to register default AvroFormat: {e}") - }; - - // register built in functions - functions::register_all(&mut new_self) - .expect("can not register built in functions"); - - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(&mut new_self) - .expect("can not register array expressions"); - - functions_aggregate::register_all(&mut new_self) - .expect("can not register aggregate functions"); - - new_self + SessionStateBuilder::new_with_config_rt(config, runtime) + .with_defaults(true) + .with_catalog_list(catalog_list) + .build() } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. @@ -325,44 +214,6 @@ impl SessionState { ) -> Self { Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) } - fn register_default_schema( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - default_catalog: &MemoryCatalogProvider, - ) { - let url = config.options().catalog.location.as_ref(); - let format = config.options().catalog.format.as_ref(); - let (url, format) = match (url, format) { - (Some(url), Some(format)) => (url, format), - _ => return, - }; - let url = url.to_string(); - let format = format.to_string(); - - let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); - let authority = match url.host_str() { - Some(host) => format!("{}://{}", url.scheme(), host), - None => format!("{}://", url.scheme()), - }; - let path = &url.as_str()[authority.len()..]; - let path = object_store::path::Path::parse(path).expect("Can't parse path"); - let store = ObjectStoreUrl::parse(authority.as_str()) - .expect("Invalid default catalog url"); - let store = match runtime.object_store(store) { - Ok(store) => store, - _ => return, - }; - let factory = match table_factories.get(format.as_str()) { - Some(factory) => factory, - _ => return, - }; - let schema = - ListingSchemaProvider::new(authority, path, factory.clone(), store, format); - let _ = default_catalog - .register_schema("default", Arc::new(schema)) - .expect("Failed to register default schema"); - } pub(crate) fn resolve_table_ref( &self, @@ -400,12 +251,14 @@ impl SessionState { }) } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the random session id. pub fn with_session_id(mut self, session_id: String) -> Self { self.session_id = session_id; self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// override default query planner with `query_planner` pub fn with_query_planner( mut self, @@ -415,6 +268,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Override the [`AnalyzerRule`]s optimizer plan rules. pub fn with_analyzer_rules( mut self, @@ -424,6 +278,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the entire list of [`OptimizerRule`]s used to optimize plans pub fn with_optimizer_rules( mut self, @@ -433,6 +288,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the entire list of [`PhysicalOptimizerRule`]s used to optimize plans pub fn with_physical_optimizer_rules( mut self, @@ -452,6 +308,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Add `optimizer_rule` to the end of the list of /// [`OptimizerRule`]s used to rewrite queries. pub fn add_optimizer_rule( @@ -472,6 +329,7 @@ impl SessionState { self.optimizer.rules.push(optimizer_rule); } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Add `physical_optimizer_rule` to the end of the list of /// [`PhysicalOptimizerRule`]s used to rewrite queries. pub fn add_physical_optimizer_rule( @@ -482,6 +340,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Adds a new [`ConfigExtension`] to TableOptions pub fn add_table_options_extension( mut self, @@ -491,6 +350,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements pub fn with_function_factory( mut self, @@ -505,6 +365,7 @@ impl SessionState { self.function_factory = Some(function_factory); } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the extension [`SerializerRegistry`] pub fn with_serializer_registry( mut self, @@ -976,6 +837,457 @@ impl SessionState { } } +/// A builder to be used for building [`SessionState`]'s. Defaults will be used for all values +/// unless explicitly provided. Note that there is no `Default` or `new()` for SessionState, +/// to avoid accidentally running queries or other operations without passing through +/// the [`SessionConfig`] or [`RuntimeEnv`]. +pub struct SessionStateBuilder { + state: SessionState, + use_defaults: bool, +} + +impl SessionStateBuilder { + /// Returns new [`SessionStateBuilder`] using the provided + /// [`SessionConfig`] and [`RuntimeEnv`]. + pub fn new_with_config_rt( + config: SessionConfig, + runtime_env: Arc, + ) -> Self { + let session_id = Uuid::new_v4().to_string(); + let catalog_list = + Arc::new(MemoryCatalogProviderList::new()) as Arc; + + Self { + state: SessionState { + session_id, + analyzer: Analyzer::new(), + expr_planners: vec![], + optimizer: Optimizer::new(), + physical_optimizers: PhysicalOptimizer::new(), + query_planner: Arc::new(DefaultQueryPlanner {}), + catalog_list, + table_functions: HashMap::new(), + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), + serializer_registry: Arc::new(EmptySerializerRegistry), + file_formats: HashMap::new(), + table_options: TableOptions::default_from_session_config( + config.options(), + ), + config, + execution_props: ExecutionProps::new(), + table_factories: HashMap::new(), + runtime_env, + function_factory: None, + }, + use_defaults: true, + } + } + + /// Returns a new [SessionStateBuilder] based on an existing [SessionState] + /// The session id for the new builder will be reset to a unique value, all + /// other fields will be set to what is set in the provided session state + pub fn new_from_existing(existing: SessionState) -> Self { + let session_id = Uuid::new_v4().to_string(); + + Self { + state: SessionState { + session_id, + ..existing + }, + use_defaults: true, + } + } + + /// Set to true (default = true) if defaults for table_factories, expr_planners, file formats + /// and builtin functions should be set. + /// Note that there is an explicit option for enabling catalog and schema default + /// via [SessionConfig::create_default_catalog_and_schema] which will only be used + /// if the use_defaults is enabled here. + /// Also note that if a field is explicitly set to a non-empty value - + /// for example by using the [SessionStateBuilder::with_file_formats] function, + /// then defaults for that field will not be set. + pub fn with_defaults(mut self, use_defaults: bool) -> Self { + self.use_defaults = use_defaults; + self + } + + /// Replace the random session id. + pub fn with_session_id(mut self, session_id: String) -> Self { + self.state.session_id = session_id; + self + } + + /// Override the [`AnalyzerRule`]s optimizer plan rules. + pub fn with_analyzer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.state.analyzer = Analyzer::with_rules(rules); + self + } + + /// Add `analyzer_rule` to the end of the list of + /// [`AnalyzerRule`]s used to rewrite queries. + pub fn add_analyzer_rule( + mut self, + analyzer_rule: Arc, + ) -> Self { + self.state.analyzer.rules.push(analyzer_rule); + self + } + + /// Replace the entire list of [`OptimizerRule`]s used to optimize plans + pub fn with_optimizer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.state.optimizer = Optimizer::with_rules(rules); + self + } + + /// Add `optimizer_rule` to the end of the list of + /// [`OptimizerRule`]s used to rewrite queries. + pub fn add_optimizer_rule( + mut self, + optimizer_rule: Arc, + ) -> Self { + self.state.optimizer.rules.push(optimizer_rule); + self + } + + /// Replace the entire list of [`ExprPlanner`]s used to customize the behavior of the SQL planner + pub fn with_expr_planners( + mut self, + expr_planners: Vec>, + ) -> Self { + self.state.expr_planners = expr_planners; + self + } + + /// Replace the entire list of [`PhysicalOptimizerRule`]s used to optimize plans + pub fn with_physical_optimizer_rules( + mut self, + physical_optimizers: Vec>, + ) -> Self { + self.state.physical_optimizers = + PhysicalOptimizer::with_rules(physical_optimizers); + self + } + + /// Add `physical_optimizer_rule` to the end of the list of + /// [`PhysicalOptimizerRule`]s used to rewrite queries. + pub fn add_physical_optimizer_rule( + mut self, + physical_optimizer_rule: Arc, + ) -> Self { + self.state + .physical_optimizers + .rules + .push(physical_optimizer_rule); + self + } + + /// override default query planner with `query_planner` + pub fn with_query_planner( + mut self, + query_planner: Arc, + ) -> Self { + self.state.query_planner = query_planner; + self + } + + /// override default catalog list with `catalog_list` + pub fn with_catalog_list( + mut self, + catalog_list: Arc, + ) -> Self { + self.state.catalog_list = catalog_list; + self + } + + /// override default table functions with `table_functions` + pub fn with_table_functions( + mut self, + table_functions: HashMap>, + ) -> Self { + self.state.table_functions = table_functions; + self + } + + /// override default scalar functions with `scalar_functions` + pub fn with_scalar_functions( + mut self, + scalar_functions: HashMap>, + ) -> Self { + self.state.scalar_functions = scalar_functions; + self + } + + /// override default aggregate functions with `aggregate_functions` + pub fn with_aggregate_functions( + mut self, + aggregate_functions: HashMap>, + ) -> Self { + self.state.aggregate_functions = aggregate_functions; + self + } + + /// override default window functions with `window_functions` + pub fn with_window_functions( + mut self, + window_functions: HashMap>, + ) -> Self { + self.state.window_functions = window_functions; + self + } + + /// Registers a [`SerializerRegistry`] + pub fn with_serializer_registry( + mut self, + serializer_registry: Arc, + ) -> Self { + self.state.serializer_registry = serializer_registry; + self + } + + /// override default list of file formats with `file_formats` + pub fn with_file_formats( + mut self, + file_formats: HashMap>, + ) -> Self { + self.state.file_formats = file_formats; + self + } + + /// override the session config with `config` + pub fn with_config(mut self, config: SessionConfig) -> Self { + self.state.config = config; + self + } + + /// override default table options with `table_options` + pub fn with_table_options(mut self, table_options: TableOptions) -> Self { + self.state.table_options = table_options; + self + } + + /// Adds a new [`ConfigExtension`] to TableOptions + pub fn with_table_options_extension( + mut self, + extension: T, + ) -> Self { + self.state.table_options.extensions.insert(extension); + self + } + + /// override default execution props with `execution_props` + pub fn with_execution_props(mut self, execution_props: ExecutionProps) -> Self { + self.state.execution_props = execution_props; + self + } + + /// override default table factories with `table_factories` + pub fn with_table_factories( + mut self, + table_factories: HashMap>, + ) -> Self { + self.state.table_factories = table_factories; + self + } + + /// override the runtime env with `runtime_env` + pub fn with_runtime_env(mut self, runtime_env: Arc) -> Self { + self.state.runtime_env = runtime_env; + self + } + + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn with_function_factory( + mut self, + function_factory: Option>, + ) -> Self { + self.state.function_factory = function_factory; + self + } + + /// build a [`SessionState`] with the current configuration + pub fn build(mut self) -> SessionState { + if self.use_defaults { + if self.state.table_factories.is_empty() { + self.state.table_factories = + SessionStateDefaults::default_table_factories(); + } + if self.state.expr_planners.is_empty() { + self.state.expr_planners = SessionStateDefaults::default_expr_planners(); + } + + if self.state.config.create_default_catalog_and_schema() { + let default_catalog = SessionStateDefaults::default_catalog( + &self.state.config, + &self.state.table_factories, + &self.state.runtime_env, + ); + + self.state.catalog_list.register_catalog( + self.state.config.options().catalog.default_catalog.clone(), + Arc::new(default_catalog), + ); + } + + if self.state.file_formats.is_empty() { + SessionStateDefaults::register_file_format_defaults(&mut self.state); + } + + if self.state.scalar_functions.is_empty() { + SessionStateDefaults::register_builtin_functions(&mut self.state); + } + } + + self.state.clone() + } +} + +struct SessionStateDefaults {} + +impl SessionStateDefaults { + pub fn default_table_factories() -> HashMap> { + let mut table_factories: HashMap> = + HashMap::new(); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + + table_factories + } + + pub fn default_catalog( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + ) -> MemoryCatalogProvider { + let default_catalog = MemoryCatalogProvider::new(); + + default_catalog + .register_schema( + &config.options().catalog.default_schema, + Arc::new(MemorySchemaProvider::new()), + ) + .expect("memory catalog provider can register schema"); + + Self::register_default_schema(config, table_factories, runtime, &default_catalog); + + default_catalog + } + + pub fn default_expr_planners() -> Vec> { + let expr_planners: Vec> = vec![ + Arc::new(functions::core::planner::CoreFunctionPlanner::default()), + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::ArrayFunctionPlanner), + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::FieldAccessPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), + ]; + + expr_planners + } + + pub fn register_default_schema( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + default_catalog: &MemoryCatalogProvider, + ) { + let url = config.options().catalog.location.as_ref(); + let format = config.options().catalog.format.as_ref(); + let (url, format) = match (url, format) { + (Some(url), Some(format)) => (url, format), + _ => return, + }; + let url = url.to_string(); + let format = format.to_string(); + + let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); + let authority = match url.host_str() { + Some(host) => format!("{}://{}", url.scheme(), host), + None => format!("{}://", url.scheme()), + }; + let path = &url.as_str()[authority.len()..]; + let path = object_store::path::Path::parse(path).expect("Can't parse path"); + let store = ObjectStoreUrl::parse(authority.as_str()) + .expect("Invalid default catalog url"); + let store = match runtime.object_store(store) { + Ok(store) => store, + _ => return, + }; + let factory = match table_factories.get(format.as_str()) { + Some(factory) => factory, + _ => return, + }; + let schema = + ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let _ = default_catalog + .register_schema("default", Arc::new(schema)) + .expect("Failed to register default schema"); + } + + pub fn register_file_format_defaults(state: &mut SessionState) { + #[cfg(feature = "parquet")] + if let Err(e) = + state.register_file_format(Arc::new(ParquetFormatFactory::new()), false) + { + log::info!("Unable to register default ParquetFormat: {e}") + }; + + if let Err(e) = + state.register_file_format(Arc::new(JsonFormatFactory::new()), false) + { + log::info!("Unable to register default JsonFormat: {e}") + }; + + if let Err(e) = + state.register_file_format(Arc::new(CsvFormatFactory::new()), false) + { + log::info!("Unable to register default CsvFormat: {e}") + }; + + if let Err(e) = + state.register_file_format(Arc::new(ArrowFormatFactory::new()), false) + { + log::info!("Unable to register default ArrowFormat: {e}") + }; + + if let Err(e) = + state.register_file_format(Arc::new(AvroFormatFactory::new()), false) + { + log::info!("Unable to register default AvroFormat: {e}") + }; + } + + pub fn register_builtin_functions(state: &mut SessionState) { + // register built in functions + functions::register_all(state).expect("can not register built in functions"); + + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + functions_array::register_all(state).expect("can not register array expressions"); + + functions_aggregate::register_all(state) + .expect("can not register aggregate functions"); + } +} + struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f61ee5d9ab98..997ad2a8341a 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -37,6 +37,7 @@ use datafusion::datasource::{MemTable, TableProvider}; use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::physical_optimizer::join_selection::JoinSelection; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; @@ -434,13 +435,13 @@ impl TestCase { let runtime = RuntimeEnv::new(rt_config).unwrap(); // Configure execution - let state = SessionState::new_with_config_rt(config, Arc::new(runtime)); + let state = SessionStateBuilder::new_with_config_rt(config, Arc::new(runtime)); let state = match scenario.rules() { Some(rules) => state.with_physical_optimizer_rules(rules), None => state, }; - let ctx = SessionContext::new_with_state(state); + let ctx = SessionContext::new_with_state(state.build()); ctx.register_table("t", table).expect("registering table"); let query = query.expect("Test error: query not specified"); diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 38ed142cf922..20d315fe97ab 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -92,6 +92,7 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; @@ -290,10 +291,12 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let mut state = SessionState::new_with_config_rt(config, runtime) + let state = SessionStateBuilder::new_with_config_rt(config, runtime) + .with_defaults(true) .with_query_planner(Arc::new(TopKQueryPlanner {})) - .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + .add_optimizer_rule(Arc::new(TopKOptimizerRule {})) + .add_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); SessionContext::new_with_state(state) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 2893b1a31a26..4866fa587939 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -28,7 +28,6 @@ use std::sync::Arc; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; use datafusion::error::Result; -use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ @@ -37,6 +36,7 @@ use datafusion::logical_expr::{ use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; +use datafusion::execution::session_state::SessionStateBuilder; use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; @@ -1121,11 +1121,13 @@ async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { } async fn create_context() -> Result { - let mut state = SessionState::new_with_config_rt( + let mut state = SessionStateBuilder::new_with_config_rt( SessionConfig::default(), Arc::new(RuntimeEnv::default()), ) - .with_serializer_registry(Arc::new(MockSerializerRegistry)); + .with_defaults(true) + .with_serializer_registry(Arc::new(MockSerializerRegistry)) + .build(); // register udaf for test, e.g. `sum()` datafusion_functions_aggregate::register_all(&mut state) From de390823b4f660363eeb8d0eed360308367a7dd8 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 10 Jul 2024 17:12:55 -0400 Subject: [PATCH 2/8] Updated new_from_existing to take a reference to the existing SessionState and clone it. --- datafusion-cli/src/catalog.rs | 2 +- datafusion/core/src/execution/session_state.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index 7aa8937a3b78..c7fca9b0121e 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -163,7 +163,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { .ok_or_else(|| plan_datafusion_err!("locking error"))? .read() .clone(); - let mut builder = SessionStateBuilder::new_from_existing(state.clone()); + let mut builder = SessionStateBuilder::new_from_existing(&state); let optimized_name = substitute_tilde(name.to_owned()); let table_url = ListingTableUrl::parse(optimized_name.as_str())?; let scheme = table_url.scheme(); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c42499fedaea..f1c7faa47fcc 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -888,13 +888,13 @@ impl SessionStateBuilder { /// Returns a new [SessionStateBuilder] based on an existing [SessionState] /// The session id for the new builder will be reset to a unique value, all /// other fields will be set to what is set in the provided session state - pub fn new_from_existing(existing: SessionState) -> Self { + pub fn new_from_existing(existing: &SessionState) -> Self { let session_id = Uuid::new_v4().to_string(); Self { state: SessionState { session_id, - ..existing + ..existing.clone() }, use_defaults: true, } From 7d1bc09a1c20d8a2f7560032b5b29624b4fae164 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 10 Jul 2024 17:14:34 -0400 Subject: [PATCH 3/8] Minor documentation update. --- datafusion/core/src/execution/session_state.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index f1c7faa47fcc..c435ae688ae3 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -886,8 +886,8 @@ impl SessionStateBuilder { } /// Returns a new [SessionStateBuilder] based on an existing [SessionState] - /// The session id for the new builder will be reset to a unique value, all - /// other fields will be set to what is set in the provided session state + /// The session id for the new builder will be set to a unique value; all + /// other fields will be cloned from what is set in the provided session state pub fn new_from_existing(existing: &SessionState) -> Self { let session_id = Uuid::new_v4().to_string(); From 490214dc22bde203477764df2458fb37e10e6a52 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 10 Jul 2024 17:50:01 -0400 Subject: [PATCH 4/8] SessionStateDefaults improvements. --- .../core/src/execution/session_state.rs | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c435ae688ae3..6cf760bb6e0a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1137,11 +1137,16 @@ impl SessionStateBuilder { } if self.state.file_formats.is_empty() { - SessionStateDefaults::register_file_format_defaults(&mut self.state); + SessionStateDefaults::register_default_file_formats(&mut self.state); } if self.state.scalar_functions.is_empty() { - SessionStateDefaults::register_builtin_functions(&mut self.state); + SessionStateDefaults::register_scalar_functions(&mut self.state); + SessionStateDefaults::register_array_functions(&mut self.state); + } + + if self.state.aggregate_functions.is_empty() { + SessionStateDefaults::register_aggregate_functions(&mut self.state); } } @@ -1149,9 +1154,12 @@ impl SessionStateBuilder { } } -struct SessionStateDefaults {} +/// Defaults that are used as part of creating a SessionState such as table providers, +/// file formats, registering of builtin functions, etc. +pub struct SessionStateDefaults {} impl SessionStateDefaults { + /// returns a map of the default [`TableProviderFactory`]s pub fn default_table_factories() -> HashMap> { let mut table_factories: HashMap> = HashMap::new(); @@ -1166,6 +1174,7 @@ impl SessionStateDefaults { table_factories } + /// returns the default MemoryCatalogProvider pub fn default_catalog( config: &SessionConfig, table_factories: &HashMap>, @@ -1185,6 +1194,7 @@ impl SessionStateDefaults { default_catalog } + /// returns the list of default [`ExprPlanner`]s pub fn default_expr_planners() -> Vec> { let expr_planners: Vec> = vec![ Arc::new(functions::core::planner::CoreFunctionPlanner::default()), @@ -1203,6 +1213,32 @@ impl SessionStateDefaults { expr_planners } + /// registers all builtin functions - scalar, array and aggregate + pub fn register_builtin_functions(state: &mut SessionState) { + Self::register_scalar_functions(state); + Self::register_array_functions(state); + Self::register_aggregate_functions(state); + } + + /// registers all the builtin scalar functions + pub fn register_scalar_functions(state: &mut SessionState) { + functions::register_all(state).expect("can not register built in functions"); + } + + /// registers all the builtin array functions + pub fn register_array_functions(state: &mut SessionState) { + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + functions_array::register_all(state).expect("can not register array expressions"); + } + + /// registers all the builtin aggregate functions + pub fn register_aggregate_functions(state: &mut SessionState) { + functions_aggregate::register_all(state) + .expect("can not register aggregate functions"); + } + + /// registers the default schema pub fn register_default_schema( config: &SessionConfig, table_factories: &HashMap>, @@ -1242,7 +1278,8 @@ impl SessionStateDefaults { .expect("Failed to register default schema"); } - pub fn register_file_format_defaults(state: &mut SessionState) { + /// registers the default [`FileFormatFactory`]s + pub fn register_default_file_formats(state: &mut SessionState) { #[cfg(feature = "parquet")] if let Err(e) = state.register_file_format(Arc::new(ParquetFormatFactory::new()), false) @@ -1274,18 +1311,6 @@ impl SessionStateDefaults { log::info!("Unable to register default AvroFormat: {e}") }; } - - pub fn register_builtin_functions(state: &mut SessionState) { - // register built in functions - functions::register_all(state).expect("can not register built in functions"); - - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(state).expect("can not register array expressions"); - - functions_aggregate::register_all(state) - .expect("can not register aggregate functions"); - } } struct SessionContextProvider<'a> { From e3fad6df11d3b1de69cba1894878d9f418bd2dcd Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 12 Jul 2024 18:30:07 -0400 Subject: [PATCH 5/8] Reworked how SessionStateBuilder works from PR feedback. --- datafusion-cli/src/catalog.rs | 12 +- .../examples/custom_file_format.rs | 9 +- .../core/src/datasource/file_format/csv.rs | 7 +- datafusion/core/src/execution/context/mod.rs | 24 +- .../core/src/execution/session_state.rs | 651 ++++++++++++------ datafusion/core/src/physical_planner.rs | 7 +- datafusion/core/src/test/object_store.rs | 8 +- datafusion/core/tests/dataframe/mod.rs | 19 +- datafusion/core/tests/memory_limit/mod.rs | 13 +- .../core/tests/parquet/file_statistics.rs | 6 +- datafusion/core/tests/sql/create_drop.rs | 13 +- .../tests/user_defined/user_defined_plan.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 8 +- .../tests/cases/roundtrip_logical_plan.rs | 13 +- 14 files changed, 538 insertions(+), 258 deletions(-) diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index c7fca9b0121e..b83f65975610 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -163,7 +163,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { .ok_or_else(|| plan_datafusion_err!("locking error"))? .read() .clone(); - let mut builder = SessionStateBuilder::new_from_existing(&state); + let mut builder = SessionStateBuilder::from(state.clone()); let optimized_name = substitute_tilde(name.to_owned()); let table_url = ListingTableUrl::parse(optimized_name.as_str())?; let scheme = table_url.scheme(); @@ -180,12 +180,14 @@ impl SchemaProvider for DynamicFileSchemaProvider { // to any command options so the only choice is to use an empty collection match scheme { "s3" | "oss" | "cos" => { - builder = - builder.with_table_options_extension(AwsOptions::default()); + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(AwsOptions::default()) + } } "gs" | "gcs" => { - builder = - builder.with_table_options_extension(GcpOptions::default()) + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(GcpOptions::default()) + } } _ => {} }; diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index fe936418bce4..bdb702375c94 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -22,6 +22,7 @@ use arrow::{ datatypes::UInt64Type, }; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ datasource::{ file_format::{ @@ -32,9 +33,9 @@ use datafusion::{ MemTable, }, error::Result, - execution::{context::SessionState, runtime_env::RuntimeEnv}, + execution::context::SessionState, physical_plan::ExecutionPlan, - prelude::{SessionConfig, SessionContext}, + prelude::SessionContext, }; use datafusion_common::{GetExt, Statistics}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; @@ -176,9 +177,7 @@ impl GetExt for TSVFileFactory { #[tokio::main] async fn main() -> Result<()> { // Create a new context with the default configuration - let config = SessionConfig::new(); - let runtime = RuntimeEnv::default(); - let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime)); + let mut state = SessionStateBuilder::new().with_default_features().build(); // Register the custom file format let file_format = Arc::new(TSVFileFactory::new()); diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 92cb11e2b47a..baeaf51fb56d 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -632,6 +632,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::{col, lit}; + use crate::execution::session_state::SessionStateBuilder; use chrono::DateTime; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -814,7 +815,11 @@ mod tests { let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new()).unwrap()); let mut cfg = SessionConfig::new(); cfg.options_mut().catalog.has_header = true; - let session_state = SessionState::new_with_config_rt(cfg, runtime); + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap(); let path = Path::from("csv/aggregate_test_100.csv"); let csv = CsvFormat::default().with_has_header(true); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 63009f6eb959..640a9b14a65f 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -73,6 +73,7 @@ use object_store::ObjectStore; use parking_lot::RwLock; use url::Url; +use crate::execution::session_state::SessionStateBuilder; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; @@ -294,7 +295,11 @@ impl SessionContext { /// all `SessionContext`'s should be configured with the /// same `RuntimeEnv`. pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); Self::new_with_state(state) } @@ -1708,7 +1713,11 @@ mod tests { .set_str("datafusion.catalog.location", url.as_str()) .set_str("datafusion.catalog.format", "CSV") .set_str("datafusion.catalog.has_header", "true"); - let session_state = SessionState::new_with_config_rt(cfg, runtime); + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(session_state); ctx.refresh_catalogs().await?; @@ -1734,11 +1743,12 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); - let session_state = - SessionStateBuilder::new_with_config_rt(SessionConfig::new(), runtime) - .with_defaults(true) - .with_query_planner(Arc::new(MyQueryPlanner {})) - .build(); + let session_state = SessionStateBuilder::new() + .with_config(SessionConfig::new()) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(MyQueryPlanner {})) + .build(); let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 3e7c07158482..bdea5fde69bc 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -77,6 +77,8 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +use itertools::Itertools; +use log::{debug, info}; use sqlparser::ast::Expr as SQLExpr; use sqlparser::dialect::dialect_from_str; use std::collections::hash_map::Entry; @@ -89,6 +91,25 @@ use uuid::Uuid; /// Execution context for registering data sources and executing queries. /// See [`SessionContext`] for a higher level API. /// +/// Use the [`SessionStateBuilder`] to build a SessionState object. +/// +/// ``` +/// use datafusion::prelude::*; +/// # use datafusion::{error::Result, assert_batches_eq}; +/// # use datafusion::execution::session_state::SessionStateBuilder; +/// # use datafusion_execution::runtime_env::RuntimeEnv; +/// # use std::sync::Arc; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let state = SessionStateBuilder::new() +/// .with_config(SessionConfig::new()) +/// .with_runtime_env(Arc::new(RuntimeEnv::default())) +/// .with_default_features() +/// .build(); +/// Ok(()) +/// # } +/// ``` +/// /// Note that there is no `Default` or `new()` for SessionState, /// to avoid accidentally running queries or other operations without passing through /// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionStateBuilder`] and @@ -176,43 +197,56 @@ impl Debug for SessionState { impl SessionState { /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let catalog_list = - Arc::new(MemoryCatalogProviderList::new()) as Arc; - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] + #[deprecated(since = "32.0.0", note = "Use SessionStateBuilder")] pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - Self::new_with_config_rt(config, runtime) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogProviderList`] + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] pub fn new_with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, ) -> Self { - SessionStateBuilder::new_with_config_rt(config, runtime) - .with_defaults(true) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) .with_catalog_list(catalog_list) + .with_default_features() .build() } + /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated( - since = "32.0.0", - note = "Use SessionState::new_with_config_rt_and_catalog_list" - )] + #[deprecated(since = "32.0.0", note = "Use SessionStateBuilder")] pub fn with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, ) -> Self { - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_catalog_list(catalog_list) + .with_default_features() + .build() } pub(crate) fn resolve_table_ref( @@ -721,19 +755,20 @@ impl SessionState { &self.table_options } - /// Return mutable table opptions + /// Return mutable table options pub fn table_options_mut(&mut self) -> &mut TableOptions { &mut self.table_options } - /// Registers a [`ConfigExtension`] as a table option extention that can be + /// Registers a [`ConfigExtension`] as a table option extension that can be /// referenced from SQL statements executed against this context. pub fn register_table_options_extension(&mut self, extension: T) { self.table_options.extensions.insert(extension) } - /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or CREATE EXTERNAL TABLE statements for reading - /// and writing files of custom formats. + /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or + /// CREATE EXTERNAL TABLE statements for reading and writing files of custom + /// formats. pub fn register_file_format( &mut self, file_format: Arc, @@ -813,7 +848,7 @@ impl SessionState { ); } - /// Deregsiter a user defined table function + /// Deregister a user defined table function pub fn deregister_udtf( &mut self, name: &str, @@ -837,94 +872,124 @@ impl SessionState { } } -/// A builder to be used for building [`SessionState`]'s. Defaults will be used for all values -/// unless explicitly provided. Note that there is no `Default` or `new()` for SessionState, -/// to avoid accidentally running queries or other operations without passing through -/// the [`SessionConfig`] or [`RuntimeEnv`]. +/// A builder to be used for building [`SessionState`]'s. Defaults will +/// be used for all values unless explicitly provided. pub struct SessionStateBuilder { - state: SessionState, - use_defaults: bool, + session_id: Option, + analyzer: Option, + expr_planners: Option>>, + optimizer: Option, + physical_optimizers: Option, + query_planner: Option>, + catalog_list: Option>, + table_functions: Option>>, + scalar_functions: Option>>, + aggregate_functions: Option>>, + window_functions: Option>>, + serializer_registry: Option>, + file_formats: Option>>, + config: Option, + table_options: Option, + execution_props: Option, + table_factories: Option>>, + runtime_env: Option>, + function_factory: Option>, + // fields to support convenience functions + analyzer_rules: Option>>, + optimizer_rules: Option>>, + physical_optimizer_rules: Option>>, } impl SessionStateBuilder { - /// Returns new [`SessionStateBuilder`] using the provided - /// [`SessionConfig`] and [`RuntimeEnv`]. - pub fn new_with_config_rt( - config: SessionConfig, - runtime_env: Arc, - ) -> Self { - let session_id = Uuid::new_v4().to_string(); - let catalog_list = - Arc::new(MemoryCatalogProviderList::new()) as Arc; - + /// Returns a new [`SessionStateBuilder`] with no options set. + pub fn new() -> Self { Self { - state: SessionState { - session_id, - analyzer: Analyzer::new(), - expr_planners: vec![], - optimizer: Optimizer::new(), - physical_optimizers: PhysicalOptimizer::new(), - query_planner: Arc::new(DefaultQueryPlanner {}), - catalog_list, - table_functions: HashMap::new(), - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), - window_functions: HashMap::new(), - serializer_registry: Arc::new(EmptySerializerRegistry), - file_formats: HashMap::new(), - table_options: TableOptions::default_from_session_config( - config.options(), - ), - config, - execution_props: ExecutionProps::new(), - table_factories: HashMap::new(), - runtime_env, - function_factory: None, - }, - use_defaults: true, + session_id: None, + analyzer: None, + expr_planners: None, + optimizer: None, + physical_optimizers: None, + query_planner: None, + catalog_list: None, + table_functions: None, + scalar_functions: None, + aggregate_functions: None, + window_functions: None, + serializer_registry: None, + file_formats: None, + table_options: None, + config: None, + execution_props: None, + table_factories: None, + runtime_env: None, + function_factory: None, + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, } } /// Returns a new [SessionStateBuilder] based on an existing [SessionState] - /// The session id for the new builder will be set to a unique value; all - /// other fields will be cloned from what is set in the provided session state - pub fn new_from_existing(existing: &SessionState) -> Self { - let session_id = Uuid::new_v4().to_string(); + /// The session id for the new builder will be unset; all other fields will + /// be cloned from what is set in the provided session state + pub fn new_from_existing(existing: SessionState) -> Self { + let cloned = existing.clone(); Self { - state: SessionState { - session_id, - ..existing.clone() - }, - use_defaults: true, + session_id: None, + analyzer: Some(cloned.analyzer), + expr_planners: Some(cloned.expr_planners), + optimizer: Some(cloned.optimizer), + physical_optimizers: Some(cloned.physical_optimizers), + query_planner: Some(cloned.query_planner), + catalog_list: Some(cloned.catalog_list), + table_functions: Some(cloned.table_functions), + scalar_functions: Some(cloned.scalar_functions.into_values().collect_vec()), + aggregate_functions: Some( + cloned.aggregate_functions.into_values().collect_vec(), + ), + window_functions: Some(cloned.window_functions.into_values().collect_vec()), + serializer_registry: Some(cloned.serializer_registry), + file_formats: Some(cloned.file_formats.into_values().collect_vec()), + config: Some(cloned.config), + table_options: Some(cloned.table_options), + execution_props: Some(cloned.execution_props), + table_factories: Some(cloned.table_factories), + runtime_env: Some(cloned.runtime_env), + function_factory: cloned.function_factory, + + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, } } - /// Set to true (default = true) if defaults for table_factories, expr_planners, file formats - /// and builtin functions should be set. - /// Note that there is an explicit option for enabling catalog and schema default - /// via [SessionConfig::create_default_catalog_and_schema] which will only be used - /// if the use_defaults is enabled here. - /// Also note that if a field is explicitly set to a non-empty value - - /// for example by using the [SessionStateBuilder::with_file_formats] function, - /// then defaults for that field will not be set. - pub fn with_defaults(mut self, use_defaults: bool) -> Self { - self.use_defaults = use_defaults; + /// Set defaults for table_factories, file formats, expr_planners and builtin + /// scalar and aggregate functions. + pub fn with_default_features(mut self) -> Self { + self.table_factories = Some(SessionStateDefaults::default_table_factories()); + self.file_formats = Some(SessionStateDefaults::default_file_formats()); + self.expr_planners = Some(SessionStateDefaults::default_expr_planners()); + self.scalar_functions = Some(SessionStateDefaults::default_scalar_functions()); + self.aggregate_functions = + Some(SessionStateDefaults::default_aggregate_functions()); self } - /// Replace the random session id. + /// Set the session id. pub fn with_session_id(mut self, session_id: String) -> Self { - self.state.session_id = session_id; + self.session_id = Some(session_id); self } - /// Override the [`AnalyzerRule`]s optimizer plan rules. + /// Set the [`AnalyzerRule`]s optimizer plan rules. pub fn with_analyzer_rules( mut self, rules: Vec>, ) -> Self { - self.state.analyzer = Analyzer::with_rules(rules); + self.analyzer = Some(Analyzer::with_rules(rules)); self } @@ -934,16 +999,18 @@ impl SessionStateBuilder { mut self, analyzer_rule: Arc, ) -> Self { - self.state.analyzer.rules.push(analyzer_rule); + let mut rules = self.analyzer_rules.unwrap_or_default(); + rules.push(analyzer_rule); + self.analyzer_rules = Some(rules); self } - /// Replace the entire list of [`OptimizerRule`]s used to optimize plans + /// Set the [`OptimizerRule`]s used to optimize plans. pub fn with_optimizer_rules( mut self, rules: Vec>, ) -> Self { - self.state.optimizer = Optimizer::with_rules(rules); + self.optimizer = Some(Optimizer::with_rules(rules)); self } @@ -953,26 +1020,28 @@ impl SessionStateBuilder { mut self, optimizer_rule: Arc, ) -> Self { - self.state.optimizer.rules.push(optimizer_rule); + let mut rules = self.optimizer_rules.unwrap_or_default(); + rules.push(optimizer_rule); + self.optimizer_rules = Some(rules); self } - /// Replace the entire list of [`ExprPlanner`]s used to customize the behavior of the SQL planner + /// Set the [`ExprPlanner`]s used to customize the behavior of the SQL planner. pub fn with_expr_planners( mut self, expr_planners: Vec>, ) -> Self { - self.state.expr_planners = expr_planners; + self.expr_planners = Some(expr_planners); self } - /// Replace the entire list of [`PhysicalOptimizerRule`]s used to optimize plans + /// Set tje [`PhysicalOptimizerRule`]s used to optimize plans. pub fn with_physical_optimizer_rules( mut self, physical_optimizers: Vec>, ) -> Self { - self.state.physical_optimizers = - PhysicalOptimizer::with_rules(physical_optimizers); + self.physical_optimizers = + Some(PhysicalOptimizer::with_rules(physical_optimizers)); self } @@ -982,175 +1051,368 @@ impl SessionStateBuilder { mut self, physical_optimizer_rule: Arc, ) -> Self { - self.state - .physical_optimizers - .rules - .push(physical_optimizer_rule); + let mut rules = self.physical_optimizer_rules.unwrap_or_default(); + rules.push(physical_optimizer_rule); + self.physical_optimizer_rules = Some(rules); self } - /// override default query planner with `query_planner` + /// Set the [`QueryPlanner`] pub fn with_query_planner( mut self, query_planner: Arc, ) -> Self { - self.state.query_planner = query_planner; + self.query_planner = Some(query_planner); self } - /// override default catalog list with `catalog_list` + /// Set the [`CatalogProviderList`] pub fn with_catalog_list( mut self, catalog_list: Arc, ) -> Self { - self.state.catalog_list = catalog_list; + self.catalog_list = Some(catalog_list); self } - /// override default table functions with `table_functions` + /// Set the map of [`TableFunction`]s pub fn with_table_functions( mut self, table_functions: HashMap>, ) -> Self { - self.state.table_functions = table_functions; + self.table_functions = Some(table_functions); self } - /// override default scalar functions with `scalar_functions` + /// Set the map of [`ScalarUDF`]s pub fn with_scalar_functions( mut self, - scalar_functions: HashMap>, + scalar_functions: Vec>, ) -> Self { - self.state.scalar_functions = scalar_functions; + self.scalar_functions = Some(scalar_functions); self } - /// override default aggregate functions with `aggregate_functions` + /// Set the map of [`AggregateUDF`]s pub fn with_aggregate_functions( mut self, - aggregate_functions: HashMap>, + aggregate_functions: Vec>, ) -> Self { - self.state.aggregate_functions = aggregate_functions; + self.aggregate_functions = Some(aggregate_functions); self } - /// override default window functions with `window_functions` + /// Set the map of [`WindowUDF`]s pub fn with_window_functions( mut self, - window_functions: HashMap>, + window_functions: Vec>, ) -> Self { - self.state.window_functions = window_functions; + self.window_functions = Some(window_functions); self } - /// Registers a [`SerializerRegistry`] + /// Set the [`SerializerRegistry`] pub fn with_serializer_registry( mut self, serializer_registry: Arc, ) -> Self { - self.state.serializer_registry = serializer_registry; + self.serializer_registry = Some(serializer_registry); self } - /// override default list of file formats with `file_formats` + /// Set the map of [`FileFormatFactory`]s pub fn with_file_formats( mut self, - file_formats: HashMap>, + file_formats: Vec>, ) -> Self { - self.state.file_formats = file_formats; + self.file_formats = Some(file_formats); self } - /// override the session config with `config` + /// Set the [`SessionConfig`] pub fn with_config(mut self, config: SessionConfig) -> Self { - self.state.config = config; + self.config = Some(config); self } - /// override default table options with `table_options` + /// Set the [`TableOptions`] pub fn with_table_options(mut self, table_options: TableOptions) -> Self { - self.state.table_options = table_options; + self.table_options = Some(table_options); self } - /// Adds a new [`ConfigExtension`] to TableOptions - pub fn with_table_options_extension( - mut self, - extension: T, - ) -> Self { - self.state.table_options.extensions.insert(extension); - self - } - - /// override default execution props with `execution_props` + /// Set the [`ExecutionProps`] pub fn with_execution_props(mut self, execution_props: ExecutionProps) -> Self { - self.state.execution_props = execution_props; + self.execution_props = Some(execution_props); self } - /// override default table factories with `table_factories` + /// Set the map of [`TableProviderFactory`]s pub fn with_table_factories( mut self, table_factories: HashMap>, ) -> Self { - self.state.table_factories = table_factories; + self.table_factories = Some(table_factories); self } - /// override the runtime env with `runtime_env` + /// Set the [`RuntimeEnv`] pub fn with_runtime_env(mut self, runtime_env: Arc) -> Self { - self.state.runtime_env = runtime_env; + self.runtime_env = Some(runtime_env); self } - /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + /// Set a [`FunctionFactory`] to handle `CREATE FUNCTION` statements pub fn with_function_factory( mut self, function_factory: Option>, ) -> Self { - self.state.function_factory = function_factory; + self.function_factory = function_factory; self } - /// build a [`SessionState`] with the current configuration - pub fn build(mut self) -> SessionState { - if self.use_defaults { - if self.state.table_factories.is_empty() { - self.state.table_factories = - SessionStateDefaults::default_table_factories(); - } - if self.state.expr_planners.is_empty() { - self.state.expr_planners = SessionStateDefaults::default_expr_planners(); - } + /// Builds a [`SessionState`] with the current configuration. + /// + /// Note that there is an explicit option for enabling catalog and schema defaults + /// in [SessionConfig::create_default_catalog_and_schema] which if enabled + /// will be built here. + pub fn build(self) -> SessionState { + let config = self.config.unwrap_or_default(); + let runtime_env = self.runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + + let mut state = SessionState { + session_id: self.session_id.unwrap_or(Uuid::new_v4().to_string()), + analyzer: self.analyzer.unwrap_or_default(), + expr_planners: self.expr_planners.unwrap_or_default(), + optimizer: self.optimizer.unwrap_or_default(), + physical_optimizers: self.physical_optimizers.unwrap_or_default(), + query_planner: self + .query_planner + .unwrap_or(Arc::new(DefaultQueryPlanner {})), + catalog_list: self + .catalog_list + .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) + as Arc), + table_functions: self.table_functions.unwrap_or_default(), + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), + serializer_registry: self + .serializer_registry + .unwrap_or(Arc::new(EmptySerializerRegistry)), + file_formats: HashMap::new(), + table_options: self + .table_options + .unwrap_or(TableOptions::default_from_session_config(config.options())), + config, + execution_props: self.execution_props.unwrap_or_default(), + table_factories: self.table_factories.unwrap_or_default(), + runtime_env, + function_factory: self.function_factory, + }; - if self.state.config.create_default_catalog_and_schema() { - let default_catalog = SessionStateDefaults::default_catalog( - &self.state.config, - &self.state.table_factories, - &self.state.runtime_env, - ); - - self.state.catalog_list.register_catalog( - self.state.config.options().catalog.default_catalog.clone(), - Arc::new(default_catalog), - ); + if let Some(file_formats) = self.file_formats { + for file_format in file_formats { + if let Err(e) = state.register_file_format(file_format, false) { + info!("Unable to register file format: {e}") + }; } + } - if self.state.file_formats.is_empty() { - SessionStateDefaults::register_default_file_formats(&mut self.state); + if let Some(scalar_functions) = self.scalar_functions { + scalar_functions.into_iter().for_each(|udf| { + let existing_udf = state.register_udf(udf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(aggregate_functions) = self.aggregate_functions { + aggregate_functions.into_iter().for_each(|udaf| { + let existing_udf = state.register_udaf(udaf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(window_functions) = self.window_functions { + window_functions.into_iter().for_each(|udwf| { + let existing_udf = state.register_udwf(udwf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if state.config.create_default_catalog_and_schema() { + let default_catalog = SessionStateDefaults::default_catalog( + &state.config, + &state.table_factories, + &state.runtime_env, + ); + + state.catalog_list.register_catalog( + state.config.options().catalog.default_catalog.clone(), + Arc::new(default_catalog), + ); + } + + if let Some(analyzer_rules) = self.analyzer_rules { + for analyzer_rule in analyzer_rules { + state.analyzer.rules.push(analyzer_rule); } + } - if self.state.scalar_functions.is_empty() { - SessionStateDefaults::register_scalar_functions(&mut self.state); - SessionStateDefaults::register_array_functions(&mut self.state); + if let Some(optimizer_rules) = self.optimizer_rules { + for optimizer_rule in optimizer_rules { + state.optimizer.rules.push(optimizer_rule); } + } - if self.state.aggregate_functions.is_empty() { - SessionStateDefaults::register_aggregate_functions(&mut self.state); + if let Some(physical_optimizer_rules) = self.physical_optimizer_rules { + for physical_optimizer_rule in physical_optimizer_rules { + state + .physical_optimizers + .rules + .push(physical_optimizer_rule); } } - self.state.clone() + state + } + + /// Returns the current session_id value + pub fn session_id(&self) -> &Option { + &self.session_id + } + + /// Returns the current analyzer value + pub fn analyzer(&mut self) -> &mut Option { + &mut self.analyzer + } + + /// Returns the current expr_planners value + pub fn expr_planners(&mut self) -> &mut Option>> { + &mut self.expr_planners + } + + /// Returns the current optimizer value + pub fn optimizer(&mut self) -> &mut Option { + &mut self.optimizer + } + + /// Returns the current physical_optimizers value + pub fn physical_optimizers(&mut self) -> &mut Option { + &mut self.physical_optimizers + } + + /// Returns the current query_planner value + pub fn query_planner(&mut self) -> &mut Option> { + &mut self.query_planner + } + + /// Returns the current catalog_list value + pub fn catalog_list(&mut self) -> &mut Option> { + &mut self.catalog_list + } + + /// Returns the current table_functions value + pub fn table_functions( + &mut self, + ) -> &mut Option>> { + &mut self.table_functions + } + + /// Returns the current scalar_functions value + pub fn scalar_functions(&mut self) -> &mut Option>> { + &mut self.scalar_functions + } + + /// Returns the current aggregate_functions value + pub fn aggregate_functions(&mut self) -> &mut Option>> { + &mut self.aggregate_functions + } + + /// Returns the current window_functions value + pub fn window_functions(&mut self) -> &mut Option>> { + &mut self.window_functions + } + + /// Returns the current serializer_registry value + pub fn serializer_registry(&mut self) -> &mut Option> { + &mut self.serializer_registry + } + + /// Returns the current file_formats value + pub fn file_formats(&mut self) -> &mut Option>> { + &mut self.file_formats + } + + /// Returns the current session_config value + pub fn config(&mut self) -> &mut Option { + &mut self.config + } + + /// Returns the current table_options value + pub fn table_options(&mut self) -> &mut Option { + &mut self.table_options + } + + /// Returns the current execution_props value + pub fn execution_props(&mut self) -> &mut Option { + &mut self.execution_props + } + + /// Returns the current table_factories value + pub fn table_factories( + &mut self, + ) -> &mut Option>> { + &mut self.table_factories + } + + /// Returns the current runtime_env value + pub fn runtime_env(&mut self) -> &mut Option> { + &mut self.runtime_env + } + + /// Returns the current function_factory value + pub fn function_factory(&mut self) -> &mut Option> { + &mut self.function_factory + } + + /// Returns the current analyzer_rules value + pub fn analyzer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.analyzer_rules + } + + /// Returns the current optimizer_rules value + pub fn optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.optimizer_rules + } + + /// Returns the current physical_optimizer_rules value + pub fn physical_optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.physical_optimizer_rules + } +} + +impl Default for SessionStateBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for SessionStateBuilder { + fn from(state: SessionState) -> Self { + SessionStateBuilder::new_from_existing(state) } } @@ -1213,6 +1475,33 @@ impl SessionStateDefaults { expr_planners } + /// returns a map of default [`ScalarUDF']'s keyed by name + pub fn default_scalar_functions() -> Vec> { + let mut functions: Vec> = functions::all_default_functions(); + functions.append(&mut functions_array::all_default_array_functions()); + + functions + } + + /// returns a map of default [`AggregateUDF']'s keyed by named + pub fn default_aggregate_functions() -> Vec> { + functions_aggregate::all_default_aggregate_functions() + } + + /// returns a map of default [`FileFormatFactory']'s keyed by extension + pub fn default_file_formats() -> Vec> { + let file_formats: Vec> = vec![ + #[cfg(feature = "parquet")] + Arc::new(ParquetFormatFactory::new()), + Arc::new(JsonFormatFactory::new()), + Arc::new(CsvFormatFactory::new()), + Arc::new(ArrowFormatFactory::new()), + Arc::new(AvroFormatFactory::new()), + ]; + + file_formats + } + /// registers all builtin functions - scalar, array and aggregate pub fn register_builtin_functions(state: &mut SessionState) { Self::register_scalar_functions(state); @@ -1280,36 +1569,12 @@ impl SessionStateDefaults { /// registers the default [`FileFormatFactory`]s pub fn register_default_file_formats(state: &mut SessionState) { - #[cfg(feature = "parquet")] - if let Err(e) = - state.register_file_format(Arc::new(ParquetFormatFactory::new()), false) - { - log::info!("Unable to register default ParquetFormat: {e}") - }; - - if let Err(e) = - state.register_file_format(Arc::new(JsonFormatFactory::new()), false) - { - log::info!("Unable to register default JsonFormat: {e}") - }; - - if let Err(e) = - state.register_file_format(Arc::new(CsvFormatFactory::new()), false) - { - log::info!("Unable to register default CsvFormat: {e}") - }; - - if let Err(e) = - state.register_file_format(Arc::new(ArrowFormatFactory::new()), false) - { - log::info!("Unable to register default ArrowFormat: {e}") - }; - - if let Err(e) = - state.register_file_format(Arc::new(AvroFormatFactory::new()), false) - { - log::info!("Unable to register default AvroFormat: {e}") - }; + let formats = SessionStateDefaults::default_file_formats(); + for format in formats { + if let Err(e) = state.register_file_format(format, false) { + log::info!("Unable to register default file format: {e}") + }; + } } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6aad4d575532..7ae8b5247f19 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2269,6 +2269,7 @@ mod tests { use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util::{scan_empty, scan_empty_with_partitions}; + use crate::execution::session_state::SessionStateBuilder; use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; @@ -2282,7 +2283,11 @@ mod tests { let runtime = Arc::new(RuntimeEnv::default()); let config = SessionConfig::new().with_target_partitions(4); let config = config.set_bool("datafusion.optimizer.skip_failed_rules", false); - SessionState::new_with_config_rt(config, runtime) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } async fn plan(logical_plan: &LogicalPlan) -> Result> { diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index bea6f7b9ceb7..6c0a2fc7bec4 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -16,9 +16,8 @@ // under the License. //! Object store implementation used for testing use crate::execution::context::SessionState; +use crate::execution::session_state::SessionStateBuilder; use crate::prelude::SessionContext; -use datafusion_execution::config::SessionConfig; -use datafusion_execution::runtime_env::RuntimeEnv; use futures::FutureExt; use object_store::{memory::InMemory, path::Path, ObjectMeta, ObjectStore}; use std::sync::Arc; @@ -44,10 +43,7 @@ pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, Sessi ( Arc::new(memory), - SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + SessionStateBuilder::new().with_default_features().build(), ) } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2d1904d9e166..ee20f4df2f5e 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -42,7 +42,8 @@ use url::Url; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::execution::context::{SessionContext, SessionState}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::{parquet_test_data, populate_csv_partitions}; @@ -1544,7 +1545,11 @@ async fn unnest_non_nullable_list() -> Result<()> { async fn test_read_batches() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![ @@ -1594,7 +1599,11 @@ async fn test_read_batches() -> Result<()> { async fn test_read_batches_empty() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let batches = vec![]; @@ -1608,9 +1617,7 @@ async fn test_read_batches_empty() -> Result<()> { #[tokio::test] async fn consecutive_projection_same_schema() -> Result<()> { - let config = SessionConfig::new(); - let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new().with_default_features().build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 4d28a2056692..60f126e6d73d 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -460,13 +460,16 @@ impl TestCase { let runtime = RuntimeEnv::new(rt_config).unwrap(); // Configure execution - let state = SessionStateBuilder::new_with_config_rt(config, Arc::new(runtime)); - let state = match scenario.rules() { - Some(rules) => state.with_physical_optimizer_rules(rules), - None => state, + let builder = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(Arc::new(runtime)) + .with_default_features(); + let builder = match scenario.rules() { + Some(rules) => builder.with_physical_optimizer_rules(rules), + None => builder, }; - let ctx = SessionContext::new_with_state(state.build()); + let ctx = SessionContext::new_with_state(builder.build()); ctx.register_table("t", table).expect("registering table"); let query = query.expect("Test error: query not specified"); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 9f94a59a3e59..bf25b36f48e8 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -35,6 +35,7 @@ use datafusion_execution::cache::cache_unit::{ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use tempfile::tempdir; #[tokio::test] @@ -167,10 +168,7 @@ async fn get_listing_table( ) -> ListingTable { let schema = opt .infer_schema( - &SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + &SessionStateBuilder::new().with_default_features().build(), table_path, ) .await diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 2174009b8557..83712053b954 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::test_util::TestTableFactory; use super::*; #[tokio::test] async fn create_custom_table() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); @@ -45,10 +41,7 @@ async fn create_custom_table() -> Result<()> { #[tokio::test] async fn create_external_table_with_ddl() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {})); diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 20d315fe97ab..79dc28d4f44b 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -291,8 +291,10 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionStateBuilder::new_with_config_rt(config, runtime) - .with_defaults(true) + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})) .add_analyzer_rule(Arc::new(MyAnalyzerRule {})) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f764a050a6cd..d0209d811b7c 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -39,8 +39,7 @@ use prost::Message; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ @@ -202,10 +201,7 @@ async fn roundtrip_custom_tables() -> Result<()> { let mut table_factories: HashMap> = HashMap::new(); table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); // replace factories *state.table_factories_mut() = table_factories; let ctx = SessionContext::new_with_state(state); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 4866fa587939..5b2d0fbacaef 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -1121,13 +1121,12 @@ async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { } async fn create_context() -> Result { - let mut state = SessionStateBuilder::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ) - .with_defaults(true) - .with_serializer_registry(Arc::new(MockSerializerRegistry)) - .build(); + let mut state = SessionStateBuilder::new() + .with_config(SessionConfig::default()) + .with_runtime_env(Arc::new(RuntimeEnv::default())) + .with_default_features() + .with_serializer_registry(Arc::new(MockSerializerRegistry)) + .build(); // register udaf for test, e.g. `sum()` datafusion_functions_aggregate::register_all(&mut state) From 8a579cf05d1696c932afd5e8bfc601d95a1097f0 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 12 Jul 2024 18:38:37 -0400 Subject: [PATCH 6/8] Bug fix for missing array_expressions cfg feature. --- datafusion/core/src/execution/session_state.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index bdea5fde69bc..0c438cdff220 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -966,7 +966,7 @@ impl SessionStateBuilder { } } - /// Set defaults for table_factories, file formats, expr_planners and builtin + /// Set defaults for table_factories, file formats, expr_planners and builtin /// scalar and aggregate functions. pub fn with_default_features(mut self) -> Self { self.table_factories = Some(SessionStateDefaults::default_table_factories()); @@ -1478,6 +1478,7 @@ impl SessionStateDefaults { /// returns a map of default [`ScalarUDF']'s keyed by name pub fn default_scalar_functions() -> Vec> { let mut functions: Vec> = functions::all_default_functions(); + #[cfg(feature = "array_expressions")] functions.append(&mut functions_array::all_default_array_functions()); functions From 71decf6066d351d8198319919c2bbc8fc715bfb0 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 14 Jul 2024 13:56:26 -0400 Subject: [PATCH 7/8] Review feedback updates + doc fixes for SessionStateDefaults --- .../core/src/execution/session_state.rs | 126 ++++++++++-------- .../tests/user_defined/user_defined_plan.rs | 4 +- 2 files changed, 75 insertions(+), 55 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d5542de0d737..6ab4d1558800 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -872,6 +872,8 @@ impl SessionState { /// A builder to be used for building [`SessionState`]'s. Defaults will /// be used for all values unless explicitly provided. +/// +/// See example on [`SessionState`] pub struct SessionStateBuilder { session_id: Option, analyzer: Option, @@ -932,30 +934,28 @@ impl SessionStateBuilder { /// The session id for the new builder will be unset; all other fields will /// be cloned from what is set in the provided session state pub fn new_from_existing(existing: SessionState) -> Self { - let cloned = existing.clone(); - Self { session_id: None, - analyzer: Some(cloned.analyzer), - expr_planners: Some(cloned.expr_planners), - optimizer: Some(cloned.optimizer), - physical_optimizers: Some(cloned.physical_optimizers), - query_planner: Some(cloned.query_planner), - catalog_list: Some(cloned.catalog_list), - table_functions: Some(cloned.table_functions), - scalar_functions: Some(cloned.scalar_functions.into_values().collect_vec()), + analyzer: Some(existing.analyzer), + expr_planners: Some(existing.expr_planners), + optimizer: Some(existing.optimizer), + physical_optimizers: Some(existing.physical_optimizers), + query_planner: Some(existing.query_planner), + catalog_list: Some(existing.catalog_list), + table_functions: Some(existing.table_functions), + scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), aggregate_functions: Some( - cloned.aggregate_functions.into_values().collect_vec(), + existing.aggregate_functions.into_values().collect_vec(), ), - window_functions: Some(cloned.window_functions.into_values().collect_vec()), - serializer_registry: Some(cloned.serializer_registry), - file_formats: Some(cloned.file_formats.into_values().collect_vec()), - config: Some(cloned.config), - table_options: Some(cloned.table_options), - execution_props: Some(cloned.execution_props), - table_factories: Some(cloned.table_factories), - runtime_env: Some(cloned.runtime_env), - function_factory: cloned.function_factory, + window_functions: Some(existing.window_functions.into_values().collect_vec()), + serializer_registry: Some(existing.serializer_registry), + file_formats: Some(existing.file_formats.into_values().collect_vec()), + config: Some(existing.config), + table_options: Some(existing.table_options), + execution_props: Some(existing.execution_props), + table_factories: Some(existing.table_factories), + runtime_env: Some(existing.runtime_env), + function_factory: existing.function_factory, // fields to support convenience functions analyzer_rules: None, @@ -993,7 +993,7 @@ impl SessionStateBuilder { /// Add `analyzer_rule` to the end of the list of /// [`AnalyzerRule`]s used to rewrite queries. - pub fn add_analyzer_rule( + pub fn with_analyzer_rule( mut self, analyzer_rule: Arc, ) -> Self { @@ -1014,7 +1014,7 @@ impl SessionStateBuilder { /// Add `optimizer_rule` to the end of the list of /// [`OptimizerRule`]s used to rewrite queries. - pub fn add_optimizer_rule( + pub fn with_optimizer_rule( mut self, optimizer_rule: Arc, ) -> Self { @@ -1045,7 +1045,7 @@ impl SessionStateBuilder { /// Add `physical_optimizer_rule` to the end of the list of /// [`PhysicalOptimizerRule`]s used to rewrite queries. - pub fn add_physical_optimizer_rule( + pub fn with_physical_optimizer_rule( mut self, physical_optimizer_rule: Arc, ) -> Self { @@ -1175,41 +1175,61 @@ impl SessionStateBuilder { /// in [SessionConfig::create_default_catalog_and_schema] which if enabled /// will be built here. pub fn build(self) -> SessionState { - let config = self.config.unwrap_or_default(); - let runtime_env = self.runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + let Self { + session_id, + analyzer, + expr_planners, + optimizer, + physical_optimizers, + query_planner, + catalog_list, + table_functions, + scalar_functions, + aggregate_functions, + window_functions, + serializer_registry, + file_formats, + table_options, + config, + execution_props, + table_factories, + runtime_env, + function_factory, + analyzer_rules, + optimizer_rules, + physical_optimizer_rules, + } = self; + + let config = config.unwrap_or_default(); + let runtime_env = runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); let mut state = SessionState { - session_id: self.session_id.unwrap_or(Uuid::new_v4().to_string()), - analyzer: self.analyzer.unwrap_or_default(), - expr_planners: self.expr_planners.unwrap_or_default(), - optimizer: self.optimizer.unwrap_or_default(), - physical_optimizers: self.physical_optimizers.unwrap_or_default(), - query_planner: self - .query_planner - .unwrap_or(Arc::new(DefaultQueryPlanner {})), - catalog_list: self - .catalog_list + session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), + analyzer: analyzer.unwrap_or_default(), + expr_planners: expr_planners.unwrap_or_default(), + optimizer: optimizer.unwrap_or_default(), + physical_optimizers: physical_optimizers.unwrap_or_default(), + query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), + catalog_list: catalog_list .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) as Arc), - table_functions: self.table_functions.unwrap_or_default(), + table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), - serializer_registry: self - .serializer_registry + serializer_registry: serializer_registry .unwrap_or(Arc::new(EmptySerializerRegistry)), file_formats: HashMap::new(), - table_options: self - .table_options + table_options: table_options .unwrap_or(TableOptions::default_from_session_config(config.options())), config, - execution_props: self.execution_props.unwrap_or_default(), - table_factories: self.table_factories.unwrap_or_default(), + execution_props: execution_props.unwrap_or_default(), + table_factories: table_factories.unwrap_or_default(), runtime_env, - function_factory: self.function_factory, + function_factory, }; - if let Some(file_formats) = self.file_formats { + if let Some(file_formats) = file_formats { for file_format in file_formats { if let Err(e) = state.register_file_format(file_format, false) { info!("Unable to register file format: {e}") @@ -1217,7 +1237,7 @@ impl SessionStateBuilder { } } - if let Some(scalar_functions) = self.scalar_functions { + if let Some(scalar_functions) = scalar_functions { scalar_functions.into_iter().for_each(|udf| { let existing_udf = state.register_udf(udf); if let Ok(Some(existing_udf)) = existing_udf { @@ -1226,7 +1246,7 @@ impl SessionStateBuilder { }); } - if let Some(aggregate_functions) = self.aggregate_functions { + if let Some(aggregate_functions) = aggregate_functions { aggregate_functions.into_iter().for_each(|udaf| { let existing_udf = state.register_udaf(udaf); if let Ok(Some(existing_udf)) = existing_udf { @@ -1235,7 +1255,7 @@ impl SessionStateBuilder { }); } - if let Some(window_functions) = self.window_functions { + if let Some(window_functions) = window_functions { window_functions.into_iter().for_each(|udwf| { let existing_udf = state.register_udwf(udwf); if let Ok(Some(existing_udf)) = existing_udf { @@ -1257,19 +1277,19 @@ impl SessionStateBuilder { ); } - if let Some(analyzer_rules) = self.analyzer_rules { + if let Some(analyzer_rules) = analyzer_rules { for analyzer_rule in analyzer_rules { state.analyzer.rules.push(analyzer_rule); } } - if let Some(optimizer_rules) = self.optimizer_rules { + if let Some(optimizer_rules) = optimizer_rules { for optimizer_rule in optimizer_rules { state.optimizer.rules.push(optimizer_rule); } } - if let Some(physical_optimizer_rules) = self.physical_optimizer_rules { + if let Some(physical_optimizer_rules) = physical_optimizer_rules { for physical_optimizer_rule in physical_optimizer_rules { state .physical_optimizers @@ -1473,7 +1493,7 @@ impl SessionStateDefaults { expr_planners } - /// returns a map of default [`ScalarUDF']'s keyed by name + /// returns the list of default [`ScalarUDF']'s pub fn default_scalar_functions() -> Vec> { let mut functions: Vec> = functions::all_default_functions(); #[cfg(feature = "array_expressions")] @@ -1482,12 +1502,12 @@ impl SessionStateDefaults { functions } - /// returns a map of default [`AggregateUDF']'s keyed by named + /// returns the list of default [`AggregateUDF']'s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() } - /// returns a map of default [`FileFormatFactory']'s keyed by extension + /// returns the list of default [`FileFormatFactory']'s pub fn default_file_formats() -> Vec> { let file_formats: Vec> = vec![ #[cfg(feature = "parquet")] diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 79dc28d4f44b..a44f522ba95a 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -296,8 +296,8 @@ fn make_topk_context() -> SessionContext { .with_runtime_env(runtime) .with_default_features() .with_query_planner(Arc::new(TopKQueryPlanner {})) - .add_optimizer_rule(Arc::new(TopKOptimizerRule {})) - .add_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .with_optimizer_rule(Arc::new(TopKOptimizerRule {})) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) .build(); SessionContext::new_with_state(state) } From 380ce9813edebdb4cdb1c84213fc9b5638ebb1aa Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 14 Jul 2024 13:59:22 -0400 Subject: [PATCH 8/8] Cargo fmt update. --- datafusion/core/src/execution/session_state.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 6ab4d1558800..75eef4345487 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1502,12 +1502,12 @@ impl SessionStateDefaults { functions } - /// returns the list of default [`AggregateUDF']'s + /// returns the list of default [`AggregateUDF']'s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() } - /// returns the list of default [`FileFormatFactory']'s + /// returns the list of default [`FileFormatFactory']'s pub fn default_file_formats() -> Vec> { let file_formats: Vec> = vec![ #[cfg(feature = "parquet")]