Skip to content

Commit

Permalink
Union chema coersion refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgeny Maruschenko committed Oct 3, 2023
1 parent 326cc9e commit dfd4dab
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 49 deletions.
56 changes: 29 additions & 27 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,34 @@ impl LogicalPlanBuilder {
}
}

// Creates a schema for a union operation.
// Coerce fields to common type.
pub fn build_union_schema(left: &DFSchema, right: &DFSchema) -> Result<DFSchema> {
zip(left.fields().iter(), right.fields().iter())
.map(|(left_field, right_field)| {
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type =
comparison_coercion(left_field.data_type(), right_field.data_type())
.ok_or_else(|| {
DataFusionError::Plan(format!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
))
})?;
Ok(DFField::new(
left_field.qualifier().cloned(),
left_field.name(),
data_type,
nullable,
))
})
.collect::<Result<Vec<_>>>()?
.to_dfschema()
}

/// Creates a schema for a join operation.
/// The fields from the left side are first
pub fn build_join_schema(
Expand Down Expand Up @@ -1197,33 +1225,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalP
}

// create union schema
let union_schema = zip(
left_plan.schema().fields().iter(),
right_plan.schema().fields().iter(),
)
.map(|(left_field, right_field)| {
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type =
comparison_coercion(left_field.data_type(), right_field.data_type())
.ok_or_else(|| {
DataFusionError::Plan(format!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
))
})?;

Ok(DFField::new(
left_field.qualifier().cloned(),
left_field.name(),
data_type,
nullable,
))
})
.collect::<Result<Vec<_>>>()?
.to_dfschema()?;
let union_schema = build_union_schema(left_plan.schema(), right_plan.schema())?;

let inputs = vec![left_plan, right_plan]
.into_iter()
Expand Down
16 changes: 12 additions & 4 deletions datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::{
builder::project_with_column_index,
builder::{build_union_schema, project_with_column_index},
expr_rewriter::coerce_plan_expr_for_schema,
logical_plan::{LogicalPlan, Projection, Union},
};
Expand All @@ -38,6 +38,8 @@ impl EliminateNestedUnion {
}
}

pub fn get_union_schema() {}

impl OptimizerRule for EliminateNestedUnion {
fn try_optimize(
&self,
Expand All @@ -46,9 +48,15 @@ impl OptimizerRule for EliminateNestedUnion {
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Union(union) => {
let Union { inputs, schema } = union;
let Union { inputs, schema: _ } = union;

let union_schema = schema.clone();
let union_schema = inputs
.iter()
.map(|input| Arc::clone(input.schema()))
.reduce(|acc, el| {
Arc::new(build_union_schema(acc.as_ref(), el.as_ref()).unwrap())
})
.unwrap();

let inputs = inputs
.into_iter()
Expand All @@ -66,7 +74,7 @@ impl OptimizerRule for EliminateNestedUnion {
input,
union_schema.clone(),
)?)),
_ => Ok(Arc::new(plan)),
other_plan => Ok(Arc::new(other_plan)),
}
})
.collect::<Result<Vec<_>>>()?;
Expand Down
18 changes: 0 additions & 18 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2053,24 +2053,6 @@ fn union_all() {
quick_test(sql, expected);
}

#[test]
fn union_4_combined_in_one() {
let sql = "SELECT order_id from orders
UNION ALL SELECT order_id FROM orders
UNION ALL SELECT order_id FROM orders
UNION ALL SELECT order_id FROM orders";
let expected = "Union\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id\
\n TableScan: orders";
quick_test(sql, expected);
}

#[test]
fn union_with_different_column_names() {
let sql = "SELECT order_id from orders UNION ALL SELECT customer_id FROM orders";
Expand Down

0 comments on commit dfd4dab

Please sign in to comment.