diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 179fc108e6d2..ed23fada0cfb 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -33,7 +33,7 @@ use datafusion_functions::core::planner::CoreFunctionPlanner; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; -use crate::common::MockContextProvider; +use crate::common::{MockContextProvider, MockSessionState}; #[test] fn roundtrip_expr() { @@ -59,8 +59,8 @@ fn roundtrip_expr() { let roundtrip = |table, sql: &str| -> Result { let dialect = GenericDialect {}; let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; - - let context = MockContextProvider::default().with_udaf(sum_udaf()); + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); + let context = MockContextProvider { state }; let schema = context.get_table_source(table)?.schema(); let df_schema = DFSchema::try_from(schema.as_ref().clone())?; let sql_to_rel = SqlToRel::new(&context); @@ -156,11 +156,11 @@ fn roundtrip_statement() -> Result<()> { let statement = Parser::new(&dialect) .try_with_sql(query)? .parse_statement()?; - - let context = MockContextProvider::default() - .with_udaf(sum_udaf()) - .with_udaf(count_udaf()) + let state = MockSessionState::default() + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(count_udaf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -189,8 +189,10 @@ fn roundtrip_crossjoin() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default() + let state = MockSessionState::default() .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + + let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -412,10 +414,12 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .try_with_sql(query.sql)? .parse_statement()?; - let context = MockContextProvider::default() - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) - .with_udaf(max_udaf()) - .with_udaf(min_udaf()); + let state = MockSessionState::default() + .with_aggregate_function(max_udaf()) + .with_aggregate_function(min_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + + let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel .sql_statement_to_plan(statement) @@ -443,7 +447,9 @@ fn test_unnest_logical_plan() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default(); + let context = MockContextProvider { + state: MockSessionState::default(), + }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -516,7 +522,9 @@ fn test_pretty_roundtrip() -> Result<()> { let df_schema = DFSchema::try_from(schema)?; - let context = MockContextProvider::default(); + let context = MockContextProvider { + state: MockSessionState::default(), + }; let sql_to_rel = SqlToRel::new(&context); let unparser = Unparser::default().with_pretty(true); @@ -589,7 +597,9 @@ fn sql_round_trip(query: &str, expect: &str) { .parse_statement() .unwrap(); - let context = MockContextProvider::default(); + let context = MockContextProvider { + state: MockSessionState::default(), + }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 374aa9db6714..fe0e5f7283a4 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -50,36 +50,40 @@ impl Display for MockCsvType { } #[derive(Default)] -pub(crate) struct MockContextProvider { - options: ConfigOptions, - udfs: HashMap>, - udafs: HashMap>, +pub(crate) struct MockSessionState { + scalar_functions: HashMap>, + aggregate_functions: HashMap>, expr_planners: Vec>, + pub config_options: ConfigOptions, } -impl MockContextProvider { - // Suppressing dead code warning, as this is used in integration test crates - #[allow(dead_code)] - pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options +impl MockSessionState { + pub fn with_expr_planner(mut self, expr_planner: Arc) -> Self { + self.expr_planners.push(expr_planner); + self } - #[allow(dead_code)] - pub(crate) fn with_udf(mut self, udf: ScalarUDF) -> Self { - self.udfs.insert(udf.name().to_string(), Arc::new(udf)); + pub fn with_scalar_function(mut self, scalar_function: Arc) -> Self { + self.scalar_functions + .insert(scalar_function.name().to_string(), scalar_function); self } - pub(crate) fn with_udaf(mut self, udaf: Arc) -> Self { + pub fn with_aggregate_function( + mut self, + aggregate_function: Arc, + ) -> Self { // TODO: change to to_string() if all the function name is converted to lowercase - self.udafs.insert(udaf.name().to_lowercase(), udaf); + self.aggregate_functions.insert( + aggregate_function.name().to_string().to_lowercase(), + aggregate_function, + ); self } +} - pub(crate) fn with_expr_planner(mut self, planner: Arc) -> Self { - self.expr_planners.push(planner); - self - } +pub(crate) struct MockContextProvider { + pub(crate) state: MockSessionState, } impl ContextProvider for MockContextProvider { @@ -202,11 +206,11 @@ impl ContextProvider for MockContextProvider { } fn get_function_meta(&self, name: &str) -> Option> { - self.udfs.get(name).cloned() + self.state.scalar_functions.get(name).cloned() } fn get_aggregate_meta(&self, name: &str) -> Option> { - self.udafs.get(name).cloned() + self.state.aggregate_functions.get(name).cloned() } fn get_variable_type(&self, _: &[String]) -> Option { @@ -218,7 +222,7 @@ impl ContextProvider for MockContextProvider { } fn options(&self) -> &ConfigOptions { - &self.options + &self.state.config_options } fn get_file_type( @@ -237,11 +241,11 @@ impl ContextProvider for MockContextProvider { } fn udf_names(&self) -> Vec { - self.udfs.keys().cloned().collect() + self.state.scalar_functions.keys().cloned().collect() } fn udaf_names(&self) -> Vec { - self.udafs.keys().cloned().collect() + self.state.aggregate_functions.keys().cloned().collect() } fn udwf_names(&self) -> Vec { @@ -249,7 +253,7 @@ impl ContextProvider for MockContextProvider { } fn get_expr_planners(&self) -> &[Arc] { - &self.expr_planners + &self.state.expr_planners } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 4d7e60805657..5a0317c47c85 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -41,6 +41,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; +use crate::common::MockSessionState; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf, @@ -1495,8 +1496,9 @@ fn recursive_ctes_disabled() { select * from numbers;"; // manually setting up test here so that we can disable recursive ctes - let mut context = MockContextProvider::default(); - context.options_mut().execution.enable_recursive_ctes = false; + let mut state = MockSessionState::default(); + state.config_options.execution.enable_recursive_ctes = false; + let context = MockContextProvider { state }; let planner = SqlToRel::new_with_options(&context, ParserOptions::default()); let result = DFParser::parse_sql_with_dialect(sql, &GenericDialect {}); @@ -2727,7 +2729,8 @@ fn logical_plan_with_options(sql: &str, options: ParserOptions) -> Result Result { - let context = MockContextProvider::default().with_udaf(sum_udaf()); + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); + let context = MockContextProvider { state }; let planner = SqlToRel::new(&context); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; @@ -2739,39 +2742,44 @@ fn logical_plan_with_dialect_and_options( dialect: &dyn Dialect, options: ParserOptions, ) -> Result { - let context = MockContextProvider::default() - .with_udf(unicode::character_length().as_ref().clone()) - .with_udf(string::concat().as_ref().clone()) - .with_udf(make_udf( + let state = MockSessionState::default() + .with_scalar_function(Arc::new(unicode::character_length().as_ref().clone())) + .with_scalar_function(Arc::new(string::concat().as_ref().clone())) + .with_scalar_function(Arc::new(make_udf( "nullif", vec![DataType::Int32, DataType::Int32], DataType::Int32, - )) - .with_udf(make_udf( + ))) + .with_scalar_function(Arc::new(make_udf( "round", vec![DataType::Float64, DataType::Int64], DataType::Float32, - )) - .with_udf(make_udf( + ))) + .with_scalar_function(Arc::new(make_udf( "arrow_cast", vec![DataType::Int64, DataType::Utf8], DataType::Float64, - )) - .with_udf(make_udf( + ))) + .with_scalar_function(Arc::new(make_udf( "date_trunc", vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)], DataType::Int32, - )) - .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) - .with_udaf(sum_udaf()) - .with_udaf(approx_median_udaf()) - .with_udaf(count_udaf()) - .with_udaf(avg_udaf()) - .with_udaf(min_udaf()) - .with_udaf(max_udaf()) - .with_udaf(grouping_udaf()) + ))) + .with_scalar_function(Arc::new(make_udf( + "sqrt", + vec![DataType::Int64], + DataType::Int64, + ))) + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(approx_median_udaf()) + .with_aggregate_function(count_udaf()) + .with_aggregate_function(avg_udaf()) + .with_aggregate_function(min_udaf()) + .with_aggregate_function(max_udaf()) + .with_aggregate_function(grouping_udaf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + let context = MockContextProvider { state }; let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?;