diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index e3788066d33f..4696d21852fc 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -378,12 +378,18 @@ enum JoinType { ANTI = 5; } +enum JoinConstraint { + ON = 0; + USING = 1; +} + message JoinNode { LogicalPlanNode left = 1; LogicalPlanNode right = 2; JoinType join_type = 3; - repeated Column left_join_column = 4; - repeated Column right_join_column = 5; + JoinConstraint join_constraint = 4; + repeated Column left_join_column = 5; + repeated Column right_join_column = 6; } message LimitNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index a1136cf4a7d6..cad054392308 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -26,8 +26,8 @@ use datafusion::logical_plan::window_frames::{ }; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, - sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator, + sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint, JoinType, + LogicalPlan, LogicalPlanBuilder, Operator, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; @@ -257,23 +257,32 @@ impl TryInto for &protobuf::LogicalPlanNode { join.join_type )) })?; - let join_type = match join_type { - protobuf::JoinType::Inner => JoinType::Inner, - protobuf::JoinType::Left => JoinType::Left, - protobuf::JoinType::Right => JoinType::Right, - protobuf::JoinType::Full => JoinType::Full, - protobuf::JoinType::Semi => JoinType::Semi, - protobuf::JoinType::Anti => JoinType::Anti, - }; - LogicalPlanBuilder::from(convert_box_required!(join.left)?) - .join( + let join_constraint = protobuf::JoinConstraint::from_i32( + join.join_constraint, + ) + .ok_or_else(|| { + proto_error(format!( + "Received a JoinNode message with unknown JoinConstraint {}", + join.join_constraint + )) + })?; + + let builder = LogicalPlanBuilder::from(convert_box_required!(join.left)?); + let builder = match join_constraint.into() { + JoinConstraint::On => builder.join( &convert_box_required!(join.right)?, - join_type, + join_type.into(), left_keys, right_keys, - )? - .build() - .map_err(|e| e.into()) + )?, + JoinConstraint::Using => builder.join_using( + &convert_box_required!(join.right)?, + join_type.into(), + left_keys, + )?, + }; + + builder.build().map_err(|e| e.into()) } } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 4049622b83dc..07d7a59c114c 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUn use datafusion::datasource::CsvFile; use datafusion::logical_plan::{ window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, - Column, Expr, JoinType, LogicalPlan, + Column, Expr, JoinConstraint, JoinType, LogicalPlan, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::functions::BuiltinScalarFunction; @@ -804,26 +804,23 @@ impl TryInto for &LogicalPlan { right, on, join_type, + join_constraint, .. } => { let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?; let right: protobuf::LogicalPlanNode = right.as_ref().try_into()?; - let join_type = match join_type { - JoinType::Inner => protobuf::JoinType::Inner, - JoinType::Left => protobuf::JoinType::Left, - JoinType::Right => protobuf::JoinType::Right, - JoinType::Full => protobuf::JoinType::Full, - JoinType::Semi => protobuf::JoinType::Semi, - JoinType::Anti => protobuf::JoinType::Anti, - }; let (left_join_column, right_join_column) = on.iter().map(|(l, r)| (l.into(), r.into())).unzip(); + let join_type: protobuf::JoinType = join_type.to_owned().into(); + let join_constraint: protobuf::JoinConstraint = + join_constraint.to_owned().into(); Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { left: Some(Box::new(left)), right: Some(Box::new(right)), join_type: join_type.into(), + join_constraint: join_constraint.into(), left_join_column, right_join_column, }, diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index af83660baab5..1df0675ecae5 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,7 +20,7 @@ use std::{convert::TryInto, io::Cursor}; -use datafusion::logical_plan::Operator; +use datafusion::logical_plan::{JoinConstraint, JoinType, Operator}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; @@ -291,3 +291,47 @@ impl Into for protobuf::PrimitiveScalarT } } } + +impl From for JoinType { + fn from(t: protobuf::JoinType) -> Self { + match t { + protobuf::JoinType::Inner => JoinType::Inner, + protobuf::JoinType::Left => JoinType::Left, + protobuf::JoinType::Right => JoinType::Right, + protobuf::JoinType::Full => JoinType::Full, + protobuf::JoinType::Semi => JoinType::Semi, + protobuf::JoinType::Anti => JoinType::Anti, + } + } +} + +impl From for protobuf::JoinType { + fn from(t: JoinType) -> Self { + match t { + JoinType::Inner => protobuf::JoinType::Inner, + JoinType::Left => protobuf::JoinType::Left, + JoinType::Right => protobuf::JoinType::Right, + JoinType::Full => protobuf::JoinType::Full, + JoinType::Semi => protobuf::JoinType::Semi, + JoinType::Anti => protobuf::JoinType::Anti, + } + } +} + +impl From for JoinConstraint { + fn from(t: protobuf::JoinConstraint) -> Self { + match t { + protobuf::JoinConstraint::On => JoinConstraint::On, + protobuf::JoinConstraint::Using => JoinConstraint::Using, + } + } +} + +impl From for protobuf::JoinConstraint { + fn from(t: JoinConstraint) -> Self { + match t { + JoinConstraint::On => protobuf::JoinConstraint::On, + JoinConstraint::Using => protobuf::JoinConstraint::Using, + } + } +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 717ee209dbe9..12c1743c0747 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -35,7 +35,9 @@ use datafusion::catalog::catalog::{ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; -use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr}; +use datafusion::logical_plan::{ + window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType, +}; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; @@ -57,7 +59,6 @@ use datafusion::physical_plan::{ filter::FilterExec, functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, hash_join::HashJoinExec, - hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, parquet::ParquetExec, projection::ProjectionExec, @@ -348,14 +349,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { hashjoin.join_type )) })?; - let join_type = match join_type { - protobuf::JoinType::Inner => JoinType::Inner, - protobuf::JoinType::Left => JoinType::Left, - protobuf::JoinType::Right => JoinType::Right, - protobuf::JoinType::Full => JoinType::Full, - protobuf::JoinType::Semi => JoinType::Semi, - protobuf::JoinType::Anti => JoinType::Anti, - }; + let partition_mode = protobuf::PartitionMode::from_i32(hashjoin.partition_mode) .ok_or_else(|| { @@ -372,7 +366,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { left, right, on, - &join_type, + &join_type.into(), partition_mode, )?)) } diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index a393d7fdab1f..3bf7e9c3063b 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -27,7 +27,7 @@ mod roundtrip_tests { compute::kernels::sort::SortOptions, datatypes::{DataType, Field, Schema}, }, - logical_plan::Operator, + logical_plan::{JoinType, Operator}, physical_plan::{ empty::EmptyExec, expressions::{binary, col, lit, InListExpr, NotExpr}, @@ -35,7 +35,6 @@ mod roundtrip_tests { filter::FilterExec, hash_aggregate::{AggregateMode, HashAggregateExec}, hash_join::{HashJoinExec, PartitionMode}, - hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 0fc27850074c..875dbf213441 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -26,6 +26,7 @@ use std::{ sync::Arc, }; +use datafusion::logical_plan::JoinType; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::csv::CsvExec; use datafusion::physical_plan::expressions::{ @@ -35,7 +36,6 @@ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; -use datafusion::physical_plan::hash_utils::JoinType; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::parquet::ParquetExec; use datafusion::physical_plan::projection::ProjectionExec; @@ -135,18 +135,13 @@ impl TryInto for Arc { }), }) .collect(); - let join_type = match exec.join_type() { - JoinType::Inner => protobuf::JoinType::Inner, - JoinType::Left => protobuf::JoinType::Left, - JoinType::Right => protobuf::JoinType::Right, - JoinType::Full => protobuf::JoinType::Full, - JoinType::Semi => protobuf::JoinType::Semi, - JoinType::Anti => protobuf::JoinType::Anti, - }; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let partition_mode = match exec.partition_mode() { PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft, PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned, }; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { diff --git a/benchmarks/queries/q7.sql b/benchmarks/queries/q7.sql index d53877c8dde6..512e5be55a2d 100644 --- a/benchmarks/queries/q7.sql +++ b/benchmarks/queries/q7.sql @@ -36,4 +36,4 @@ group by order by supp_nation, cust_nation, - l_year; \ No newline at end of file + l_year; diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index d5a84869ad94..77b3fbeb851b 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1277,6 +1277,96 @@ mod tests { Ok(()) } + #[tokio::test] + async fn left_join_using() -> Result<()> { + let results = execute( + "SELECT t1.c1, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + + #[tokio::test] + async fn left_join_using_join_key_projection() -> Result<()> { + let results = execute( + "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+----+", + "| c1 | c2 | c2 |", + "+----+----+----+", + "| 0 | 1 | 1 |", + "| 0 | 2 | 2 |", + "| 0 | 3 | 3 |", + "| 0 | 4 | 4 |", + "| 0 | 5 | 5 |", + "| 0 | 6 | 6 |", + "| 0 | 7 | 7 |", + "| 0 | 8 | 8 |", + "| 0 | 9 | 9 |", + "| 0 | 10 | 10 |", + "+----+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + + #[tokio::test] + async fn left_join() -> Result<()> { + let results = execute( + "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 ON t1.c2 = t2.c2 ORDER BY t1.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+----+", + "| c1 | c2 | c2 |", + "+----+----+----+", + "| 0 | 1 | 1 |", + "| 0 | 2 | 2 |", + "| 0 | 3 | 3 |", + "| 0 | 4 | 4 |", + "| 0 | 5 | 5 |", + "| 0 | 6 | 6 |", + "| 0 | 7 | 7 |", + "| 0 | 8 | 8 |", + "| 0 | 9 | 9 |", + "| 0 | 10 | 10 |", + "+----+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + #[tokio::test] async fn window() -> Result<()> { let results = execute( diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 7cf779740c47..4edd01c2c0a9 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -264,7 +264,7 @@ mod tests { #[tokio::test] async fn join() -> Result<()> { let left = test_table()?.select_columns(&["c1", "c2"])?; - let right = test_table()?.select_columns(&["c1", "c3"])?; + let right = test_table_with_name("c2")?.select_columns(&["c1", "c3"])?; let left_rows = left.collect().await?; let right_rows = right.collect().await?; let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; @@ -315,7 +315,7 @@ mod tests { #[test] fn registry() -> Result<()> { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; + register_aggregate_csv(&mut ctx, "aggregate_test_100")?; // declare the udf let my_fn: ScalarFunctionImplementation = @@ -366,21 +366,28 @@ mod tests { /// Create a logical plan from a SQL query fn create_plan(sql: &str) -> Result { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; + register_aggregate_csv(&mut ctx, "aggregate_test_100")?; ctx.create_logical_plan(sql) } - fn test_table() -> Result> { + fn test_table_with_name(name: &str) -> Result> { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; - ctx.table("aggregate_test_100") + register_aggregate_csv(&mut ctx, name)?; + ctx.table(name) + } + + fn test_table() -> Result> { + test_table_with_name("aggregate_test_100") } - fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { + fn register_aggregate_csv( + ctx: &mut ExecutionContext, + table_name: &str, + ) -> Result<()> { let schema = test::aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); ctx.register_csv( - "aggregate_test_100", + table_name, &format!("{}/csv/aggregate_test_100.csv", testdata), CsvReadOptions::new().schema(schema.as_ref()), )?; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 1a53e2185a4b..41f29c4b9905 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -40,7 +40,6 @@ use crate::logical_plan::{ columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema, DFSchemaRef, Partitioning, }; -use std::collections::HashSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -217,7 +216,6 @@ impl LogicalPlanBuilder { /// * An invalid expression is used (e.g. a `sort` expression) pub fn project(&self, expr: impl IntoIterator) -> Result { let input_schema = self.plan.schema(); - let all_schemas = self.plan.all_schemas(); let mut projected_expr = vec![]; for e in expr { match e { @@ -227,10 +225,8 @@ impl LogicalPlanBuilder { .push(Expr::Column(input_schema.field(i).qualified_column())) }); } - _ => projected_expr.push(columnize_expr( - normalize_col(e, &all_schemas)?, - input_schema, - )), + _ => projected_expr + .push(columnize_expr(normalize_col(e, &self.plan)?, input_schema)), } } @@ -247,7 +243,7 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(&self, expr: Expr) -> Result { - let expr = normalize_col(expr, &self.plan.all_schemas())?; + let expr = normalize_col(expr, &self.plan)?; Ok(Self::from(LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -264,9 +260,8 @@ impl LogicalPlanBuilder { /// Apply a sort pub fn sort(&self, exprs: impl IntoIterator) -> Result { - let schemas = self.plan.all_schemas(); Ok(Self::from(LogicalPlan::Sort { - expr: normalize_cols(exprs, &schemas)?, + expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), })) } @@ -292,20 +287,15 @@ impl LogicalPlanBuilder { let left_keys: Vec = left_keys .into_iter() - .map(|c| c.into().normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan)) .collect::>()?; let right_keys: Vec = right_keys .into_iter() - .map(|c| c.into().normalize(&right.all_schemas())) + .map(|c| c.into().normalize(right)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); - let join_schema = build_join_schema( - self.plan.schema(), - right.schema(), - &on, - &join_type, - &JoinConstraint::On, - )?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join { left: Arc::new(self.plan.clone()), @@ -327,21 +317,16 @@ impl LogicalPlanBuilder { let left_keys: Vec = using_keys .clone() .into_iter() - .map(|c| c.into().normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan)) .collect::>()?; let right_keys: Vec = using_keys .into_iter() - .map(|c| c.into().normalize(&right.all_schemas())) + .map(|c| c.into().normalize(right)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); - let join_schema = build_join_schema( - self.plan.schema(), - right.schema(), - &on, - &join_type, - &JoinConstraint::Using, - )?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join { left: Arc::new(self.plan.clone()), @@ -394,9 +379,8 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator, aggr_expr: impl IntoIterator, ) -> Result { - let schemas = self.plan.all_schemas(); - let group_expr = normalize_cols(group_expr, &schemas)?; - let aggr_expr = normalize_cols(aggr_expr, &schemas)?; + let group_expr = normalize_cols(group_expr, &self.plan)?; + let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; let all_expr = group_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; @@ -440,33 +424,12 @@ impl LogicalPlanBuilder { pub fn build_join_schema( left: &DFSchema, right: &DFSchema, - on: &[(Column, Column)], join_type: &JoinType, - join_constraint: &JoinConstraint, ) -> Result { let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full => { - let duplicate_keys = match join_constraint { - JoinConstraint::On => on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.clone()) - .collect::>(), - // using join requires unique join columns in the output schema, so we mark all - // right join keys as duplicate - JoinConstraint::Using => { - on.iter().map(|on| on.1.clone()).collect::>() - } - }; - + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let right_fields = right.fields().iter(); let left_fields = left.fields().iter(); - - // remove right-side join keys if they have the same names as the left-side - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(&f.qualified_column())); - // left then right left_fields.chain(right_fields).cloned().collect() } @@ -474,31 +437,6 @@ pub fn build_join_schema( // Only use the left side for the schema left.fields().clone() } - JoinType::Right => { - let duplicate_keys = match join_constraint { - JoinConstraint::On => on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.clone()) - .collect::>(), - // using join requires unique join columns in the output schema, so we mark all - // left join keys as duplicate - JoinConstraint::Using => { - on.iter().map(|on| on.0.clone()).collect::>() - } - }; - - // remove left-side join keys if they have the same names as the right-side - let left_fields = left - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(&f.qualified_column())); - - let right_fields = right.fields().iter(); - - // left then right - left_fields.chain(right_fields).cloned().collect() - } }; DFSchema::new(fields) diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index b4d864f55ebd..b4bde87f3471 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -48,6 +48,7 @@ impl DFSchema { pub fn new(fields: Vec) -> Result { let mut qualified_names = HashSet::new(); let mut unqualified_names = HashSet::new(); + for field in &fields { if let Some(qualifier) = field.qualifier() { if !qualified_names.insert((qualifier, field.name())) { @@ -94,10 +95,7 @@ impl DFSchema { schema .fields() .iter() - .map(|f| DFField { - field: f.clone(), - qualifier: Some(qualifier.to_owned()), - }) + .map(|f| DFField::from_qualified(qualifier, f.clone())) .collect(), ) } @@ -149,47 +147,80 @@ impl DFSchema { ))) } - /// Find the index of the column with the given qualifer and name - pub fn index_of_column(&self, col: &Column) -> Result { - for i in 0..self.fields.len() { - let field = &self.fields[i]; - if field.qualifier() == col.relation.as_ref() && field.name() == &col.name { - return Ok(i); - } + fn index_of_column_by_name( + &self, + qualifier: Option<&str>, + name: &str, + ) -> Result { + let matches: Vec = self + .fields + .iter() + .enumerate() + .filter(|(_, field)| match (qualifier, &field.qualifier) { + // field to lookup is qualified. + // current field is qualified and not shared between relations, compare both + // qualifer and name. + (Some(q), Some(field_q)) => q == field_q && field.name() == name, + // field to lookup is qualified but current field is unqualified. + (Some(_), None) => false, + // field to lookup is unqualified, no need to compare qualifier + _ => field.name() == name, + }) + .map(|(idx, _)| idx) + .collect(); + + match matches.len() { + 0 => Err(DataFusionError::Plan(format!( + "No field named '{}.{}'. Valid fields are {}.", + qualifier.unwrap_or(""), + name, + self.get_field_names() + ))), + 1 => Ok(matches[0]), + _ => Err(DataFusionError::Internal(format!( + "Ambiguous reference to qualified field named '{}.{}'", + qualifier.unwrap_or(""), + name + ))), } - Err(DataFusionError::Plan(format!( - "No field matches column '{}'. Available fields: {}", - col, self - ))) + } + + /// Find the index of the column with the given qualifier and name + pub fn index_of_column(&self, col: &Column) -> Result { + self.index_of_column_by_name(col.relation.as_deref(), &col.name) } /// Find the field with the given name pub fn field_with_name( &self, - relation_name: Option<&str>, + qualifier: Option<&str>, name: &str, - ) -> Result { - if let Some(relation_name) = relation_name { - self.field_with_qualified_name(relation_name, name) + ) -> Result<&DFField> { + if let Some(qualifier) = qualifier { + self.field_with_qualified_name(qualifier, name) } else { self.field_with_unqualified_name(name) } } - /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result { - let matches: Vec<&DFField> = self - .fields + /// Find all fields match the given name + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { + self.fields .iter() .filter(|field| field.name() == name) - .collect(); + .collect() + } + + /// Find the field with the given name + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { + let matches = self.fields_with_unqualified_name(name); match matches.len() { 0 => Err(DataFusionError::Plan(format!( "No field with unqualified name '{}'. Valid fields are {}.", name, self.get_field_names() ))), - 1 => Ok(matches[0].to_owned()), + 1 => Ok(matches[0]), _ => Err(DataFusionError::Plan(format!( "Ambiguous reference to field named '{}'", name @@ -200,33 +231,15 @@ impl DFSchema { /// Find the field with the given qualified name pub fn field_with_qualified_name( &self, - relation_name: &str, + qualifier: &str, name: &str, - ) -> Result { - let matches: Vec<&DFField> = self - .fields - .iter() - .filter(|field| { - field.qualifier == Some(relation_name.to_string()) && field.name() == name - }) - .collect(); - match matches.len() { - 0 => Err(DataFusionError::Plan(format!( - "No field named '{}.{}'. Valid fields are {}.", - relation_name, - name, - self.get_field_names() - ))), - 1 => Ok(matches[0].to_owned()), - _ => Err(DataFusionError::Internal(format!( - "Ambiguous reference to qualified field named '{}.{}'", - relation_name, name - ))), - } + ) -> Result<&DFField> { + let idx = self.index_of_column_by_name(Some(qualifier), name)?; + Ok(self.field(idx)) } /// Find the field with the given qualified column - pub fn field_from_qualified_column(&self, column: &Column) -> Result { + pub fn field_from_column(&self, column: &Column) -> Result<&DFField> { match &column.relation { Some(r) => self.field_with_qualified_name(r, &column.name), None => self.field_with_unqualified_name(&column.name), @@ -247,31 +260,20 @@ impl DFSchema { fields: self .fields .into_iter() - .map(|f| { - if f.qualifier().is_some() { - DFField::new( - None, - f.name(), - f.data_type().to_owned(), - f.is_nullable(), - ) - } else { - f - } - }) + .map(|f| f.strip_qualifier()) .collect(), } } /// Replace all field qualifier with new value in schema - pub fn replace_qualifier(self, qualifer: &str) -> Self { + pub fn replace_qualifier(self, qualifier: &str) -> Self { DFSchema { fields: self .fields .into_iter() .map(|f| { DFField::new( - Some(qualifer), + Some(qualifier), f.name(), f.data_type().to_owned(), f.is_nullable(), @@ -328,10 +330,7 @@ impl TryFrom for DFSchema { schema .fields() .iter() - .map(|f| DFField { - field: f.clone(), - qualifier: None, - }) + .map(|f| DFField::from(f.clone())) .collect(), ) } @@ -454,8 +453,8 @@ impl DFField { /// Returns a string to the `DFField`'s qualified name pub fn qualified_name(&self) -> String { - if let Some(relation_name) = &self.qualifier { - format!("{}.{}", relation_name, self.field.name()) + if let Some(qualifier) = &self.qualifier { + format!("{}.{}", qualifier, self.field.name()) } else { self.field.name().to_owned() } @@ -469,6 +468,14 @@ impl DFField { } } + /// Builds an unqualified column based on self + pub fn unqualified_column(&self) -> Column { + Column { + relation: None, + name: self.field.name().to_string(), + } + } + /// Get the optional qualifier pub fn qualifier(&self) -> Option<&String> { self.qualifier.as_ref() @@ -478,6 +485,12 @@ impl DFField { pub fn field(&self) -> &Field { &self.field } + + /// Return field with qualifier stripped + pub fn strip_qualifier(mut self) -> Self { + self.qualifier = None; + self + } } #[cfg(test)] diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 1fab9bb875ae..9454d7593c3f 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,7 +20,7 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{window_frames, DFField, DFSchema, DFSchemaRef}; +use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, window_functions, @@ -29,7 +29,7 @@ use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::sync::Arc; @@ -89,14 +89,46 @@ impl Column { /// /// For example, `foo` will be normalized to `t.foo` if there is a /// column named `foo` in a relation named `t` found in `schemas` - pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result { + pub fn normalize(self, plan: &LogicalPlan) -> Result { if self.relation.is_some() { return Ok(self); } - for schema in schemas { - if let Ok(field) = schema.field_with_unqualified_name(&self.name) { - return Ok(field.qualified_column()); + let schemas = plan.all_schemas(); + let using_columns = plan.using_columns()?; + + for schema in &schemas { + let fields = schema.fields_with_unqualified_name(&self.name); + match fields.len() { + 0 => continue, + 1 => { + return Ok(fields[0].qualified_column()); + } + _ => { + // More than 1 fields in this schema have their names set to self.name. + // + // This should only happen when a JOIN query with USING constraint references + // join columns using unqualified column name. For example: + // + // ```sql + // SELECT id FROM t1 JOIN t2 USING(id) + // ``` + // + // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. + // We will use the relation from the first matched field to normalize self. + + // Compare matched fields with one USING JOIN clause at a time + for using_col in &using_columns { + let all_matched = fields + .iter() + .all(|f| using_col.contains(&f.qualified_column())); + // All matched fields belong to the same using column set, in orther words + // the same join clause. We simply pick the qualifer from the first match. + if all_matched { + return Ok(fields[0].qualified_column()); + } + } + } } } @@ -321,9 +353,7 @@ impl Expr { pub fn get_type(&self, schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), - Expr::Column(c) => { - Ok(schema.field_from_qualified_column(c)?.data_type().clone()) - } + Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()), Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), @@ -395,9 +425,7 @@ impl Expr { pub fn nullable(&self, input_schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.nullable(input_schema), - Expr::Column(c) => { - Ok(input_schema.field_from_qualified_column(c)?.is_nullable()) - } + Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()), Expr::Literal(value) => Ok(value.is_null()), Expr::ScalarVariable(_) => Ok(true), Expr::Case { @@ -1118,36 +1146,56 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { } } +/// Recursively replace all Column expressions in a given expression tree with Column expressions +/// provided by the hash map argument. +pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { + struct ColumnReplacer<'a> { + replace_map: &'a HashMap<&'a Column, &'a Column>, + } + + impl<'a> ExprRewriter for ColumnReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = &expr { + match self.replace_map.get(c) { + Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), + None => Ok(expr), + } + } else { + Ok(expr) + } + } + } + + e.rewrite(&mut ColumnReplacer { replace_map }) +} + /// Recursively call [`Column::normalize`] on all Column expressions /// in the `expr` expression tree. -pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result { - struct ColumnNormalizer<'a, 'b> { - schemas: &'a [&'b DFSchemaRef], +pub fn normalize_col(e: Expr, plan: &LogicalPlan) -> Result { + struct ColumnNormalizer<'a> { + plan: &'a LogicalPlan, } - impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> { + impl<'a> ExprRewriter for ColumnNormalizer<'a> { fn mutate(&mut self, expr: Expr) -> Result { if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize(self.schemas)?)) + Ok(Expr::Column(c.normalize(self.plan)?)) } else { Ok(expr) } } } - e.rewrite(&mut ColumnNormalizer { schemas }) + e.rewrite(&mut ColumnNormalizer { plan }) } /// Recursively normalize all Column expressions in a list of expression trees #[inline] pub fn normalize_cols( exprs: impl IntoIterator, - schemas: &[&DFSchemaRef], + plan: &LogicalPlan, ) -> Result> { - exprs - .into_iter() - .map(|e| normalize_col(e, schemas)) - .collect() + exprs.into_iter().map(|e| normalize_col(e, plan)).collect() } /// Create an expression to represent the min() aggregate function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 69d03d22bb21..86a2f567d7de 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -41,10 +41,10 @@ pub use expr::{ cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, - sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, - to_hex, translate, trim, trunc, upper, when, Column, Expr, ExprRewriter, - ExpressionVisitor, Literal, Recursion, + regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, + sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, + substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column, Expr, + ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 99f0fa14a2d9..b954b6a97950 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -21,9 +21,11 @@ use super::display::{GraphvizVisitor, IndentVisitor}; use super::expr::{Column, Expr}; use super::extension::UserDefinedLogicalNode; use crate::datasource::TableProvider; +use crate::error::DataFusionError; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::collections::HashSet; use std::{ fmt::{self, Display}, sync::Arc, @@ -354,6 +356,43 @@ impl LogicalPlan { | LogicalPlan::CreateExternalTable { .. } => vec![], } } + + /// returns all `Using` join columns in a logical plan + pub fn using_columns(&self) -> Result>, DataFusionError> { + struct UsingJoinColumnVisitor { + using_columns: Vec>, + } + + impl PlanVisitor for UsingJoinColumnVisitor { + type Error = DataFusionError; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + if let LogicalPlan::Join { + join_constraint: JoinConstraint::Using, + on, + .. + } = plan + { + self.using_columns.push( + on.iter() + .map(|entry| { + std::iter::once(entry.0.clone()) + .chain(std::iter::once(entry.1.clone())) + }) + .flatten() + .collect::>(), + ); + } + Ok(true) + } + } + + let mut visitor = UsingJoinColumnVisitor { + using_columns: vec![], + }; + self.accept(&mut visitor)?; + Ok(visitor.using_columns) + } } /// Logical partitioning schemes supported by the repartition operator. @@ -709,10 +748,21 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Join { on: ref keys, .. } => { + LogicalPlan::Join { + on: ref keys, + join_constraint, + .. + } => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect(); - write!(f, "Join: {}", join_expr.join(", ")) + match join_constraint { + JoinConstraint::On => { + write!(f, "Join: {}", join_expr.join(", ")) + } + JoinConstraint::Using => { + write!(f, "Join: Using {}", join_expr.join(", ")) + } + } } LogicalPlan::CrossJoin { .. } => { write!(f, "CrossJoin:") diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index c1d81fe62934..76d8c05bed4c 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -16,7 +16,7 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; -use crate::logical_plan::{and, Column, LogicalPlan}; +use crate::logical_plan::{and, replace_col, Column, LogicalPlan}; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -96,12 +96,21 @@ fn get_join_predicates<'a>( let left_columns = &left .fields() .iter() - .map(|f| f.qualified_column()) + .map(|f| { + std::iter::once(f.qualified_column()) + // we need to push down filter using unqualified column as well + .chain(std::iter::once(f.unqualified_column())) + }) + .flatten() .collect::>(); let right_columns = &right .fields() .iter() - .map(|f| f.qualified_column()) + .map(|f| { + std::iter::once(f.qualified_column()) + .chain(std::iter::once(f.unqualified_column())) + }) + .flatten() .collect::>(); let filters = state @@ -232,6 +241,38 @@ fn split_members<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) { } } +fn optimize_join( + mut state: State, + plan: &LogicalPlan, + left: &LogicalPlan, + right: &LogicalPlan, +) -> Result { + let (pushable_to_left, pushable_to_right, keep) = + get_join_predicates(&state, left.schema(), right.schema()); + + let mut left_state = state.clone(); + left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); + let left = optimize(left, left_state)?; + + let mut right_state = state.clone(); + right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); + let right = optimize(right, right_state)?; + + // create a new Join with the new `left` and `right` + let expr = plan.expressions(); + let plan = utils::from_plan(plan, &expr, &[left, right])?; + + if keep.0.is_empty() { + Ok(plan) + } else { + // wrap the join on the filter whose predicates must be kept + let plan = add_filter(plan, &keep.0); + state.filters = remove_filters(&state.filters, &keep.1); + + Ok(plan) + } +} + fn optimize(plan: &LogicalPlan, mut state: State) -> Result { match plan { LogicalPlan::Explain { .. } => { @@ -336,32 +377,68 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .collect::>(); issue_filters(state, used_columns, plan) } - LogicalPlan::Join { left, right, .. } - | LogicalPlan::CrossJoin { left, right, .. } => { - let (pushable_to_left, pushable_to_right, keep) = - get_join_predicates(&state, left.schema(), right.schema()); - - let mut left_state = state.clone(); - left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); - let left = optimize(left, left_state)?; - - let mut right_state = state.clone(); - right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); - let right = optimize(right, right_state)?; - - // create a new Join with the new `left` and `right` - let expr = plan.expressions(); - let plan = utils::from_plan(plan, &expr, &[left, right])?; + LogicalPlan::CrossJoin { left, right, .. } => { + optimize_join(state, plan, left, right) + } + LogicalPlan::Join { + left, right, on, .. + } => { + // duplicate filters for joined columns so filters can be pushed down to both sides. + // Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + let join_side_filters = state + .filters + .iter() + .filter_map(|(predicate, columns)| { + let mut join_cols_to_replace = HashMap::new(); + for col in columns.iter() { + for (l, r) in on { + if col == l { + join_cols_to_replace.insert(col, r); + break; + } else if col == r { + join_cols_to_replace.insert(col, l); + break; + } + } + } - if keep.0.is_empty() { - Ok(plan) - } else { - // wrap the join on the filter whose predicates must be kept - let plan = add_filter(plan, &keep.0); - state.filters = remove_filters(&state.filters, &keep.1); + if join_cols_to_replace.is_empty() { + return None; + } - Ok(plan) - } + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } + }; + + let join_side_columns = columns + .clone() + .into_iter() + // replace keys in join_cols_to_replace with values in resulting column + // set + .filter(|c| !join_cols_to_replace.contains_key(c)) + .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone())) + .collect(); + + Some(Ok((join_side_predicate, join_side_columns))) + }) + .collect::>>()?; + state.filters.extend(join_side_filters); + + optimize_join(state, plan, left, right) } LogicalPlan::TableScan { source, @@ -878,12 +955,13 @@ mod tests { Ok(()) } - /// post-join predicates on a column common to both sides is pushed to both sides + /// post-on-join predicates on a column common to both sides is pushed to both sides #[test] - fn filter_join_on_common_independent() -> Result<()> { + fn filter_on_join_on_common_independent() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()).build()?; - let right = LogicalPlanBuilder::from(table_scan) + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a")])? .build()?; let plan = LogicalPlanBuilder::from(left) @@ -901,20 +979,61 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a LtEq Int64(1)\ - \n Join: #test.a = #test.a\ + \n Join: #test.a = #test2.a\ \n TableScan: test projection=None\ - \n Projection: #test.a\ - \n TableScan: test projection=None" + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" ); // filter sent to side before the join let expected = "\ - Join: #test.a = #test.a\ + Join: #test.a = #test2.a\ \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #test.a\ - \n Filter: #test.a LtEq Int64(1)\ - \n TableScan: test projection=None"; + \n Projection: #test2.a\ + \n Filter: #test2.a LtEq Int64(1)\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// post-using-join predicates on a column common to both sides is pushed to both sides + #[test] + fn filter_using_join_on_common_independent() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join_using( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + )? + .filter(col("a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Filter: #test.a LtEq Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + // filter sent to side before the join + let expected = "\ + Join: Using #test.a = #test2.a\ + \n Filter: #test.a LtEq Int64(1)\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n Filter: #test2.a LtEq Int64(1)\ + \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -923,10 +1042,11 @@ mod tests { #[test] fn filter_join_on_common_dependent() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()) + let left = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("c")])? .build()?; - let right = LogicalPlanBuilder::from(table_scan) + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a"), col("b")])? .build()?; let plan = LogicalPlanBuilder::from(left) @@ -944,12 +1064,12 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #test.c LtEq #test.b\ - \n Join: #test.a = #test.a\ + Filter: #test.c LtEq #test2.b\ + \n Join: #test.a = #test2.a\ \n Projection: #test.a, #test.c\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.b\ - \n TableScan: test projection=None" + \n Projection: #test2.a, #test2.b\ + \n TableScan: test2 projection=None" ); // expected is equal: no push-down @@ -962,12 +1082,14 @@ mod tests { #[test] fn filter_join_on_one_side() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()) + let left = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("b")])? .build()?; - let right = LogicalPlanBuilder::from(table_scan) + let table_scan_right = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(table_scan_right) .project(vec![col("a"), col("c")])? .build()?; + let plan = LogicalPlanBuilder::from(left) .join( &right, @@ -983,20 +1105,20 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.b LtEq Int64(1)\ - \n Join: #test.a = #test.a\ + \n Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.c\ - \n TableScan: test projection=None" + \n Projection: #test2.a, #test2.c\ + \n TableScan: test2 projection=None" ); let expected = "\ - Join: #test.a = #test.a\ + Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n Filter: #test.b LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.c\ - \n TableScan: test projection=None"; + \n Projection: #test2.a, #test2.c\ + \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 3c8f1ee4ceb5..0272b9f7872c 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -216,9 +216,7 @@ fn optimize_plan( let schema = build_join_schema( optimized_left.schema(), optimized_right.schema(), - on, join_type, - join_constraint, )?; Ok(LogicalPlan::Join { @@ -499,7 +497,7 @@ mod tests { } #[test] - fn join_schema_trim() -> Result<()> { + fn join_schema_trim_full_join_column_projection() -> Result<()> { let table_scan = test_table_scan()?; let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); @@ -511,7 +509,7 @@ mod tests { .project(vec![col("a"), col("b"), col("c1")])? .build()?; - // make sure projections are pushed down to table scan + // make sure projections are pushed down to both table scans let expected = "Projection: #test.a, #test.b, #test2.c1\ \n Join: #test.a = #test2.c1\ \n TableScan: test projection=Some([0, 1])\ @@ -521,7 +519,48 @@ mod tests { let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); - // make sure schema for join node doesn't include c1 column + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "c1", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + + #[test] + fn join_schema_trim_partial_join_column_projection() -> Result<()> { + // test join column push down without explicit column projections + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])? + // projecting joined column `a` should push the right side column `c1` projection as + // well into test2 table even though `c1` is not referenced in projection. + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to both table scans + let expected = "Projection: #test.a, #test.b\ + \n Join: #test.a = #test2.c1\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; assert_eq!( **optimized_join.schema(), @@ -535,6 +574,45 @@ mod tests { Ok(()) } + #[test] + fn join_schema_trim_using_join() -> Result<()> { + // shared join colums from using join should be pushed to both sides + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join_using(&table2_scan, JoinType::Left, vec!["a"])? + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: #test.a, #test.b\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "a", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + #[test] fn cast() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index ae3e196c2225..1d19f0681b35 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -215,13 +215,8 @@ pub fn from_plan( on, .. } => { - let schema = build_join_schema( - inputs[0].schema(), - inputs[1].schema(), - on, - join_type, - join_constraint, - )?; + let schema = + build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; Ok(LogicalPlan::Join { left: Arc::new(inputs[0].clone()), right: Arc::new(inputs[1].clone()), diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 195a19c54070..bd93f1bd195b 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -55,9 +55,10 @@ use arrow::array::{ use super::expressions::Column; use super::{ coalesce_partitions::CoalescePartitionsExec, - hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType}, + hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, }; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; use super::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -136,12 +137,7 @@ impl HashJoinExec { let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &on)?; - let schema = Arc::new(build_join_schema( - &left_schema, - &right_schema, - &on, - join_type, - )); + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type)); let random_state = RandomState::with_seeds(0, 0, 0, 0); @@ -1408,16 +1404,16 @@ mod tests { join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 5 | 9 | 20 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1449,16 +1445,16 @@ mod tests { ) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 5 | 9 | 20 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1526,18 +1522,18 @@ mod tests { let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+", - "| a1 | b2 | c1 | c2 |", - "+----+----+----+----+", - "| 1 | 1 | 7 | 70 |", - "| 2 | 2 | 8 | 80 |", - "| 2 | 2 | 9 | 80 |", - "+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1578,18 +1574,18 @@ mod tests { let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+", - "| a1 | b2 | c1 | c2 |", - "+----+----+----+----+", - "| 1 | 1 | 7 | 70 |", - "| 2 | 2 | 8 | 80 |", - "| 2 | 2 | 9 | 80 |", - "+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1626,7 +1622,7 @@ mod tests { let join = join(left, right, on, &JoinType::Inner)?; let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part let stream = join.execute(0).await?; @@ -1634,11 +1630,11 @@ mod tests { assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1647,12 +1643,12 @@ mod tests { let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 2 | 5 | 8 | 30 | 90 |", - "| 3 | 5 | 9 | 30 | 90 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 5 | 90 |", + "| 3 | 5 | 9 | 30 | 5 | 90 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1692,21 +1688,21 @@ mod tests { let join = join(left, right, on, &JoinType::Left).unwrap(); let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let stream = join.execute(0).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1772,19 +1768,19 @@ mod tests { let join = join(left, right, on, &JoinType::Left).unwrap(); let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let stream = join.execute(0).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | | |", - "| 2 | 5 | 8 | | |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | 4 | |", + "| 2 | 5 | 8 | | 5 | |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1845,16 +1841,16 @@ mod tests { let (columns, batches) = join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1885,16 +1881,16 @@ mod tests { &JoinType::Left, ) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1996,16 +1992,16 @@ mod tests { let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; - assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+", - "| | | 30 | 6 | 90 |", - "| 1 | 7 | 10 | 4 | 70 |", - "| 2 | 8 | 20 | 5 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2033,16 +2029,16 @@ mod tests { let (columns, batches) = partitioned_join_collect(left, right, on, &JoinType::Right).await?; - assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+", - "| | | 30 | 6 | 90 |", - "| 1 | 7 | 10 | 4 | 70 |", - "| 2 | 8 | 20 | 5 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 0cf0b9212cd2..9243affe9cfc 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -21,25 +21,9 @@ use crate::error::{DataFusionError, Result}; use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; +use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; -/// All valid types of joins. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum JoinType { - /// Inner Join - Inner, - /// Left Join - Left, - /// Right Join - Right, - /// Full Join - Full, - /// Semi Join - Semi, - /// Anti Join - Anti, -} - /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; /// Reference for JoinOn. @@ -104,46 +88,11 @@ fn check_join_set_is_valid( /// Creates a schema for a join operation. /// The fields from the left side are first -pub fn build_join_schema( - left: &Schema, - right: &Schema, - on: JoinOnRef, - join_type: &JoinType, -) -> Schema { +pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full => { - // remove right-side join keys if they have the same names as the left-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l.name() == r.name()) - .map(|on| on.1.name()) - .collect::>(); - + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { let left_fields = left.fields().iter(); - - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(f.name().as_str())); - - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinType::Right => { - // remove left-side join keys if they have the same names as the right-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l.name() == r.name()) - .map(|on| on.1.name()) - .collect::>(); - - let left_fields = left - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(f.name().as_str())); - let right_fields = right.fields().iter(); - // left then right left_fields.chain(right_fields).cloned().collect() } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 5b43ec12bbf0..12f563618d85 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -40,7 +40,6 @@ use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{hash_utils, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; -use crate::prelude::JoinType; use crate::scalar::ScalarValue; use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys}; use crate::variable::VarType; @@ -661,14 +660,6 @@ impl DefaultPhysicalPlanner { let physical_left = self.create_initial_plan(left, ctx_state)?; let right_df_schema = right.schema(); let physical_right = self.create_initial_plan(right, ctx_state)?; - let physical_join_type = match join_type { - JoinType::Inner => hash_utils::JoinType::Inner, - JoinType::Left => hash_utils::JoinType::Left, - JoinType::Right => hash_utils::JoinType::Right, - JoinType::Full => hash_utils::JoinType::Full, - JoinType::Semi => hash_utils::JoinType::Semi, - JoinType::Anti => hash_utils::JoinType::Anti, - }; let join_on = keys .iter() .map(|(l, r)| { @@ -702,7 +693,7 @@ impl DefaultPhysicalPlanner { Partitioning::Hash(right_expr, ctx_state.config.concurrency), )?), join_on, - &physical_join_type, + join_type, PartitionMode::Partitioned, )?)) } else { @@ -710,7 +701,7 @@ impl DefaultPhysicalPlanner { physical_left, physical_right, join_on, - &physical_join_type, + join_type, PartitionMode::CollectLeft, )?)) } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 213ae890d7d0..b633e6e8ca22 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -27,8 +27,8 @@ use crate::datasource::TableProvider; use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ - and, lit, union_with_alias, Column, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, PlanType, StringifiedPlan, ToDFSchema, + and, col, lit, normalize_col, union_with_alias, Column, DFSchema, Expr, LogicalPlan, + LogicalPlanBuilder, Operator, PlanType, StringifiedPlan, ToDFSchema, }; use crate::prelude::JoinType; use crate::scalar::ScalarValue; @@ -477,12 +477,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let right_schema = right.schema(); let mut join_keys = vec![]; for (l, r) in &possible_join_keys { - if left_schema.field_from_qualified_column(l).is_ok() - && right_schema.field_from_qualified_column(r).is_ok() + if left_schema.field_from_column(l).is_ok() + && right_schema.field_from_column(r).is_ok() { join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_from_qualified_column(r).is_ok() - && right_schema.field_from_qualified_column(l).is_ok() + } else if left_schema.field_from_column(r).is_ok() + && right_schema.field_from_column(l).is_ok() { join_keys.push((r.clone(), l.clone())); } @@ -560,7 +560,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // SELECT c1 AS m FROM t HAVING c1 > 10; // SELECT c1, MAX(c2) AS m FROM t GROUP BY c1 HAVING MAX(c2) > 10; // - resolve_aliases_to_exprs(&having_expr, &alias_map) + let having_expr = resolve_aliases_to_exprs(&having_expr, &alias_map)?; + normalize_col(having_expr, &projected_plan) }) .transpose()?; @@ -584,6 +585,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let group_by_expr = resolve_positions_to_exprs(&group_by_expr, &select_exprs) .unwrap_or(group_by_expr); + let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( plan.schema(), &[group_by_expr.clone()], @@ -662,13 +664,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result> { let input_schema = plan.schema(); - Ok(projection + projection .iter() .map(|expr| self.sql_select_to_rex(expr, input_schema)) .collect::>>()? .iter() .flat_map(|expr| expand_wildcard(expr, input_schema)) - .collect::>()) + .map(|expr| normalize_col(expr, plan)) + .collect::>>() } /// Wrap a plan in a projection @@ -816,20 +819,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { find_column_exprs(exprs) .iter() .try_for_each(|col| match col { - Expr::Column(col) => { - match &col.relation { - Some(r) => schema.field_with_qualified_name(r, &col.name), - None => schema.field_with_unqualified_name(&col.name), + Expr::Column(col) => match &col.relation { + Some(r) => { + schema.field_with_qualified_name(r, &col.name)?; + Ok(()) + } + None => { + if !schema.fields_with_unqualified_name(&col.name).is_empty() { + Ok(()) + } else { + Err(DataFusionError::Plan(format!( + "No field with unqualified name '{}'", + &col.name + ))) + } } - .map_err(|_| { - DataFusionError::Plan(format!( - "Invalid identifier '{}' for schema {}", - col, - schema.to_string() - )) - })?; - Ok(()) } + .map_err(|_: DataFusionError| { + DataFusionError::Plan(format!( + "Invalid identifier '{}' for schema {}", + col, + schema.to_string() + )) + }), _ => Err(DataFusionError::Internal("Not a column".to_string())), }) } @@ -907,11 +919,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { - Ok(Expr::Column( - schema - .field_with_unqualified_name(&id.value)? - .qualified_column(), - )) + // create a column expression based on raw user input, this column will be + // normalized with qualifer later by the SQL planner. + Ok(col(&id.value)) } } @@ -1651,7 +1661,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -1709,7 +1719,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -1719,7 +1729,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'x'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#x' for schema "), )); } @@ -2190,7 +2200,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -2280,7 +2290,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Column #doesnotexist not found in provided schemas"), )); } @@ -2290,7 +2300,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -2722,7 +2732,7 @@ mod tests { JOIN person as person2 \ USING (id)"; let expected = "Projection: #person.first_name, #person.id\ - \n Join: #person.id = #person2.id\ + \n Join: Using #person.id = #person2.id\ \n TableScan: person projection=None\ \n TableScan: person2 projection=None"; quick_test(sql, expected); diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 7ca7cc12d9ef..c06feffd9f99 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -110,14 +110,19 @@ pub fn aggr_test_schema() -> SchemaRef { ])) } -/// some tests share a common table -pub fn test_table_scan() -> Result { +/// some tests share a common table with different names +pub fn test_table_scan_with_name(name: &str) -> Result { let schema = Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::UInt32, false), Field::new("c", DataType::UInt32, false), ]); - LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() + LogicalPlanBuilder::scan_empty(Some(name), &schema, None)?.build() +} + +/// some tests share a common table +pub fn test_table_scan() -> Result { + test_table_scan_with_name("test") } pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) {