diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 69b5fbb9f8c0f..d526b63ae5d2c 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -803,23 +803,39 @@ mod test { use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::OptimizerContext; use crate::test::*; - use crate::Optimizer; use datafusion_expr::test::function_stub::{avg, sum}; - fn assert_optimized_plan_eq( - expected: &str, - plan: LogicalPlan, - config: Option<&dyn OptimizerConfig>, - ) { - let optimizer = - Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]); - let default_config = OptimizerContext::new(); - let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap(); - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(expected, formatted_plan); + macro_rules! assert_optimized_plan_equal { + ( + $config:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + assert_optimized_plan_eq_snapshot!( + $config, + rules, + $plan, + @ $expected, + ) + }}; + + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -844,13 +860,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\ - \n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]] + Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -864,13 +881,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -917,11 +935,14 @@ mod test { )? .build()?; - let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]] + TableScan: test + " + )?; // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -936,11 +957,14 @@ mod test { )? .build()?; - let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]] + TableScan: test + " + )?; // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -953,11 +977,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -970,11 +997,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -991,14 +1021,15 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a) + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1018,14 +1049,15 @@ mod test { )? .build()?; - let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\ - \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\ - \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\ - \n TableScan: table.test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a) + Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]] + Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a + TableScan: table.test + " + ) } #[test] @@ -1039,13 +1071,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS first, __common_expr_1 AS second + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1056,13 +1089,14 @@ mod test { .project(vec![lit(1) + col("a"), col("a") + lit(1)])? .build()?; - let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1) + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1074,12 +1108,14 @@ mod test { .project(vec![lit(1) + col("a")])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n Projection: Int32(1) + test.a, test.a\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + test.a + Projection: Int32(1) + test.a, test.a + TableScan: test + " + ) } #[test] @@ -1193,14 +1229,15 @@ mod test { .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))? .build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - Int32(10) > __common_expr_1\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - Int32(10) > __common_expr_1 + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1226,7 +1263,7 @@ mod test { fn test_alias_collision() -> Result<()> { let table_scan = test_table_scan()?; - let config = &OptimizerContext::new(); + let config = OptimizerContext::new(); let common_expr_1 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![ @@ -1241,14 +1278,18 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\ - \n Projection: test.a + test.b AS __common_expr_1, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); - - let config = &OptimizerContext::new(); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4 + Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c + Projection: test.a + test.b AS __common_expr_1, test.c + TableScan: test + " + )?; + + let config = OptimizerContext::new(); let _common_expr_1 = config.alias_generator().next(CSE_PREFIX); let common_expr_2 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan) @@ -1264,12 +1305,16 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4 + Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c + Projection: test.a + test.b AS __common_expr_2, test.c + TableScan: test + " + )?; Ok(()) } @@ -1308,13 +1353,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\ - \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5 + Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1331,13 +1377,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1360,13 +1407,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ - \n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4 + Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1382,14 +1430,15 @@ mod test { .project(vec![col("c1"), col("c2")])? .build()?; - let expected = "Projection: c1, c2\ - \n Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: c1, c2 + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1405,14 +1454,15 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c + Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1422,13 +1472,15 @@ mod test { let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1438,13 +1490,15 @@ mod test { let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1454,13 +1508,15 @@ mod test { let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1470,13 +1526,15 @@ mod test { let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1486,13 +1544,15 @@ mod test { let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1502,13 +1562,15 @@ mod test { let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1518,13 +1580,15 @@ mod test { let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1535,11 +1599,15 @@ mod test { .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - __common_expr_1 = Int32(30) + Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1)) let table_scan = test_table_scan()?; @@ -1548,11 +1616,16 @@ mod test { + col("a")) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\ - \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30) + Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c2 / (c1 + c3) <=> c2 / (c3 + c1) let table_scan = test_table_scan()?; @@ -1560,11 +1633,15 @@ mod test { * (col("b") / (col("c") + col("a")))) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; Ok(()) } @@ -1612,10 +1689,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\ - \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a + Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // is_null(a == b) <=> is_null(b == a) let table_scan = test_table_scan()?; @@ -1624,10 +1705,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\ - \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL + Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // a + b between 0 and 10 <=> b + a between 0 and 10 let table_scan = test_table_scan()?; @@ -1636,10 +1721,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\ - \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10) + Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c between a + b and 10 <=> c between b + a and 10 let table_scan = test_table_scan()?; @@ -1648,10 +1737,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\ - \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10) + Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // function call with argument <=> function call with argument let udf = ScalarUDF::from(TestUdf::new()); @@ -1661,11 +1754,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\ - \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a) + Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } /// returns a "random" function that is marked volatile (aka each invocation diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index d35572e6d34a3..d465faf0c5e83 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -440,22 +440,28 @@ mod tests { logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; + use insta::assert_snapshot; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let starting_schema = Arc::clone($plan.schema()); + let rule = EliminateCrossJoin::new(); + let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap(); + let formatted_plan = optimized_plan.display_indent_schema(); + // Ensure the rule was actually applied + assert!(is_plan_transformed, "failed to optimize plan"); + // Verify the schema remains unchanged + assert_eq!(&starting_schema, optimized_plan.schema()); + assert_snapshot!( + formatted_plan, + @ $expected, + ); - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { - let starting_schema = Arc::clone(plan.schema()); - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(transformed_plan.transformed, "failed to optimize plan"); - let optimized_plan = transformed_plan.data; - let formatted = optimized_plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - assert_eq!(&starting_schema, optimized_plan.schema()) + Ok(()) + }}; } #[test] @@ -473,16 +479,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -501,16 +506,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -528,16 +532,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -559,15 +562,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -589,15 +592,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -615,15 +618,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -644,19 +647,18 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - let expected = vec![ - "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -691,19 +693,18 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -765,22 +766,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -840,22 +840,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -915,22 +914,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -994,22 +992,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1083,21 +1080,20 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1177,20 +1173,19 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1208,15 +1203,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1235,16 +1230,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1263,16 +1257,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1291,16 +1284,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1328,18 +1320,17 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 6a5b29062e948..a6651df938a70 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -120,6 +120,7 @@ mod tests { use super::*; use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; @@ -128,9 +129,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(EliminateDuplicatedExpr::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateDuplicatedExpr::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index db2136e5e4e5e..452df6e8331f8 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -82,6 +82,7 @@ mod tests { use std::sync::Arc; use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr}; @@ -94,9 +95,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(EliminateFilter::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateFilter::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index bd5e6910201cc..604f083b37090 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -117,6 +117,7 @@ mod tests { use super::*; use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -135,9 +136,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(EliminateGroupByConstant::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateGroupByConstant::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index bac82a2ee1316..2aad889b2fcbe 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -76,6 +76,7 @@ impl OptimizerRule for EliminateJoin { mod tests { use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_join::EliminateJoin; + use crate::OptimizerContext; use datafusion_common::Result; use datafusion_expr::JoinType::Inner; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; @@ -84,11 +85,13 @@ mod tests { macro_rules! assert_optimized_plan_equal { ( $plan:expr, - @$expected:literal $(,)? + @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(EliminateJoin::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateJoin::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index fe835afbaa542..f8f93727cd9ba 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -117,6 +117,7 @@ mod tests { use crate::analyzer::type_coercion::TypeCoercion; use crate::analyzer::Analyzer; use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; @@ -137,9 +138,11 @@ mod tests { let options = ConfigOptions::default(); let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) .execute_and_check($plan, &options, |_, _| {})?; - let rule: Arc = Arc::new(EliminateNestedUnion::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateNestedUnion::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, analyzed_plan, @ $expected, ) diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 704a9e7e53414..621086e4a28a9 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -306,6 +306,7 @@ mod tests { use super::*; use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_expr::{ binary_expr, cast, col, lit, @@ -317,11 +318,13 @@ mod tests { macro_rules! assert_optimized_plan_equal { ( $plan:expr, - @$expected:literal $(,)? + @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(EliminateOuterJoin::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateOuterJoin::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 314b439cb51ee..14a424b32687f 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -108,6 +108,7 @@ fn create_not_null_predicate(filters: Vec) -> Expr { mod tests { use super::*; use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; @@ -118,9 +119,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(FilterNullJoinKeys {}); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(FilterNullJoinKeys {})]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index a443c4cc81ef3..1093c0b3cf691 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -847,9 +847,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(OptimizeProjections::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(OptimizeProjections::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 10821f08fbf1c..4fb9e117e2afc 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -249,6 +249,7 @@ mod tests { assert_optimized_plan_with_rules, test_table_scan, test_table_scan_fields, test_table_scan_with_name, }; + use crate::OptimizerContext; use super::*; @@ -257,9 +258,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(PropagateEmptyRelation::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PropagateEmptyRelation::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 757a1b1646e67..bbf0b0dd810e7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1418,9 +1418,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(PushDownFilter::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownFilter::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 0ed4e05d8594f..ec042dd350ca1 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -279,6 +279,7 @@ mod test { use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::DFSchemaRef; use datafusion_expr::{ col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, @@ -291,9 +292,11 @@ mod test { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(PushDownLimit::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownLimit::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index c7c9d03a51ae7..2383787fa0e8a 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -190,6 +190,7 @@ mod tests { use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::Result; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr}; use datafusion_functions_aggregate::sum::sum; @@ -199,9 +200,11 @@ mod tests { $plan:expr, @ $expected:literal $(,)? ) => {{ - let rule: Arc = Arc::new(ReplaceDistinctWithAggregate::new()); + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(ReplaceDistinctWithAggregate::new())]; assert_optimized_plan_eq_snapshot!( - rule, + optimizer_ctx, + rules, $plan, @ $expected, ) diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 20e6d2b61252d..6e0b734bb9280 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -147,21 +147,6 @@ macro_rules! assert_optimized_plan_eq_snapshot { Ok::<(), datafusion_common::DataFusionError>(()) }}; - - ( - $rule:expr, - $plan:expr, - @ $expected:literal $(,)? - ) => {{ - // Apply the rule once - let opt_context = $crate::OptimizerContext::new().with_max_passes(1); - - let optimizer = $crate::Optimizer::with_rules(vec![Arc::clone(&$rule)]); - let optimized_plan = optimizer.optimize($plan, &opt_context, |_, _| {})?; - insta::assert_snapshot!(optimized_plan, @ $expected); - - Ok::<(), datafusion_common::DataFusionError>(()) - }}; } fn generate_optimized_plan_with_rules(