Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 140 additions & 9 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,22 @@ impl Column {
/// For example, `foo` will be normalized to `t.foo` if there is a
/// column named `foo` in a relation named `t` found in `schemas`
pub fn normalize(self, plan: &LogicalPlan) -> Result<Self> {
let schemas = plan.all_schemas();
let using_columns = plan.using_columns()?;
self.normalize_with_schemas(&schemas, &using_columns)
}

// Internal implementation of normalize
fn normalize_with_schemas(
self,
schemas: &[&Arc<DFSchema>],
using_columns: &[HashSet<Column>],
) -> Result<Self> {
if self.relation.is_some() {
return Ok(self);
}

let schemas = plan.all_schemas();
let using_columns = plan.using_columns()?;

for schema in &schemas {
for schema in schemas {
let fields = schema.fields_with_unqualified_name(&self.name);
match fields.len() {
0 => continue,
Expand All @@ -118,7 +126,7 @@ impl Column {
// We will use the relation from the first matched field to normalize self.

// Compare matched fields with one USING JOIN clause at a time
for using_col in &using_columns {
for using_col in using_columns {
let all_matched = fields
.iter()
.all(|f| using_col.contains(&f.qualified_column()));
Expand Down Expand Up @@ -1171,22 +1179,39 @@ pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<E

/// Recursively call [`Column::normalize`] on all Column expressions
/// in the `expr` expression tree.
pub fn normalize_col(e: Expr, plan: &LogicalPlan) -> Result<Expr> {
pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?)
}

/// Recursively call [`Column::normalize`] on all Column expressions
/// in the `expr` expression tree.
fn normalize_col_with_schemas(
expr: Expr,
schemas: &[&Arc<DFSchema>],
using_columns: &[HashSet<Column>],
) -> Result<Expr> {
struct ColumnNormalizer<'a> {
plan: &'a LogicalPlan,
schemas: &'a [&'a Arc<DFSchema>],
using_columns: &'a [HashSet<Column>],
}

impl<'a> ExprRewriter for ColumnNormalizer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(c) = expr {
Ok(Expr::Column(c.normalize(self.plan)?))
Ok(Expr::Column(c.normalize_with_schemas(
self.schemas,
self.using_columns,
)?))
} else {
Ok(expr)
}
}
}

e.rewrite(&mut ColumnNormalizer { plan })
expr.rewrite(&mut ColumnNormalizer {
schemas,
using_columns,
})
}

/// Recursively normalize all Column expressions in a list of expression trees
Expand All @@ -1198,6 +1223,38 @@ pub fn normalize_cols(
exprs.into_iter().map(|e| normalize_col(e, plan)).collect()
}

/// Recursively 'unnormalize' (remove all qualifiers) from an
/// expression tree.
///
/// For example, if there were expressions like `foo.bar` this would
/// rewrite it to just `bar`.
pub fn unnormalize_col(expr: Expr) -> Expr {
struct RemoveQualifier {}

impl ExprRewriter for RemoveQualifier {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(col) = expr {
//let Column { relation: _, name } = col;
Ok(Expr::Column(Column {
relation: None,
name: col.name,
}))
} else {
Ok(expr)
}
}
}

expr.rewrite(&mut RemoveQualifier {})
.expect("Unnormalize is infallable")
}

/// Recursively un-normalize all Column expressions in a list of expression trees
#[inline]
pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
exprs.into_iter().map(unnormalize_col).collect()
}

/// Create an expression to represent the min() aggregate function
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction {
Expand Down Expand Up @@ -1810,4 +1867,78 @@ mod tests {
}
}
}

#[test]
fn normalize_cols() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some additional unit test coverage here for normalize when I was writing the unnormalize versions

let expr = col("a") + col("b") + col("c");

// Schemas with some matching and some non matching cols
let schema_a =
DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")])
.unwrap();
let schema_c =
DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")])
.unwrap();
let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap();
// non matching
let schema_f =
DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")])
.unwrap();
let schemas = vec![schema_c, schema_f, schema_b, schema_a]
.into_iter()
.map(Arc::new)
.collect::<Vec<_>>();
let schemas = schemas.iter().collect::<Vec<_>>();

let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
assert_eq!(
normalized_expr,
col("tableA.a") + col("tableB.b") + col("tableC.c")
);
}

