Skip to content

Commit

Permalink
[CHORE] Use treenode for tree traversal in logical optimizer rules
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Sep 6, 2024
1 parent 6fe408c commit c9ea3bb
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 421 deletions.
25 changes: 25 additions & 0 deletions src/common/treenode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,31 @@ impl<T> Transformed<T> {
f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
}

/// Returns self if self is transformed, otherwise returns other.
pub fn or(self, other: Self) -> Self {
if self.transformed {
self
} else {
other
}
}

/// Maps a `Transformed<T>` to `Transformed<U>`,
/// by supplying a function to apply to a contained Yes value
/// as well as a function to apply to a contained No value.
#[inline]
pub fn map_yes_no<U, Y: FnOnce(T) -> U, N: FnOnce(T) -> U>(
self,
yes_op: Y,
no_op: N,
) -> Transformed<U> {
if self.transformed {
Transformed::yes(yes_op(self.data))
} else {
Transformed::no(no_op(self.data))
}
}

/// Maps the [`Transformed`] object to the result of the given `f`.
pub fn transform_data<U, F: FnOnce(T) -> Result<Transformed<U>>>(
self,
Expand Down
84 changes: 39 additions & 45 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::sync::Arc;

use common_treenode::Transformed;
use daft_core::prelude::*;
use daft_dsl::{optimization, resolve_exprs, AggExpr, ApproxPercentileParams, Expr, ExprRef};
use indexmap::{IndexMap, IndexSet};
use itertools::Itertools;
use snafu::ResultExt;

use crate::logical_optimization::Transformed;
use crate::logical_plan::{CreationSnafu, Result};
use crate::LogicalPlan;

Expand Down Expand Up @@ -144,7 +144,7 @@ impl Project {
.map(|e| {
let new_expr =
replace_column_with_semantic_id(e.clone(), &subexprs_to_replace, schema);
let new_expr = new_expr.unwrap();
let new_expr = new_expr.data;
// The substitution can unintentionally change the expression's name
// (since the name depends on the first column referenced, which can be substituted away)
// so re-alias the original name here if it has changed.
Expand Down Expand Up @@ -185,10 +185,10 @@ fn replace_column_with_semantic_id(
Expr::Alias(_, name) => Expr::Alias(new_expr.into(), name.clone()),
_ => new_expr,
};
Transformed::Yes(new_expr.into())
Transformed::yes(new_expr.into())
} else {
match e.as_ref() {
Expr::Column(_) | Expr::Literal(_) => Transformed::No(e),
Expr::Column(_) | Expr::Literal(_) => Transformed::no(e),
Expr::Agg(agg_expr) => replace_column_with_semantic_id_aggexpr(
agg_expr.clone(),
subexprs_to_replace,
Expand Down Expand Up @@ -241,11 +241,11 @@ fn replace_column_with_semantic_id(
subexprs_to_replace,
schema,
);
if child.is_no() && fill_value.is_no() {
Transformed::No(e)
if !child.transformed && !fill_value.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Expr::FillNull(child.unwrap().clone(), fill_value.unwrap().clone()).into(),
Transformed::yes(
Expr::FillNull(child.data.clone(), fill_value.data.clone()).into(),
)
}
}
Expand All @@ -254,12 +254,10 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema);
let items =
replace_column_with_semantic_id(items.clone(), subexprs_to_replace, schema);
if child.is_no() && items.is_no() {
Transformed::No(e)
if !child.transformed && !items.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Expr::IsIn(child.unwrap().clone(), items.unwrap().clone()).into(),
)
Transformed::yes(Expr::IsIn(child.data.clone(), items.data.clone()).into())
}
}
Expr::Between(child, lower, upper) => {
Expand All @@ -269,16 +267,12 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(lower.clone(), subexprs_to_replace, schema);
let upper =
replace_column_with_semantic_id(upper.clone(), subexprs_to_replace, schema);
if child.is_no() && lower.is_no() && upper.is_no() {
Transformed::No(e)
if !child.transformed && !lower.transformed && !upper.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Expr::Between(
child.unwrap().clone(),
lower.unwrap().clone(),
upper.unwrap().clone(),
)
.into(),
Transformed::yes(
Expr::Between(child.data.clone(), lower.data.clone(), upper.data.clone())
.into(),
)
}
}
Expand All @@ -287,14 +281,14 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(left.clone(), subexprs_to_replace, schema);
let right =
replace_column_with_semantic_id(right.clone(), subexprs_to_replace, schema);
if left.is_no() && right.is_no() {
Transformed::No(e)
if !left.transformed && !right.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Transformed::yes(
Expr::BinaryOp {
op: *op,
left: left.unwrap().clone(),
right: right.unwrap().clone(),
left: left.data.clone(),
right: right.data.clone(),
}
.into(),
)
Expand All @@ -311,14 +305,14 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(if_true.clone(), subexprs_to_replace, schema);
let if_false =
replace_column_with_semantic_id(if_false.clone(), subexprs_to_replace, schema);
if predicate.is_no() && if_true.is_no() && if_false.is_no() {
Transformed::No(e)
if !predicate.transformed && !if_true.transformed && !if_false.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Transformed::yes(
Expr::IfElse {
predicate: predicate.unwrap().clone(),
if_true: if_true.unwrap().clone(),
if_false: if_false.unwrap().clone(),
predicate: predicate.data.clone(),
if_true: if_true.data.clone(),
if_false: if_false.data.clone(),
}
.into(),
)
Expand All @@ -331,13 +325,13 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema)
})
.collect::<Vec<_>>();
if transforms.iter().all(|e| e.is_no()) {
Transformed::No(e)
if transforms.iter().all(|e| !e.transformed) {
Transformed::no(e)
} else {
Transformed::Yes(
Transformed::yes(
Expr::Function {
func: func.clone(),
inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(),
inputs: transforms.iter().map(|t| t.data.clone()).collect(),
}
.into(),
)
Expand All @@ -352,11 +346,11 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema)
})
.collect::<Vec<_>>();
if transforms.iter().all(|e| e.is_no()) {
Transformed::No(e)
if transforms.iter().all(|e| !e.transformed) {
Transformed::no(e)
} else {
func.inputs = transforms.iter().map(|t| t.unwrap()).cloned().collect();
Transformed::Yes(Expr::ScalarFunction(func).into())
func.inputs = transforms.iter().map(|t| t.data.clone()).collect();
Transformed::yes(Expr::ScalarFunction(func).into())
}
}
}
Expand Down Expand Up @@ -446,12 +440,12 @@ fn replace_column_with_semantic_id_aggexpr(
.iter()
.map(|e| replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema))
.collect::<Vec<_>>();
if transforms.iter().all(|e| e.is_no()) {
Transformed::No(AggExpr::MapGroups { func, inputs })
if transforms.iter().all(|e| !e.transformed) {
Transformed::no(AggExpr::MapGroups { func, inputs })
} else {
Transformed::Yes(AggExpr::MapGroups {
Transformed::yes(AggExpr::MapGroups {
func: func.clone(),
inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(),
inputs: transforms.iter().map(|t| t.data.clone()).collect(),
})
}
}
Expand Down
1 change: 0 additions & 1 deletion src/daft-plan/src/logical_optimization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ mod rules;
mod test;

pub use optimizer::{Optimizer, OptimizerConfig};
pub use rules::Transformed;
Loading

0 comments on commit c9ea3bb

Please sign in to comment.