diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index b6c9b92ce..c504c039f 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -23,7 +23,6 @@ use datafusion::datasource::{TableProvider, TableType, ViewTable}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::Expr; use datafusion::optimizer::analyzer::type_coercion::TypeCoercion; -use datafusion::optimizer::eliminate_cross_join::EliminateCrossJoin; use datafusion::optimizer::eliminate_duplicated_expr::EliminateDuplicatedExpr; use datafusion::optimizer::eliminate_filter::EliminateFilter; use datafusion::optimizer::eliminate_group_by_constant::EliminateGroupByConstant; @@ -247,7 +246,8 @@ fn optimize_rule_for_unparsing() -> Vec> { // Arc::new(SimplifyExpressions::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), - Arc::new(EliminateCrossJoin::new()), + // Disable EliminateCrossJoin to avoid generate invalid sql (expression should be rebased manually) + // Arc::new(EliminateCrossJoin::new()), // Disable CommonSubexprEliminate to avoid generate invalid projection plan // Arc::new(CommonSubexprEliminate::new()), // Arc::new(EliminateLimit::new()), diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 850902e81..6ed9c2870 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -3774,6 +3774,43 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_disable_eliminate_cross_join() -> Result<()> { + let ctx = create_wren_ctx(None); + + // test required property + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .build(), + ) + .model( + ModelBuilder::new("nation") + .table_reference("nation") + .column(ColumnBuilder::new("n_nationkey", "int").build()) + .column(ColumnBuilder::new("n_name", "string").build()) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let sql = "SELECT * FROM customer, nation WHERE customer.c_nationkey = nation.n_nationkey"; + let headers = Arc::new(HashMap::default()); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_nationkey, customer.c_name, nation.n_nationkey, nation.n_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer CROSS JOIN (SELECT nation.n_name, nation.n_nationkey FROM (SELECT __source.n_name AS n_name, __source.n_nationkey AS n_nationkey FROM nation AS __source) AS nation) AS nation WHERE customer.c_nationkey = nation.n_nationkey" + ); + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));