diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 9454d7593c3f3..59c99797e0cd8 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -90,14 +90,22 @@ impl Column { /// For example, `foo` will be normalized to `t.foo` if there is a /// column named `foo` in a relation named `t` found in `schemas` pub fn normalize(self, plan: &LogicalPlan) -> Result { + let schemas = plan.all_schemas(); + let using_columns = plan.using_columns()?; + self.normalize_with_schemas(&schemas, &using_columns) + } + + // Internal implementation of normalize + fn normalize_with_schemas( + self, + schemas: &[&Arc], + using_columns: &[HashSet], + ) -> Result { if self.relation.is_some() { return Ok(self); } - let schemas = plan.all_schemas(); - let using_columns = plan.using_columns()?; - - for schema in &schemas { + for schema in schemas { let fields = schema.fields_with_unqualified_name(&self.name); match fields.len() { 0 => continue, @@ -118,7 +126,7 @@ impl Column { // We will use the relation from the first matched field to normalize self. // Compare matched fields with one USING JOIN clause at a time - for using_col in &using_columns { + for using_col in using_columns { let all_matched = fields .iter() .all(|f| using_col.contains(&f.qualified_column())); @@ -1171,22 +1179,39 @@ pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result Result { +pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { + normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +fn normalize_col_with_schemas( + expr: Expr, + schemas: &[&Arc], + using_columns: &[HashSet], +) -> Result { struct ColumnNormalizer<'a> { - plan: &'a LogicalPlan, + schemas: &'a [&'a Arc], + using_columns: &'a [HashSet], } impl<'a> ExprRewriter for ColumnNormalizer<'a> { fn mutate(&mut self, expr: Expr) -> Result { if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize(self.plan)?)) + Ok(Expr::Column(c.normalize_with_schemas( + self.schemas, + self.using_columns, + )?)) } else { Ok(expr) } } } - e.rewrite(&mut ColumnNormalizer { plan }) + expr.rewrite(&mut ColumnNormalizer { + schemas, + using_columns, + }) } /// Recursively normalize all Column expressions in a list of expression trees @@ -1198,6 +1223,38 @@ pub fn normalize_cols( exprs.into_iter().map(|e| normalize_col(e, plan)).collect() } +/// Recursively 'unnormalize' (remove all qualifiers) from an +/// expression tree. +/// +/// For example, if there were expressions like `foo.bar` this would +/// rewrite it to just `bar`. +pub fn unnormalize_col(expr: Expr) -> Expr { + struct RemoveQualifier {} + + impl ExprRewriter for RemoveQualifier { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(col) = expr { + //let Column { relation: _, name } = col; + Ok(Expr::Column(Column { + relation: None, + name: col.name, + })) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut RemoveQualifier {}) + .expect("Unnormalize is infallable") +} + +/// Recursively un-normalize all Column expressions in a list of expression trees +#[inline] +pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { + exprs.into_iter().map(unnormalize_col).collect() +} + /// Create an expression to represent the min() aggregate function pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction { @@ -1810,4 +1867,78 @@ mod tests { } } } + + #[test] + fn normalize_cols() { + let expr = col("a") + col("b") + col("c"); + + // Schemas with some matching and some non matching cols + let schema_a = + DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) + .unwrap(); + let schema_c = + DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) + .unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + // non matching + let schema_f = + DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) + .unwrap(); + let schemas = vec![schema_c, schema_f, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!( + normalized_expr, + col("tableA.a") + col("tableB.b") + col("tableC.c") + ); + } + + #[test] + fn normalize_cols_priority() { + let expr = col("a") + col("b"); + // Schemas with multiple matches for column a, first takes priority + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); + let schemas = vec![schema_a2, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); + } + + #[test] + fn normalize_cols_non_exist() { + // test normalizing columns when the name doesn't exist + let expr = col("a") + col("b"); + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); + let schemas = schemas.iter().collect::>(); + + let error = normalize_col_with_schemas(expr, &schemas, &[]) + .unwrap_err() + .to_string(); + assert_eq!( + error, + "Error during planning: Column #b not found in provided schemas" + ); + } + + #[test] + fn unnormalize_cols() { + let expr = col("tableA.a") + col("tableB.b"); + let unnormalized_expr = unnormalize_col(expr); + assert_eq!(unnormalized_expr, col("a") + col("b")); + } + + fn make_field(relation: &str, column: &str) -> DFField { + DFField::new(Some(relation), column, DataType::Int8, false) + } } diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 86a2f567d7de4..2c751abdad349 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -43,8 +43,8 @@ pub use expr::{ min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, - substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column, Expr, - ExprRewriter, ExpressionVisitor, Literal, Recursion, + substr, sum, tan, to_hex, translate, trim, trunc, unnormalize_col, unnormalize_cols, + upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 73b2f362989f6..df4168370003a 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -23,8 +23,9 @@ use super::{ }; use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ - DFSchema, Expr, LogicalPlan, Operator, Partitioning as LogicalPartitioning, PlanType, - StringifiedPlan, UserDefinedLogicalNode, + unnormalize_cols, DFSchema, Expr, LogicalPlan, Operator, + Partitioning as LogicalPartitioning, PlanType, StringifiedPlan, + UserDefinedLogicalNode, }; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions; @@ -311,7 +312,13 @@ impl DefaultPhysicalPlanner { filters, limit, .. - } => source.scan(projection, batch_size, filters, *limit), + } => { + // Remove all qualifiers from the scan as the provider + // doesn't know (nor should care) how the relation was + // referred to in the query + let filters = unnormalize_cols(filters.iter().cloned()); + source.scan(projection, batch_size, &filters, *limit) + } LogicalPlan::Window { input, window_expr, .. } => { diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 86b3946e47121..f5486afc7aa4a 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -44,9 +44,9 @@ async fn prune_timestamps_nanos() { .query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); } @@ -59,9 +59,9 @@ async fn prune_timestamps_micros() { ) .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); } @@ -74,9 +74,9 @@ async fn prune_timestamps_millis() { ) .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); } @@ -89,9 +89,9 @@ async fn prune_timestamps_seconds() { ) .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); }