Skip to content

Commit

Permalink
Stop copying LogicalPlan and Exprs in TypeCoercion
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed May 7, 2024
1 parent 826d51f commit 5ed976b
Showing 1 changed file with 88 additions and 37 deletions.
125 changes: 88 additions & 37 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
DataFusionError, Result, ScalarValue,
Expand All @@ -31,8 +31,8 @@ use datafusion_expr::expr::{
self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList,
InSubquery, Like, ScalarFunction, WindowFunction,
};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
comparison_coercion, get_input_types, like_coercion,
Expand All @@ -51,6 +51,7 @@ use datafusion_expr::{
};

use crate::analyzer::AnalyzerRule;
use crate::utils::NamePreserver;

#[derive(Default)]
pub struct TypeCoercion {}
Expand All @@ -67,26 +68,28 @@ impl AnalyzerRule for TypeCoercion {
}

fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
analyze_internal(&DFSchema::empty(), &plan)
let empty_schema = DFSchema::empty();

let transformed_plan = plan
.transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
.data;

Ok(transformed_plan)
}
}

/// use the external schema to handle the correlated subqueries case
///
/// Assumes that children have already been optimized
fn analyze_internal(
// use the external schema to handle the correlated subqueries case
external_schema: &DFSchema,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| analyze_internal(external_schema, p))
.collect::<Result<Vec<_>>>()?;
plan: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());
let mut schema = merge_schema(plan.inputs());

if let LogicalPlan::TableScan(ts) = plan {
if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
Expand All @@ -99,25 +102,75 @@ fn analyze_internal(
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
schema.merge(external_schema);

let mut expr_rewrite = TypeCoercionRewriter { schema: &schema };

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, new_inputs)
let mut expr_rewrite = TypeCoercionRewriter::new(&schema);

let name_preserver = NamePreserver::new(&plan);
// apply coercion rewrite all expressions in the plan indivdually
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
expr.rewrite(&mut expr_rewrite)?
.map_data(|expr| original_name.restore(expr))
})?
// coerce join expressions specially
.map_data(|plan| expr_rewrite.coerce_joins(plan))?
// recompute the schema after the expressions have been rewritten as the types may have changed
.map_data(|plan| plan.recompute_schema())
}

pub(crate) struct TypeCoercionRewriter<'a> {
pub(crate) schema: &'a DFSchema,
}

impl<'a> TypeCoercionRewriter<'a> {
fn new(schema: &'a DFSchema) -> Self {
Self { schema }
}

/// Coerce join equality expressions
///
/// Joins must be treated specially as their equality expressions are stored
/// as a parallel list of left and right expressions, rather than a single
/// equality expression
///
/// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
/// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
fn coerce_joins(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
let LogicalPlan::Join(mut join) = plan else {
return Ok(plan);
};

join.on = join
.on
.into_iter()
.map(|(lhs, rhs)| {
// coerce the arguments as though they were a single binary equality
// expression
let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?;
Ok((lhs, rhs))
})
.collect::<Result<Vec<_>>>()?;

Ok(LogicalPlan::Join(join))
}

fn coerce_binary_op(
&self,
left: Expr,
op: Operator,
right: Expr,
) -> Result<(Expr, Expr)> {
let (left_type, right_type) = get_input_types(
&left.get_type(self.schema)?,
&op,
&right.get_type(self.schema)?,
)?;
Ok((
left.cast_to(&left_type, self.schema)?,
right.cast_to(&right_type, self.schema)?,
))
}
}

impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
type Node = Expr;

Expand All @@ -130,14 +183,15 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
subquery,
outer_ref_columns,
}) => {
let new_plan = analyze_internal(self.schema, &subquery)?;
let new_plan = analyze_internal(self.schema, unwrap_arc(subquery))?.data;
Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})))
}
Expr::Exists(Exists { subquery, negated }) => {
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
let new_plan =
analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data;
Ok(Transformed::yes(Expr::Exists(Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
Expand All @@ -151,7 +205,8 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
subquery,
negated,
}) => {
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
let new_plan =
analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data;
let expr_type = expr.get_type(self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
Expand Down Expand Up @@ -220,15 +275,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
))))
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let (left_type, right_type) = get_input_types(
&left.get_type(self.schema)?,
&op,
&right.get_type(self.schema)?,
)?;
let (left, right) = self.coerce_binary_op(*left, op, *right)?;
Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left.cast_to(&left_type, self.schema)?),
Box::new(left),
op,
Box::new(right.cast_to(&right_type, self.schema)?),
Box::new(right),
))))
}
Expr::Between(Between {
Expand Down

0 comments on commit 5ed976b

Please sign in to comment.