Skip to content

Commit

Permalink
Move sql_compound_identifier_to_expr to ExprPlanner (#11487)
Browse files Browse the repository at this point in the history
* move get_field to expr planner

* formatting

* formatting

* documentation

* refactor

* documentation & fix's

* move optimizer tests to core

* fix breaking tc's

* cleanup

* fix examples

* formatting

* rm datafusion-functions from optimizer

* update compound identifier

* update planner

* update planner

* formatting

* reverting optimizer tests

* formatting
  • Loading branch information
dharanad committed Jul 21, 2024
1 parent d232065 commit 36660fe
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 33 deletions.
19 changes: 18 additions & 1 deletion datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use std::sync::Arc;

use arrow::datatypes::{DataType, SchemaRef};
use arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion_common::{
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
Result, TableReference,
Expand Down Expand Up @@ -180,6 +180,23 @@ pub trait ExprPlanner: Send + Sync {
fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plans compound identifier eg `db.schema.table` for non-empty nested names
///
/// Note:
/// Currently compound identifier for outer query schema is not supported.
///
/// Returns planned expression
fn plan_compound_identifier(
&self,
_field: &Field,
_qualifier: Option<&TableReference>,
_nested_names: &[String],
) -> Result<PlannerResult<Vec<Expr>>> {
not_impl_err!(
"Default planner compound identifier hasn't been implemented for ExprPlanner"
)
}
}

/// An operator with two arguments to plan
Expand Down
1 change: 0 additions & 1 deletion datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
nvl2(),
arrow_typeof(),
named_struct(),
get_field(),
coalesce(),
map(),
]
Expand Down
27 changes: 25 additions & 2 deletions datafusion/functions/src/core/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::DFSchema;
use arrow::datatypes::Field;
use datafusion_common::Result;
use datafusion_common::{not_impl_err, Column, DFSchema, ScalarValue, TableReference};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawDictionaryExpr};
use datafusion_expr::Expr;
use datafusion_expr::{lit, Expr};

use super::named_struct;

Expand Down Expand Up @@ -62,4 +63,26 @@ impl ExprPlanner for CoreFunctionPlanner {
ScalarFunction::new_udf(crate::string::overlay(), args),
)))
}

fn plan_compound_identifier(
&self,
field: &Field,
qualifier: Option<&TableReference>,
nested_names: &[String],
) -> Result<PlannerResult<Vec<Expr>>> {
// TODO: remove when can support multiple nested identifiers
if nested_names.len() > 1 {
return not_impl_err!(
"Nested identifiers not yet supported for column {}",
Column::from((qualifier, field)).quoted_flat_name()
);
}
let nested_name = nested_names[0].to_string();

let col = Expr::Column(Column::from((qualifier, field)));
let get_field_args = vec![col, lit(ScalarValue::from(nested_name))];
Ok(PlannerResult::Planned(Expr::ScalarFunction(
ScalarFunction::new_udf(crate::core::get_field(), get_field_args),
)))
}
}
20 changes: 18 additions & 2 deletions datafusion/sql/examples/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,25 @@
// specific language governing permissions and limitations
// under the License.

use std::{collections::HashMap, sync::Arc};

use arrow_schema::{DataType, Field, Schema};

use datafusion_common::config::ConfigOptions;
use datafusion_common::{plan_err, Result};
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::WindowUDF;
use datafusion_expr::{
logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource,
};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_sql::{
planner::{ContextProvider, SqlToRel},
sqlparser::{dialect::GenericDialect, parser::Parser},
TableReference,
};
use std::{collections::HashMap, sync::Arc};

