diff --git a/wren-core/core/src/mdl/dialect/wren_dialect.rs b/wren-core/core/src/mdl/dialect/wren_dialect.rs index c918078c4..43a0f98f9 100644 --- a/wren-core/core/src/mdl/dialect/wren_dialect.rs +++ b/wren-core/core/src/mdl/dialect/wren_dialect.rs @@ -84,11 +84,17 @@ impl Dialect for WrenDialect { start_bound: &WindowFrameBound, end_bound: &WindowFrameBound, ) -> bool { - self.inner_dialect.window_func_support_window_frame( - func_name, - start_bound, - end_bound, - ) + if matches!(start_bound, WindowFrameBound::Preceding(None)) + && matches!(end_bound, WindowFrameBound::CurrentRow) + { + false + } else { + self.inner_dialect.window_func_support_window_frame( + func_name, + start_bound, + end_bound, + ) + } } } diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 5c4e9dc45..6e3e8aec3 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -3389,6 +3389,43 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_window_function_frame() -> Result<()> { + let ctx = SessionContext::new(); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("orders") + .table_reference("orders") + .column(ColumnBuilder::new("o_orderkey", "int").build()) + .column(ColumnBuilder::new("o_custkey", "int").build()) + .column(ColumnBuilder::new("o_orderdate", "date").build()) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let headers = Arc::new(HashMap::default()); + // assert default won't generate the window frame + let sql = "SELECT rank() OVER (PARTITION BY o_custkey ORDER BY o_orderdate) FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT rank() OVER (PARTITION BY orders.o_custkey ORDER BY orders.o_orderdate ASC NULLS LAST) FROM (SELECT orders.o_custkey, orders.o_orderdate FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + + // assert generate window frame if given + let sql = "SELECT count(*) OVER (PARTITION BY o_custkey ORDER BY o_orderdate ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING) as window_col FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT count(1) OVER (PARTITION BY orders.o_custkey ORDER BY orders.o_orderdate ASC NULLS LAST ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING) AS window_col FROM (SELECT orders.o_custkey, orders.o_orderdate FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + Ok(()) + } + #[tokio::test] async fn test_window_functions_without_frame_bigquery() -> Result<()> { let ctx = SessionContext::new();