Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS Float64) - #lineitem.l_discount)]]\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Float64(1) - #lineitem.l_discount)]]\

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice

\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ async fn multiple_or_predicates() -> Result<()> {
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Float64(1) AND #lineitem.l_quantity <= Float64(11) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Float64(10) AND #lineitem.l_quantity <= Float64(20) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Float64(20) AND #lineitem.l_quantity <= Float64(30) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
Expand Down
69 changes: 52 additions & 17 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Optimizer rule for type validation and coercion

use crate::simplify_expressions::ConstEvaluator;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{DFSchema, DFSchemaRef, Result};
use datafusion_expr::binary_rule::coerce_types;
Expand All @@ -27,6 +28,7 @@ use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{ExprSchemable, Signature};
use datafusion_physical_expr::execution_props::ExecutionProps;

#[derive(Default)]
pub struct TypeCoercion {}
Expand Down Expand Up @@ -64,7 +66,15 @@ impl OptimizerRule for TypeCoercion {
_ => DFSchemaRef::new(DFSchema::empty()),
};

let mut expr_rewrite = TypeCoercionRewriter { schema };
let mut execution_props = ExecutionProps::new();
execution_props.query_execution_start_time =
optimizer_config.query_execution_start_time;
let const_evaluator = ConstEvaluator::try_new(&execution_props)?;
Comment thread
Dandandan marked this conversation as resolved.

let mut expr_rewrite = TypeCoercionRewriter {
schema,
const_evaluator,
};

let new_expr = plan
.expressions()
Expand All @@ -76,11 +86,12 @@ impl OptimizerRule for TypeCoercion {
}
}

struct TypeCoercionRewriter {
struct TypeCoercionRewriter<'a> {
schema: DFSchemaRef,
const_evaluator: ConstEvaluator<'a>,
}

impl ExprRewriter for TypeCoercionRewriter {
impl ExprRewriter for TypeCoercionRewriter<'_> {
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
Expand All @@ -91,22 +102,26 @@ impl ExprRewriter for TypeCoercionRewriter {
let left_type = left.get_type(&self.schema)?;
let right_type = right.get_type(&self.schema)?;
let coerced_type = coerce_types(&left_type, &op, &right_type)?;
Ok(Expr::BinaryExpr {

let expr = Expr::BinaryExpr {
left: Box::new(left.cast_to(&coerced_type, &self.schema)?),
op,
right: Box::new(right.cast_to(&coerced_type, &self.schema)?),
})
};

expr.rewrite(&mut self.const_evaluator)
}
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
Ok(Expr::ScalarUDF {
let expr = Expr::ScalarUDF {
fun,
args: new_expr,
})
};
expr.rewrite(&mut self.const_evaluator)
}
expr => Ok(expr),
}
Expand Down Expand Up @@ -145,7 +160,8 @@ mod test {
use crate::type_coercion::TypeCoercion;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, Result};
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::{col, ColumnarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expand All @@ -156,28 +172,40 @@ mod test {

#[test]
fn simple_case() -> Result<()> {
let expr = lit(1.2_f64).lt(lit(2_u32));
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Float64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n EmptyRelation",
"Projection: #a < Float64(2)\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn nested_case() -> Result<()> {
let expr = lit(1.2_f64).lt(lit(2_u32));
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Float64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![expr.clone().or(expr)],
Expand All @@ -187,8 +215,11 @@ mod test {
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR Float64(1.2) < CAST(UInt32(2) AS Float64)\
\n EmptyRelation", &format!("{:?}", plan));
assert_eq!(
"Projection: #a < Float64(2) OR #a < Float64(2)\
\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

Expand All @@ -197,7 +228,11 @@ mod test {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!());
let fun: ScalarFunctionImplementation = Arc::new(move |_| {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
"a".to_string(),
))))
});
let udf = Expr::ScalarUDF {
fun: Arc::new(ScalarUDF::new(
"TestScalarUDF",
Expand All @@ -212,7 +247,7 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation",
"Projection: Utf8(\"a\")\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
Expand Down