Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use datafusion::datasource::{TableProvider, TableType, ViewTable};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::logical_expr::Expr;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::optimizer::eliminate_cross_join::EliminateCrossJoin;
use datafusion::optimizer::eliminate_duplicated_expr::EliminateDuplicatedExpr;
use datafusion::optimizer::eliminate_filter::EliminateFilter;
use datafusion::optimizer::eliminate_group_by_constant::EliminateGroupByConstant;
Expand Down Expand Up @@ -247,7 +246,8 @@ fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
// Arc::new(SimplifyExpressions::new()),
Arc::new(EliminateDuplicatedExpr::new()),
Arc::new(EliminateFilter::new()),
Arc::new(EliminateCrossJoin::new()),
// Disable EliminateCrossJoin to avoid generate invalid sql (expression should be rebased manually)
// Arc::new(EliminateCrossJoin::new()),
// Disable CommonSubexprEliminate to avoid generate invalid projection plan
// Arc::new(CommonSubexprEliminate::new()),
// Arc::new(EliminateLimit::new()),
Expand Down
37 changes: 37 additions & 0 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3774,6 +3774,43 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_disable_eliminate_cross_join() -> Result<()> {
let ctx = create_wren_ctx(None);

// test required property
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("customer")
.table_reference("customer")
.column(ColumnBuilder::new("c_nationkey", "int").build())
.column(ColumnBuilder::new("c_name", "string").build())
.build(),
)
.model(
ModelBuilder::new("nation")
.table_reference("nation")
.column(ColumnBuilder::new("n_nationkey", "int").build())
.column(ColumnBuilder::new("n_name", "string").build())
.build(),
)
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(
manifest,
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let sql = "SELECT * FROM customer, nation WHERE customer.c_nationkey = nation.n_nationkey";
let headers = Arc::new(HashMap::default());
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?,
@"SELECT customer.c_nationkey, customer.c_name, nation.n_nationkey, nation.n_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer CROSS JOIN (SELECT nation.n_name, nation.n_nationkey FROM (SELECT __source.n_name AS n_name, __source.n_nationkey AS n_nationkey FROM nation AS __source) AS nation) AS nation WHERE customer.c_nationkey = nation.n_nationkey"
);
Ok(())
}

/// Return a RecordBatch with made up data about customer
fn customer() -> RecordBatch {
let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
Expand Down