diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 415af1bf94dc..c775427df138 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, @@ -180,6 +180,23 @@ pub trait ExprPlanner: Send + Sync { fn plan_make_map(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } + + /// Plans compound identifier eg `db.schema.table` for non-empty nested names + /// + /// Note: + /// Currently compound identifier for outer query schema is not supported. + /// + /// Returns planned expression + fn plan_compound_identifier( + &self, + _field: &Field, + _qualifier: Option<&TableReference>, + _nested_names: &[String], + ) -> Result>> { + not_impl_err!( + "Default planner compound identifier hasn't been implemented for ExprPlanner" + ) + } } /// An operator with two arguments to plan diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index cbfaa592b012..ee0309e59382 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -100,7 +100,6 @@ pub fn functions() -> Vec> { nvl2(), arrow_typeof(), named_struct(), - get_field(), coalesce(), map(), ] diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 63eaa9874c2b..889f191d592f 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::DFSchema; +use arrow::datatypes::Field; use datafusion_common::Result; +use datafusion_common::{not_impl_err, Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawDictionaryExpr}; -use datafusion_expr::Expr; +use datafusion_expr::{lit, Expr}; use super::named_struct; @@ -62,4 +63,26 @@ impl ExprPlanner for CoreFunctionPlanner { ScalarFunction::new_udf(crate::string::overlay(), args), ))) } + + fn plan_compound_identifier( + &self, + field: &Field, + qualifier: Option<&TableReference>, + nested_names: &[String], + ) -> Result>> { + // TODO: remove when can support multiple nested identifiers + if nested_names.len() > 1 { + return not_impl_err!( + "Nested identifiers not yet supported for column {}", + Column::from((qualifier, field)).quoted_flat_name() + ); + } + let nested_name = nested_names[0].to_string(); + + let col = Expr::Column(Column::from((qualifier, field))); + let get_field_args = vec![col, lit(ScalarValue::from(nested_name))]; + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::core::get_field(), get_field_args), + ))) + } } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index b724afabaf09..d9ee1b4db8e2 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -15,13 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::{collections::HashMap, sync::Arc}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ @@ -29,7 +34,6 @@ use datafusion_sql::{ sqlparser::{dialect::GenericDialect, parser::Parser}, TableReference, }; -use std::{collections::HashMap, sync::Arc}; fn main() { let sql = "SELECT \ @@ -53,7 +57,8 @@ fn main() { // create a logical query plan let context_provider = MyContextProvider::new() .with_udaf(sum_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -65,6 +70,7 @@ struct MyContextProvider { options: ConfigOptions, tables: HashMap>, udafs: HashMap>, + expr_planners: Vec>, } impl MyContextProvider { @@ -73,6 +79,11 @@ impl MyContextProvider { self } + fn with_expr_planner(mut self, planner: Arc) -> Self { + self.expr_planners.push(planner); + self + } + fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -105,6 +116,7 @@ impl MyContextProvider { tables, options: Default::default(), udafs: Default::default(), + expr_planners: vec![], } } } @@ -154,4 +166,8 @@ impl ContextProvider for MyContextProvider { fn udwf_names(&self) -> Vec { Vec::new() } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 39736b1fbba5..f8979bde3086 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; +use sqlparser::ast::{Expr as SQLExpr, Ident}; + use datafusion_common::{ internal_err, not_impl_err, plan_datafusion_err, Column, DFSchema, DataFusionError, - Result, ScalarValue, TableReference, + Result, TableReference, }; -use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; -use sqlparser::ast::{Expr as SQLExpr, Ident}; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::{Case, Expr}; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_identifier_to_expr( @@ -125,26 +128,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match search_result { // found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // TODO: remove when can support multiple nested identifiers - if nested_names.len() > 1 { - return not_impl_err!( - "Nested identifiers not yet supported for column {}", - Column::from((qualifier, field)).quoted_flat_name() - ); - } - let nested_name = nested_names[0].to_string(); - - let col = Expr::Column(Column::from((qualifier, field))); - if let Some(udf) = - self.context_provider.get_function_meta("get_field") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![col, lit(ScalarValue::from(nested_name))], - ))) - } else { - internal_err!("get_field not found") + // found matching field with spare identifier(s) for nested field(s) in structure + for planner in self.context_provider.get_expr_planners() { + if let Ok(planner_result) = planner.plan_compound_identifier( + field, + qualifier, + nested_names, + ) { + match planner_result { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(_args) => {} + } + } } + not_impl_err!( + "Compound identifiers not supported by ExprPlanner: {ids:?}" + ) } // found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 91295b2e8aae..66f568224c3b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; use std::vec; use arrow_schema::*; @@ -28,6 +29,7 @@ use datafusion_sql::unparser::dialect::{ }; use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; +use datafusion_functions::core::planner::CoreFunctionPlanner; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -155,7 +157,8 @@ fn roundtrip_statement() -> Result<()> { let context = MockContextProvider::default() .with_udaf(sum_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -184,7 +187,8 @@ fn roundtrip_crossjoin() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default(); + let context = MockContextProvider::default() + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -247,7 +251,8 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .try_with_sql(query.sql)? .parse_statement()?; - let context = MockContextProvider::default(); + let context = MockContextProvider::default() + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel .sql_statement_to_plan(statement) diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index b8d8bd12d28b..d9e672a842ce 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -25,6 +25,7 @@ use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, GetExt, Result, TableReference}; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_sql::planner::ContextProvider; @@ -53,6 +54,7 @@ pub(crate) struct MockContextProvider { options: ConfigOptions, udfs: HashMap>, udafs: HashMap>, + expr_planners: Vec>, } impl MockContextProvider { @@ -73,6 +75,11 @@ impl MockContextProvider { self.udafs.insert(udaf.name().to_lowercase(), udaf); self } + + pub(crate) fn with_expr_planner(mut self, planner: Arc) -> Self { + self.expr_planners.push(planner); + self + } } impl ContextProvider for MockContextProvider { @@ -240,6 +247,10 @@ impl ContextProvider for MockContextProvider { fn udwf_names(&self) -> Vec { Vec::new() } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } } struct EmptyTable { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e34e7e20a0f3..b4f8b4e6d01c 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -18,6 +18,7 @@ use std::any::Any; #[cfg(test)] use std::collections::HashMap; +use std::sync::Arc; use std::vec; use arrow_schema::TimeUnit::Nanosecond; @@ -37,6 +38,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; +use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, }; @@ -2694,7 +2696,8 @@ fn logical_plan_with_dialect_and_options( .with_udaf(approx_median_udaf()) .with_udaf(count_udaf()) .with_udaf(avg_udaf()) - .with_udaf(grouping_udaf()); + .with_udaf(grouping_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect);