diff --git a/wren-core-base/src/mdl/utils.rs b/wren-core-base/src/mdl/utils.rs index af8df9eed..6dfd0b0e0 100644 --- a/wren-core-base/src/mdl/utils.rs +++ b/wren-core-base/src/mdl/utils.rs @@ -43,7 +43,7 @@ pub(crate) fn parse_identifiers_normalized( }) } -pub fn quote_identifier(s: &str) -> Cow { +pub fn quote_identifier(s: &str) -> Cow<'_, str> { if needs_quotes(s) { Cow::Owned(format!("\"{}\"", s.replace('"', "\"\""))) } else { diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index c0e2c9dbc..972eb9bab 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -97,12 +97,14 @@ impl ModelGenerationRule { let rls_filter = filters .into_iter() .reduce(|acc, filter| { - if acc.is_none() { - filter - } else if let Some(filter) = filter { - Some(acc.unwrap().and(filter)) + if let Some(acc) = acc { + if let Some(filter) = filter { + Some(acc.and(filter)) + } else { + Some(acc) + } } else { - acc + filter } }) .flatten(); @@ -231,7 +233,7 @@ impl ModelGenerationRule { .build()?; Ok(Transformed::yes(alias)) } else { - return plan_err!("measures should have an alias"); + plan_err!("measures should have an alias") } } else if let Some(partial_model) = extension .node diff --git a/wren-core/core/src/mdl/dialect/inner_dialect.rs b/wren-core/core/src/mdl/dialect/inner_dialect.rs index be203e527..24a85e5b5 100644 --- a/wren-core/core/src/mdl/dialect/inner_dialect.rs +++ b/wren-core/core/src/mdl/dialect/inner_dialect.rs @@ -19,11 +19,12 @@ use crate::mdl::dialect::utils::scalar_function_to_sql_internal; use crate::mdl::manifest::DataSource; -use datafusion::common::Result; +use datafusion::common::{plan_err, Result}; use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; use datafusion::logical_expr::Expr; -use datafusion::sql::sqlparser::ast; +use datafusion::scalar::ScalarValue; +use datafusion::sql::sqlparser::ast::{self, ExtractSyntax, Ident, WindowFrameBound}; use datafusion::sql::unparser::Unparser; use regex::Regex; @@ -52,6 +53,15 @@ pub trait InnerDialect: Send + Sync { fn col_alias_overrides(&self, _alias: &str) -> Result> { Ok(None) } + + fn window_func_support_window_frame( + &self, + _func_name: &str, + _start_bound: &WindowFrameBound, + _end_bound: &WindowFrameBound, + ) -> bool { + true + } } /// [get_inner_dialect] returns the suitable InnerDialect for the given data source. @@ -116,6 +126,109 @@ impl InnerDialect for BigQueryDialect { Ok(Some(alias.to_string())) } } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + function_name: &str, + args: &[Expr], + ) -> Result> { + match function_name { + "date_part" => { + if args.len() != 2 { + return plan_err!( + "date_part requires exactly 2 arguments, found {}", + args.len() + ); + } + Ok(Some(ast::Expr::Extract { + field: self.datetime_field_from_expr(&args[0])?, + syntax: ExtractSyntax::From, + expr: Box::new(unparser.expr_to_sql(&args[1])?), + })) + } + _ => Ok(None), + } + } + + /// BigQuery only allow the aggregation function with window frame. + /// Other [window functions](https://cloud.google.com/bigquery/docs/reference/standard-sql/window-functions) are not supported. + fn window_func_support_window_frame( + &self, + func_name: &str, + _start_bound: &WindowFrameBound, + _end_bound: &WindowFrameBound, + ) -> bool { + !matches!( + func_name, + "cume_dist" + | "dense_rank" + | "first_value" + | "lag" + | "last_value" + | "lead" + | "nth_value" + | "ntile" + | "percent_rank" + | "percentile_cont" + | "percentile_disc" + | "rank" + | "row_number" + | "st_clusterdbscan" + ) + } +} + +impl BigQueryDialect { + fn datetime_field_from_expr(&self, expr: &Expr) -> Result { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(s))) + | Expr::Literal(ScalarValue::LargeUtf8(Some(s))) => { + Ok(self.datetime_field_from_str(s)?) + } + _ => plan_err!( + "Invalid argument type for datetime field. Expected UTF8 string." + ), + } + } + + /// BigQuery supports only the following date part + /// + fn datetime_field_from_str(&self, s: &str) -> Result { + let s = s.to_uppercase(); + if s.starts_with("WEEK") { + if s.len() > 4 { + // Parse WEEK(MONDAY) format + if let Some(start) = s.find('(') { + if let Some(end) = s.find(')') { + let weekday = &s[start + 1..end]; + match weekday { + "SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY" + | "THURSDAY" | "FRIDAY" | "SATURDAY" => { + return Ok(ast::DateTimeField::Week(Some(Ident::new(weekday)))); + } + _ => return plan_err!("Invalid weekday '{}' for WEEK. Valid values are SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, and SATURDAY", weekday), + } + } + } + return plan_err!("Invalid WEEK format '{}'. Expected WEEK(WEEKDAY)", s); + } + return Ok(ast::DateTimeField::Week(None)); + } + match s.as_str() { + "DAYOFWEEK" => Ok(ast::DateTimeField::DayOfWeek), + "DAY" => Ok(ast::DateTimeField::Day), + "DAYOFYEAR" => Ok(ast::DateTimeField::DayOfYear), + "ISOWEEK" => Ok(ast::DateTimeField::IsoWeek), + "MONTH" => Ok(ast::DateTimeField::Month), + "QUARTER" => Ok(ast::DateTimeField::Quarter), + "YEAR" => Ok(ast::DateTimeField::Year), + "ISOYEAR" => Ok(ast::DateTimeField::Isoyear), + _ => { + plan_err!("Unsupported date part '{}' for BigQuery", s) + } + } + } } pub struct OracleDialect {} diff --git a/wren-core/core/src/mdl/dialect/wren_dialect.rs b/wren-core/core/src/mdl/dialect/wren_dialect.rs index d54772195..e9bb406e3 100644 --- a/wren-core/core/src/mdl/dialect/wren_dialect.rs +++ b/wren-core/core/src/mdl/dialect/wren_dialect.rs @@ -22,7 +22,7 @@ use datafusion::common::{internal_err, plan_err, Result, ScalarValue}; use datafusion::logical_expr::sqlparser::ast::{Ident, Subscript}; use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; use datafusion::logical_expr::Expr; -use datafusion::sql::sqlparser::ast; +use datafusion::sql::sqlparser::ast::{self, WindowFrameBound}; use datafusion::sql::sqlparser::ast::{AccessExpr, Array, Value}; use datafusion::sql::sqlparser::tokenizer::Span; use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle}; @@ -95,6 +95,19 @@ impl Dialect for WrenDialect { fn col_alias_overrides(&self, alias: &str) -> Result> { self.inner_dialect.col_alias_overrides(alias) } + + fn window_func_support_window_frame( + &self, + func_name: &str, + start_bound: &WindowFrameBound, + end_bound: &WindowFrameBound, + ) -> bool { + self.inner_dialect.window_func_support_window_frame( + func_name, + start_bound, + end_bound, + ) + } } impl Default for WrenDialect { diff --git a/wren-core/core/src/mdl/function.rs b/wren-core/core/src/mdl/function.rs index 2d96f99ed..37d1457ea 100644 --- a/wren-core/core/src/mdl/function.rs +++ b/wren-core/core/src/mdl/function.rs @@ -598,7 +598,9 @@ mod test { DataType::List(Arc::new(Field::new("element", DataType::Int32, false))); assert_eq!(udf.name, "test"); assert_eq!( - udf.return_type.to_data_type(&[list_type.clone()]).unwrap(), + udf.return_type + .to_data_type(std::slice::from_ref(&list_type)) + .unwrap(), DataType::Int32 ); assert_eq!( diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 8f8eca9a3..4640a0cbf 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -2930,6 +2930,95 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_extract_roundtrip_bigquery() -> Result<()> { + let ctx = SessionContext::new(); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("orders") + .table_reference("orders") + .column(ColumnBuilder::new("o_orderdate", "date").build()) + .build(), + ) + .data_source(DataSource::BigQuery) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let headers = Arc::new(HashMap::default()); + let sql = "SELECT EXTRACT(YEAR FROM o_orderdate) FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT EXTRACT(YEAR FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + + let sql = "SELECT EXTRACT(WEEK FROM o_orderdate) FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT EXTRACT(WEEK FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + + let sql = "SELECT EXTRACT(WEEK(MONDAY) FROM o_orderdate) FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT EXTRACT(WEEK(MONDAY) FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + + let sql = "SELECT EXTRACT(WEEK(NOTFOUND) FROM o_orderdate) FROM orders"; + match transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + Arc::clone(&headers), + sql, + ) + .await + { + Ok(_) => { + panic!("Expected error, but got SQL"); + } + Err(e) => assert_snapshot!( + e.to_string(), + @"Error during planning: Invalid weekday 'NOTFOUND' for WEEK. Valid values are SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, and SATURDAY" + ), + } + Ok(()) + } + + #[tokio::test] + async fn test_window_functions_without_frame_bigquery() -> Result<()> { + let ctx = SessionContext::new(); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("orders") + .table_reference("orders") + .column(ColumnBuilder::new("o_orderkey", "int").build()) + .column(ColumnBuilder::new("o_custkey", "int").build()) + .column(ColumnBuilder::new("o_orderdate", "date").build()) + .build(), + ) + .data_source(DataSource::BigQuery) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let headers = Arc::new(HashMap::default()); + let sql = "SELECT rank() OVER (PARTITION BY o_custkey ORDER BY o_orderdate) FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT rank() OVER (PARTITION BY orders.o_custkey ORDER BY orders.o_orderdate ASC NULLS LAST) FROM (SELECT orders.o_custkey, orders.o_orderdate FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));