fn main() {
let sql = "SELECT \
Expand All @@ -53,7 +57,8 @@ fn main() {
// create a logical query plan
let context_provider = MyContextProvider::new()
.with_udaf(sum_udaf())
.with_udaf(count_udaf());
.with_udaf(count_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
let sql_to_rel = SqlToRel::new(&context_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();

Expand All @@ -65,6 +70,7 @@ struct MyContextProvider {
options: ConfigOptions,
tables: HashMap<String, Arc<dyn TableSource>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
expr_planners: Vec<Arc<dyn ExprPlanner>>,
}

impl MyContextProvider {
Expand All @@ -73,6 +79,11 @@ impl MyContextProvider {
self
}

fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
self.expr_planners.push(planner);
self
}

fn new() -> Self {
let mut tables = HashMap::new();
tables.insert(
Expand Down Expand Up @@ -105,6 +116,7 @@ impl MyContextProvider {
tables,
options: Default::default(),
udafs: Default::default(),
expr_planners: vec![],
}
}
}
Expand Down Expand Up @@ -154,4 +166,8 @@ impl ContextProvider for MyContextProvider {
fn udwf_names(&self) -> Vec<String> {
Vec::new()
}

fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.expr_planners
}
}
45 changes: 22 additions & 23 deletions datafusion/sql/src/expr/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow_schema::Field;
use sqlparser::ast::{Expr as SQLExpr, Ident};

use datafusion_common::{
internal_err, not_impl_err, plan_datafusion_err, Column, DFSchema, DataFusionError,
Result, ScalarValue, TableReference,
Result, TableReference,
};
use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr};
use sqlparser::ast::{Expr as SQLExpr, Ident};
use datafusion_expr::planner::PlannerResult;
use datafusion_expr::{Case, Expr};

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(super) fn sql_identifier_to_expr(
Expand Down Expand Up @@ -125,26 +128,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
match search_result {
// found matching field with spare identifier(s) for nested field(s) in structure
Some((field, qualifier, nested_names)) if !nested_names.is_empty() => {
// TODO: remove when can support multiple nested identifiers
if nested_names.len() > 1 {
return not_impl_err!(
"Nested identifiers not yet supported for column {}",
Column::from((qualifier, field)).quoted_flat_name()
);
}
let nested_name = nested_names[0].to_string();

let col = Expr::Column(Column::from((qualifier, field)));
if let Some(udf) =
self.context_provider.get_function_meta("get_field")
{
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![col, lit(ScalarValue::from(nested_name))],
)))
} else {
internal_err!("get_field not found")
// found matching field with spare identifier(s) for nested field(s) in structure
for planner in self.context_provider.get_expr_planners() {
if let Ok(planner_result) = planner.plan_compound_identifier(
field,
qualifier,
nested_names,
) {
match planner_result {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(_args) => {}
}
}
}
not_impl_err!(
"Compound identifiers not supported by ExprPlanner: {ids:?}"
)
}
// found matching field with no spare identifier(s)
Some((field, qualifier, _nested_names)) => {
Expand Down
11 changes: 8 additions & 3 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;
use std::vec;

use arrow_schema::*;
Expand All @@ -28,6 +29,7 @@ use datafusion_sql::unparser::dialect::{
};
use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};

use datafusion_functions::core::planner::CoreFunctionPlanner;
use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
use sqlparser::parser::Parser;

Expand Down Expand Up @@ -155,7 +157,8 @@ fn roundtrip_statement() -> Result<()> {

let context = MockContextProvider::default()
.with_udaf(sum_udaf())
.with_udaf(count_udaf());
.with_udaf(count_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();

Expand Down Expand Up @@ -184,7 +187,8 @@ fn roundtrip_crossjoin() -> Result<()> {
.try_with_sql(query)?
.parse_statement()?;

let context = MockContextProvider::default();
let context = MockContextProvider::default()
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();

Expand Down Expand Up @@ -276,7 +280,8 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
.try_with_sql(query.sql)?
.parse_statement()?;

let context = MockContextProvider::default();
let context = MockContextProvider::default()
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel
.sql_statement_to_plan(statement)
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sql/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow_schema::*;
use datafusion_common::config::ConfigOptions;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{plan_err, GetExt, Result, TableReference};
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
use datafusion_sql::planner::ContextProvider;

Expand Down Expand Up @@ -53,6 +54,7 @@ pub(crate) struct MockContextProvider {
options: ConfigOptions,
udfs: HashMap<String, Arc<ScalarUDF>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
expr_planners: Vec<Arc<dyn ExprPlanner>>,
}

impl MockContextProvider {
Expand All @@ -73,6 +75,11 @@ impl MockContextProvider {
self.udafs.insert(udaf.name().to_lowercase(), udaf);
self
}

pub(crate) fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
self.expr_planners.push(planner);
self
}
}

impl ContextProvider for MockContextProvider {
Expand Down Expand Up @@ -240,6 +247,10 @@ impl ContextProvider for MockContextProvider {
fn udwf_names(&self) -> Vec<String> {
Vec::new()
}

fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.expr_planners
}
}

struct EmptyTable {
Expand Down
5 changes: 4 additions & 1 deletion datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use std::any::Any;
#[cfg(test)]
use std::collections::HashMap;
use std::sync::Arc;
use std::vec;

use arrow_schema::TimeUnit::Nanosecond;
Expand All @@ -37,6 +38,7 @@ use datafusion_sql::{
planner::{ParserOptions, SqlToRel},
};

use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_aggregate::{
approx_median::approx_median_udaf, count::count_udaf,
};
Expand Down Expand Up @@ -2694,7 +2696,8 @@ fn logical_plan_with_dialect_and_options(
.with_udaf(approx_median_udaf())
.with_udaf(count_udaf())
.with_udaf(avg_udaf())
.with_udaf(grouping_udaf());
.with_udaf(grouping_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));

let planner = SqlToRel::new_with_options(&context, options);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
Expand Down

0 comments on commit 36660fe

Please sign in to comment.