Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 74 additions & 24 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1385,29 +1385,26 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
when_then_expr,
else_expr,
}) if !when_then_expr.is_empty()
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
&& when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
&& info.is_boolean_type(&when_then_expr[0].1)? =>
{
// The disjunction of all the when predicates encountered so far
// String disjunction of all the when predicates encountered so far. Not nullable.
let mut filter_expr = lit(false);
// The disjunction of all the cases
let mut out_expr = lit(false);

for (when, then) in when_then_expr {
let case_expr = when
.as_ref()
.clone()
.and(filter_expr.clone().not())
.and(*then);
let when = is_exactly_true(*when, info)?;
let case_expr =
when.clone().and(filter_expr.clone().not()).and(*then);

out_expr = out_expr.or(case_expr);
filter_expr = filter_expr.or(*when);
filter_expr = filter_expr.or(when);
}

if let Some(else_expr) = else_expr {
let case_expr = filter_expr.not().and(*else_expr);
out_expr = out_expr.or(case_expr);
}
let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
let case_expr = filter_expr.not().and(else_expr);
out_expr = out_expr.or(case_expr);

// Do a first pass at simplification
out_expr.rewrite(self)?
Expand Down Expand Up @@ -1881,6 +1878,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
Ok(Expr::InList(l1))
}

/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
if !info.nullable(&expr)? {
Ok(expr)
} else {
Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr),
op: Operator::IsNotDistinctFrom,
right: Box::new(lit(true)),
}))
}
}

#[cfg(test)]
mod tests {
use crate::simplify_expressions::SimplifyContext;
Expand Down Expand Up @@ -3272,12 +3282,12 @@ mod tests {
simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("not_ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
))),
col("c2").not().and(col("c2")) // #1716
lit(false) // #1716
);

// CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
Expand All @@ -3292,12 +3302,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
)))),
col("c2")
col("c2_non_null")
);

// CASE WHEN ISNULL(c2) THEN true ELSE c2
Expand Down Expand Up @@ -3328,12 +3338,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c1 then true WHEN c2 then true ELSE false
Expand All @@ -3347,13 +3357,53 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c > 0 THEN true END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
None,
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
col("c3").gt(lit(0_i64)),
lit(true)
)
.and(lit_bool_null()))
);

// CASE WHEN c > 0 THEN true ELSE false END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
Some(Box::new(lit(false))),
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
);
}

fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsDistinctFrom,
right: Box::new(right.into()),
})
}

fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsNotDistinctFrom,
right: Box::new(right.into()),
})
}

#[test]
Expand Down
20 changes: 15 additions & 5 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,22 @@ query B
select case when a=1 then false end from foo;
----
false
false
false
false
false
false
NULL
NULL
NULL
NULL
NULL

query IBB
SELECT c,
CASE WHEN c > 0 THEN true END AS c1,
CASE WHEN c > 0 THEN true ELSE false END AS c2
FROM (VALUES (1), (0), (-1), (NULL)) AS t(c)
----
1 true true
0 NULL false
-1 NULL false
NULL NULL false

statement ok
drop table foo
Loading