#[test]
fn normalize_cols_priority() {
let expr = col("a") + col("b");
// Schemas with multiple matches for column a, first takes priority
let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap();
let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap();
let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap();
let schemas = vec![schema_a2, schema_b, schema_a]
.into_iter()
.map(Arc::new)
.collect::<Vec<_>>();
let schemas = schemas.iter().collect::<Vec<_>>();

let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b"));
}

#[test]
fn normalize_cols_non_exist() {
// test normalizing columns when the name doesn't exist
let expr = col("a") + col("b");
let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap();
let schemas = vec![schema_a].into_iter().map(Arc::new).collect::<Vec<_>>();
let schemas = schemas.iter().collect::<Vec<_>>();

let error = normalize_col_with_schemas(expr, &schemas, &[])
.unwrap_err()
.to_string();
assert_eq!(
error,
"Error during planning: Column #b not found in provided schemas"
);
}

#[test]
fn unnormalize_cols() {
let expr = col("tableA.a") + col("tableB.b");
let unnormalized_expr = unnormalize_col(expr);
assert_eq!(unnormalized_expr, col("a") + col("b"));
}

fn make_field(relation: &str, column: &str) -> DFField {
DFField::new(Some(relation), column, DataType::Int8, false)
}
}
4 changes: 2 additions & 2 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ pub use expr::{
min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match,
regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim,
sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos,
substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column, Expr,
ExprRewriter, ExpressionVisitor, Literal, Recursion,
substr, sum, tan, to_hex, translate, trim, trunc, unnormalize_col, unnormalize_cols,
upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
13 changes: 10 additions & 3 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ use super::{
};
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::{
DFSchema, Expr, LogicalPlan, Operator, Partitioning as LogicalPartitioning, PlanType,
StringifiedPlan, UserDefinedLogicalNode,
unnormalize_cols, DFSchema, Expr, LogicalPlan, Operator,
Partitioning as LogicalPartitioning, PlanType, StringifiedPlan,
UserDefinedLogicalNode,
};
use crate::physical_plan::explain::ExplainExec;
use crate::physical_plan::expressions;
Expand Down Expand Up @@ -311,7 +312,13 @@ impl DefaultPhysicalPlanner {
filters,
limit,
..
} => source.scan(projection, batch_size, filters, *limit),
} => {
// Remove all qualifiers from the scan as the provider
// doesn't know (nor should care) how the relation was
// referred to in the query
let filters = unnormalize_cols(filters.iter().cloned());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the actual fix for parquet pruning.

source.scan(projection, batch_size, &filters, *limit)
}
LogicalPlan::Window {
input, window_expr, ..
} => {
Expand Down
24 changes: 12 additions & 12 deletions datafusion/tests/parquet_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ async fn prune_timestamps_nanos() {
.query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')")
.await;
println!("{}", output.description());
// TODO This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(1));
assert_eq!(output.row_groups_pruned(), Some(0));
// This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(0));
assert_eq!(output.row_groups_pruned(), Some(1));
assert_eq!(output.result_rows, 10, "{}", output.description());
}

Expand All @@ -59,9 +59,9 @@ async fn prune_timestamps_micros() {
)
.await;
println!("{}", output.description());
// TODO This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(1));
assert_eq!(output.row_groups_pruned(), Some(0));
// This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(0));
assert_eq!(output.row_groups_pruned(), Some(1));
assert_eq!(output.result_rows, 10, "{}", output.description());
}

Expand All @@ -74,9 +74,9 @@ async fn prune_timestamps_millis() {
)
.await;
println!("{}", output.description());
// TODO This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(1));
assert_eq!(output.row_groups_pruned(), Some(0));
// This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(0));
assert_eq!(output.row_groups_pruned(), Some(1));
assert_eq!(output.result_rows, 10, "{}", output.description());
}

Expand All @@ -89,9 +89,9 @@ async fn prune_timestamps_seconds() {
)
.await;
println!("{}", output.description());
// TODO This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(1));
assert_eq!(output.row_groups_pruned(), Some(0));
// This should prune one metrics without error
assert_eq!(output.predicate_evaluation_errors(), Some(0));
assert_eq!(output.row_groups_pruned(), Some(1));
assert_eq!(output.result_rows, 10, "{}", output.description());
}

Expand Down