diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 83096e966f628..a0c75772e566c 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -18,13 +18,10 @@ //! Optimizer rule to replace nested unions to single union. use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; -use datafusion_expr::{ - builder::project_with_column_index, - expr_rewriter::coerce_plan_expr_for_schema, - logical_plan::{LogicalPlan, Projection, Union}, -}; +use datafusion_expr::logical_plan::{LogicalPlan, Union}; use crate::optimizer::ApplyOrder; +use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use std::sync::Arc; #[derive(Default)] @@ -38,6 +35,8 @@ impl EliminateNestedUnion { } } +pub fn get_union_schema() {} + impl OptimizerRule for EliminateNestedUnion { fn try_optimize( &self, @@ -48,32 +47,24 @@ impl OptimizerRule for EliminateNestedUnion { LogicalPlan::Union(union) => { let Union { inputs, schema } = union; - let union_schema = schema.clone(); - let inputs = inputs .into_iter() - .flat_map(|plan| match Arc::as_ref(plan) { - LogicalPlan::Union(Union { inputs, .. }) => inputs.clone(), - _ => vec![Arc::clone(plan)], - }) - .map(|plan| { - let plan = coerce_plan_expr_for_schema(&plan, &union_schema)?; - match plan { - LogicalPlan::Projection(Projection { - expr, input, .. - }) => Ok(Arc::new(project_with_column_index( - expr, - input, - union_schema.clone(), - )?)), - _ => Ok(Arc::new(plan)), - } + .flat_map(|plan| match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => inputs + .into_iter() + .map(|plan| { + Arc::new( + coerce_plan_expr_for_schema(plan, schema).unwrap(), + ) + }) + .collect::>(), + _ => vec![plan.clone()], }) - .collect::>>()?; + .collect::>(); Ok(Some(LogicalPlan::Union(Union { inputs, - schema: union_schema, + schema: schema.clone(), }))) } _ => Ok(None), @@ -94,13 +85,13 @@ mod tests { use super::*; use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::logical_plan::table_scan; + use datafusion_expr::{col, logical_plan::table_scan}; fn schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Int32, false), + Field::new("value", DataType::Float64, false), ]) } @@ -143,4 +134,81 @@ mod tests { \n TableScan: table"; assert_optimized_plan_equal(&plan, expected) } + + // We don't need to use project_with_column_index in logical optimizer, + // after LogicalPlanBuilder::union, we already have all equal expression aliases + #[test] + fn eliminate_nested_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union(table_2.build()?)? + .union(table_3.build()?)? + .build()?; + + let expected = "Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index c343e2576c041..70ee490346ffb 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -40,9 +40,7 @@ impl OptimizerRule for EliminateOneUnion { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Union(union) if union.inputs.len() == 1 => { - let Union { inputs, schema: _ } = union; - + LogicalPlan::Union(Union { inputs, .. }) if inputs.len() == 1 => { Ok(inputs.first().map(|input| input.as_ref().clone())) } _ => Ok(None), @@ -103,7 +101,7 @@ mod tests { } #[test] - fn eliminate_nested_union() -> Result<()> { + fn eliminate_one_union() -> Result<()> { let table_plan = coerce_plan_expr_for_schema( &table_scan(Some("table"), &schema(), None)?.build()?, &schema().to_dfschema()?, diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index a1db9dcc171d6..5231dc8698751 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -222,6 +222,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(ReplaceDistinctWithAggregate::new()), @@ -229,7 +230,6 @@ impl Optimizer { Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateNestedUnion::new()), // simplify expressions does not simplify expressions in subqueries, so we // run it again after running the optimizations that potentially converted // subqueries to joins @@ -242,9 +242,10 @@ impl Optimizer { Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(PropagateEmptyRelation::new()), + // Must be after PropagateEmptyRelation + Arc::new(EliminateOneUnion::new()), Arc::new(FilterNullJoinKeys::default()), Arc::new(EliminateOuterJoin::new()), - Arc::new(EliminateOneUnion::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit Arc::new(PushDownLimit::new()), Arc::new(PushDownFilter::new()), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f4ffea06b7a13..e97b7cd38c10c 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2053,24 +2053,6 @@ fn union_all() { quick_test(sql, expected); } -#[test] -fn union_4_combined_in_one() { - let sql = "SELECT order_id from orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders"; - let expected = "Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); -} - #[test] fn union_with_different_column_names() { let sql = "SELECT order_id from orders UNION ALL SELECT customer_id FROM orders"; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 23055cd978ffe..0e190a6acd620 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -184,6 +184,7 @@ logical_plan after inline_table_scan SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -200,6 +201,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_limit SAME TEXT AS ABOVE logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE @@ -213,6 +215,7 @@ Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c --TableScan: simple_explain_test projection=[a, b, c] logical_plan after eliminate_projection TableScan: simple_explain_test projection=[a, b, c] logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -229,6 +232,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_limit SAME TEXT AS ABOVE logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE