diff --git a/wren-core/core/src/mdl/dialect/inner_dialect.rs b/wren-core/core/src/mdl/dialect/inner_dialect.rs index 24a85e5b5..73f019878 100644 --- a/wren-core/core/src/mdl/dialect/inner_dialect.rs +++ b/wren-core/core/src/mdl/dialect/inner_dialect.rs @@ -24,7 +24,10 @@ use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; use datafusion::logical_expr::Expr; use datafusion::scalar::ScalarValue; -use datafusion::sql::sqlparser::ast::{self, ExtractSyntax, Ident, WindowFrameBound}; +use datafusion::sql::sqlparser::ast::{ + self, DataType, DateTimeField, Expr as AstExpr, ExtractSyntax, Function, + FunctionArg, FunctionArgExpr, Ident, Interval, TimezoneInfo, Value, WindowFrameBound, +}; use datafusion::sql::unparser::Unparser; use regex::Regex; @@ -135,18 +138,223 @@ impl InnerDialect for BigQueryDialect { ) -> Result> { match function_name { "date_part" => { - if args.len() != 2 { + if args.len() != 2 && args.len() != 3 { return plan_err!( - "date_part requires exactly 2 arguments, found {}", + "date_part requires 2 or 3 arguments, found {}", args.len() ); } + // Base timestamp/datetime expression + let mut source_expr = unparser.expr_to_sql(&args[1])?; + // Apply timezone if provided as 3rd arg + if args.len() == 3 { + if let Expr::Literal(ScalarValue::Utf8(Some(tz))) = &args[2] { + source_expr = AstExpr::AtTimeZone { + timestamp: Box::new(source_expr), + time_zone: TimezoneInfo::Tz(tz.clone()), + }; + } + } Ok(Some(ast::Expr::Extract { field: self.datetime_field_from_expr(&args[0])?, syntax: ExtractSyntax::From, - expr: Box::new(unparser.expr_to_sql(&args[1])?), + expr: Box::new(source_expr), })) } + "date_trunc" | "datetime_trunc" | "timestamp_trunc" | "time_trunc" => { + if args.len() != 2 { + return plan_err!( + "{} requires exactly 2 arguments, found {}", + function_name, + args.len() + ); + } + Ok(Some(AstExpr::Function(Function { + name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]), + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[1])?, + )), + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[0])?, + )), + ], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: Vec::new(), + }))) + } + "date_add" | "datetime_add" | "timestamp_add" | "time_add" | "date_sub" + | "datetime_sub" | "timestamp_sub" | "time_sub" => { + if args.len() != 2 { + return plan_err!( + "{} requires exactly 2 arguments, found {}", + function_name, + args.len() + ); + } + + let interval_expr = match &args[1] { + Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))) => { + let (days, ms) = (*interval >> 32, *interval as i32); + let use_day_unit = matches!(function_name, "date_add" | "date_sub"); + let (value_str, unit) = if use_day_unit { + (format!("{}", days), DateTimeField::Day) + } else { + ( + format!("{}", days * 24 * 3600 * 1000 + ms as i64), + DateTimeField::Millisecond, + ) + }; + AstExpr::Value(Value::Interval(Interval { + value: Box::new(AstExpr::Value(Value::Number(value_str, false))), + leading_field: Some(unit), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + })) + } + Expr::Literal(ScalarValue::IntervalYearMonth(Some(interval))) => { + let (years, months) = (*interval / 12, *interval % 12); + if function_name.starts_with("time_") { + return plan_err!( + "Cannot add/subtract YEAR/MONTH interval to/from a TIME value" + ); + } + AstExpr::Value(Value::Interval(Interval { + value: Box::new(AstExpr::Value(Value::Number( + format!("{}", years * 12 + months), + false, + ))), + leading_field: Some(DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + })) + } + _ => return plan_err!("Invalid interval for {}", function_name), + }; + + Ok(Some(AstExpr::Function(Function { + name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]), + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[0])?, + )), + FunctionArg::Unnamed(FunctionArgExpr::Expr(interval_expr)), + ], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: Vec::new(), + }))) + } + "date_diff" | "datetime_diff" | "timestamp_diff" | "time_diff" => { + if args.len() != 3 { + return plan_err!( + "{} requires exactly 3 arguments, found {}", + function_name, + args.len() + ); + } + Ok(Some(AstExpr::Function(Function { + name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]), + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[1])?, + )), + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[2])?, + )), + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[0])?, + )), + ], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: Vec::new(), + }))) + } + "parse_date" | "parse_datetime" | "parse_timestamp" | "format_date" + | "format_datetime" | "format_timestamp" => { + if args.len() != 2 { + return plan_err!( + "{} requires exactly 2 arguments, found {}", + function_name, + args.len() + ); + } + Ok(Some(AstExpr::Function(Function { + name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]), + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[0])?, + )), + FunctionArg::Unnamed(FunctionArgExpr::Expr( + unparser.expr_to_sql(&args[1])?, + )), + ], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: Vec::new(), + }))) + } + "current_date" | "current_datetime" | "current_timestamp" => { + if !args.is_empty() { + return plan_err!( + "{} requires no arguments, found {}", + function_name, + args.len() + ); + } + Ok(Some(AstExpr::Function(Function { + name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]), + args: vec![], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: Vec::new(), + }))) + } + "generate_date_array" => { + if args.len() != 2 && args.len() != 3 { + return plan_err!( + "generate_date_array requires 2 or 3 arguments, found {}", + args.len() + ); + } + let mut fn_args = vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[0])?)), + FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[1])?)), + ]; + if args.len() == 3 { + fn_args.push(FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[2])?))); + } + + Ok(Some(AstExpr::Function(Function { + name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]), + args: fn_args, + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: Vec::new(), + }))) + } _ => Ok(None), } } @@ -203,9 +411,11 @@ impl BigQueryDialect { if let Some(end) = s.find(')') { let weekday = &s[start + 1..end]; match weekday { - "SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY" + "SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY" | "THURSDAY" | "FRIDAY" | "SATURDAY" => { - return Ok(ast::DateTimeField::Week(Some(Ident::new(weekday)))); + return Ok(ast::DateTimeField::Week(Some(Ident::new( + weekday, + )))); } _ => return plan_err!("Invalid weekday '{}' for WEEK. Valid values are SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, and SATURDAY", weekday), } @@ -224,6 +434,11 @@ impl BigQueryDialect { "QUARTER" => Ok(ast::DateTimeField::Quarter), "YEAR" => Ok(ast::DateTimeField::Year), "ISOYEAR" => Ok(ast::DateTimeField::Isoyear), + "HOUR" => Ok(ast::DateTimeField::Hour), + "MINUTE" => Ok(ast::DateTimeField::Minute), + "SECOND" => Ok(ast::DateTimeField::Second), + "MILLISECOND" => Ok(ast::DateTimeField::Millisecond), + "MICROSECOND" => Ok(ast::DateTimeField::Microsecond), _ => { plan_err!("Unsupported date part '{}' for BigQuery", s) } @@ -251,4 +466,4 @@ impl InnerDialect for OracleDialect { fn non_uppercase(sql: &str) -> bool { let uppsercase = sql.to_uppercase(); uppsercase != sql -} +} \ No newline at end of file diff --git a/wren-core/core/src/mdl/dialect/wren_dialect.rs b/wren-core/core/src/mdl/dialect/wren_dialect.rs index e9bb406e3..566251615 100644 --- a/wren-core/core/src/mdl/dialect/wren_dialect.rs +++ b/wren-core/core/src/mdl/dialect/wren_dialect.rs @@ -18,6 +18,7 @@ */ use crate::mdl::dialect::inner_dialect::{get_inner_dialect, InnerDialect}; use crate::mdl::manifest::DataSource; +use crate::mdl::utils::scalar_value_to_ast_value; use datafusion::common::{internal_err, plan_err, Result, ScalarValue}; use datafusion::logical_expr::sqlparser::ast::{Ident, Subscript}; use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; @@ -84,7 +85,24 @@ impl Dialect for WrenDialect { let sql = self.named_struct_to_sql(args, unparser)?; Ok(Some(sql)) } - _ => Ok(None), + _ => { + if func_name == "lit" { + if args.len() != 1 { + return plan_err!("lit requires exactly 1 argument"); + } + match &args[0] { + Expr::Literal(value) => { + Ok(Some(ast::Expr::Value(scalar_value_to_ast_value(value)))) + } + other => { + // Fall back to the expression itself to avoid emitting `lit(...)` in SQL + Ok(Some(unparser.expr_to_sql(other)?)) + } + } + } else { + Ok(None) + } + } } } @@ -218,4 +236,4 @@ impl WrenDialect { fn non_lowercase(sql: &str) -> bool { let lowercase = sql.to_lowercase(); lowercase != sql -} +} \ No newline at end of file diff --git a/wren-core/core/src/mdl/utils.rs b/wren-core/core/src/mdl/utils.rs index ebc54fa47..457fb0a15 100644 --- a/wren-core/core/src/mdl/utils.rs +++ b/wren-core/core/src/mdl/utils.rs @@ -69,7 +69,7 @@ pub fn collect_identifiers(expr: &str) -> Result> { pub fn qualify_name_from_column_name(column: &Column) -> String { column .flat_name() - .split(".") + .split('.') .map(quoted) .collect::>() .join(".") @@ -256,6 +256,123 @@ fn collect_columns(expr: datafusion::logical_expr::sqlparser::ast::Expr) -> Vec< visited } +use chrono::{Duration, NaiveDate, NaiveDateTime}; +use datafusion::scalar::ScalarValue; +use datafusion::sql::sqlparser::ast; + +pub fn scalar_value_to_ast_value(value: &ScalarValue) -> ast::Value { + match value { + ScalarValue::Null => ast::Value::Null, + ScalarValue::Boolean(Some(b)) => ast::Value::Boolean(*b), + ScalarValue::Float32(Some(f)) => ast::Value::Number(f.to_string(), false), + ScalarValue::Float64(Some(d)) => ast::Value::Number(d.to_string(), false), + ScalarValue::Int8(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::Int16(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::Int32(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::Int64(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::UInt8(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::UInt16(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::UInt32(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::UInt64(Some(i)) => ast::Value::Number(i.to_string(), false), + ScalarValue::Utf8(Some(s)) => ast::Value::SingleQuotedString(s.clone()), + ScalarValue::LargeUtf8(Some(s)) => ast::Value::SingleQuotedString(s.clone()), + ScalarValue::Date32(Some(days)) => { + // Date32 is days since UNIX epoch (1970-01-01) + match NaiveDate::from_ymd_opt(1970, 1, 1) + .and_then(|epoch| epoch.checked_add_signed(Duration::days(*days as i64))) + { + Some(date) => ast::Value::SingleQuotedString(date.format("%Y-%m-%d").to_string()), + None => ast::Value::Null, + } + } + ScalarValue::Date64(Some(ms)) => { + match NaiveDateTime::from_timestamp_millis(*ms) { + Some(dt) => ast::Value::SingleQuotedString(dt.date().format("%Y-%m-%d").to_string()), + None => ast::Value::Null, + } + } + ScalarValue::TimestampSecond(Some(s), tz) => { + match NaiveDateTime::from_timestamp_opt(*s, 0) { + Some(dt) => { + let formatted = dt.format("%Y-%m-%d %H:%M:%S").to_string(); + if tz.is_some() { + ast::Value::SingleQuotedString(format!("{}Z", formatted)) + } else { + ast::Value::SingleQuotedString(formatted) + } + } + None => ast::Value::Null, + } + } + ScalarValue::TimestampMillisecond(Some(ms), tz) => { + match NaiveDateTime::from_timestamp_millis(*ms) { + Some(dt) => { + let formatted = dt.format("%Y-%m-%d %H:%M:%S.%3f").to_string(); + if tz.is_some() { + ast::Value::SingleQuotedString(format!("{}Z", formatted)) + } else { + ast::Value::SingleQuotedString(formatted) + } + } + None => ast::Value::Null, + } + } + ScalarValue::TimestampMicrosecond(Some(us), tz) => { + match NaiveDateTime::from_timestamp_micros(*us) { + Some(dt) => { + let formatted = dt.format("%Y-%m-%d %H:%M:%S.%6f").to_string(); + if tz.is_some() { + ast::Value::SingleQuotedString(format!("{}Z", formatted)) + } else { + ast::Value::SingleQuotedString(formatted) + } + } + None => ast::Value::Null, + } + } + ScalarValue::TimestampNanosecond(Some(ns), _) => { + let secs = ns.div_euclid(1_000_000_000); + let nanos = ns.rem_euclid(1_000_000_000) as u32; + match NaiveDateTime::from_timestamp_opt(secs, nanos) { + Some(dt) => ast::Value::SingleQuotedString( + dt.format("%Y-%m-%d %H:%M:%S.%9f").to_string() + ), + None => ast::Value::Null, + } + } + + // Explicitly map None for all Option-bearing scalar types to SQL NULL + ScalarValue::Boolean(None) + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::UInt8(None) + | ScalarValue::UInt16(None) + | ScalarValue::UInt32(None) + | ScalarValue::UInt64(None) + | ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Decimal128(None, _, _) + | ScalarValue::Decimal256(None, _, _) + | ScalarValue::Time32Second(None) + | ScalarValue::Time32Millisecond(None) + | ScalarValue::Time64Microsecond(None) + | ScalarValue::Time64Nanosecond(None) + | ScalarValue::Date32(None) + | ScalarValue::Date64(None) + | ScalarValue::TimestampSecond(None, _) + | ScalarValue::TimestampMillisecond(None, _) + | ScalarValue::TimestampMicrosecond(None, _) + | ScalarValue::TimestampNanosecond(None, _) => ast::Value::Null, + // Fallback for any other types to avoid panicking for non-None + _ => ast::Value::SingleQuotedString(value.to_string()), + } +} + + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -404,4 +521,4 @@ mod tests { assert_eq!(expr.to_string(), "customer.c_name"); Ok(()) } -} +} \ No newline at end of file diff --git a/wren-core/sqllogictest/test_files/bigquery_features.slt b/wren-core/sqllogictest/test_files/bigquery_features.slt new file mode 100644 index 000000000..a97c36fe7 --- /dev/null +++ b/wren-core/sqllogictest/test_files/bigquery_features.slt @@ -0,0 +1,187 @@ +-- sqllogictest/test_files/bigquery_features.slt + +-- +-- Section 1: Temporal Functions and Bug Fixes +-- + +-- Test: Validate the original bug fix for NULL date literals +statement ok +SELECT CAST(NULL AS DATE) + +---- +SELECT NULL + +-- Test: Validate the defensive lit() function override +statement ok +SELECT lit(123), lit('abc') + +---- +SELECT 123, 'abc' + +-- Test: Validate lit() pass-through for non-literal expression +statement ok +SELECT lit(1 + 2) + +---- +SELECT 1 + 2 + +-- Test: lit() with no args should error +statement error +SELECT lit() + +-- Test: lit() with too many args should error +statement error +SELECT lit(1, 2) + +-- Test: Validate correct argument reordering for DATE_DIFF (start, end, part) +statement ok +SELECT DATE_DIFF(CAST('2025-01-01' AS DATE), CAST('2025-01-15' AS DATE), DAY) + +---- +SELECT DATE_DIFF(CAST('2025-01-01' AS DATE), CAST('2025-01-15' AS DATE), DAY) + +-- Test: Validate correct timezone placement for EXTRACT (inside the function) +statement ok +SELECT EXTRACT(HOUR FROM CAST('2025-08-18 12:00:00' AS TIMESTAMP) AT TIME ZONE 'America/New_York') + +---- +SELECT EXTRACT(HOUR FROM CAST('2025-08-18 12:00:00' AS TIMESTAMP) AT TIME ZONE 'America/New_York') + +-- Test: Validate WEEK(WEEKDAY) parsing and unparsing round-trip +statement ok +SELECT EXTRACT(WEEK(MONDAY) FROM CAST('2025-08-18' AS DATE)) + +---- +SELECT EXTRACT(WEEK(MONDAY) FROM CAST('2025-08-18' AS DATE)) + +-- Test: DATE_ADD with a DAY interval +statement ok +SELECT DATE_ADD(CAST('2025-01-01' AS DATE), INTERVAL 5 DAY) + +---- +SELECT DATE_ADD(CAST('2025-01-01' AS DATE), INTERVAL 5 DAY) + +-- Test: TIMESTAMP_ADD with a MONTH interval +statement ok +SELECT TIMESTAMP_ADD(CAST('2025-01-01 00:00:00' AS TIMESTAMP), INTERVAL 2 MONTH) + +---- +SELECT TIMESTAMP_ADD(CAST('2025-01-01 00:00:00' AS TIMESTAMP), INTERVAL 2 MONTH) + +-- Test: DATE_TRUNC +statement ok +SELECT DATE_TRUNC(CAST('2025-08-18' AS DATE), MONTH) + +---- +SELECT DATE_TRUNC(CAST('2025-08-18' AS DATE), MONTH) + +-- Test: PARSE_DATE +statement ok +SELECT PARSE_DATE('%Y%m%d', '20250818') + +---- +SELECT PARSE_DATE('%Y%m%d', '20250818') + +-- Test: FORMAT_TIMESTAMP +statement ok +SELECT FORMAT_TIMESTAMP('%Y-%m-%d %H:%M:%S', CAST('2025-08-18 12:30:00' AS TIMESTAMP)) + +---- +SELECT FORMAT_TIMESTAMP('%Y-%m-%d %H:%M:%S', CAST('2025-08-18 12:30:00' AS TIMESTAMP)) + +-- Test: CURRENT_DATE +statement ok +SELECT CURRENT_DATE() + +---- +SELECT CURRENT_DATE() + +-- Test: GENERATE_DATE_ARRAY (wrapped in UNNEST as is common) +statement ok +SELECT * FROM UNNEST(GENERATE_DATE_ARRAY('2025-01-01', '2025-01-05')) + +---- +SELECT * FROM UNNEST(GENERATE_DATE_ARRAY('2025-01-01', '2025-01-05')) + +-- Test: High-precision timestamp formatting (microseconds) +statement ok +SELECT CAST('2025-08-18 12:30:00.123456' AS TIMESTAMP) + +---- +SELECT CAST('2025-08-18 12:30:00.123456' AS TIMESTAMP) + +-- Test: High-precision timestamp formatting (nanoseconds) - will be rounded to micros by BigQuery +statement ok +SELECT CAST('2025-08-18 12:30:00.123456789' AS TIMESTAMP) + +---- +SELECT CAST('2025-08-18 12:30:00.123457' AS TIMESTAMP) + + +-- +-- Section 2: Nullability and Data Types +-- + +-- Test: Validate NULL handling for various types via lit() +statement ok +SELECT lit(CAST(NULL AS INT)), lit(CAST(NULL AS BOOLEAN)), lit(CAST(NULL AS VARCHAR)) + +---- +SELECT NULL, NULL, NULL + +-- Test: Validate NULL handling for temporal types via lit() +statement ok +SELECT lit(CAST(NULL AS DATE)), lit(CAST(NULL AS TIMESTAMP)) + +---- +SELECT NULL, NULL + +-- +-- Section 3: Column Aliasing +-- + +-- Test: Validate column alias override for special characters +statement ok +SELECT 1 AS `!@#$%^&*()` + +---- +SELECT 1 AS `_33_64_35_36_37_94_38_42_40_41` + +-- +-- Section 4: Advanced Array/Struct Operations +-- + +-- Test: Validate ARRAY constructor (make_array) +statement ok +SELECT make_array(1, 2, 3) + +---- +SELECT [1, 2, 3] + +-- Test: Validate STRUCT constructor (named_struct) and field access +statement ok +SELECT named_struct('a', 1, 'b', 'hello').a + +---- +SELECT STRUCT(1 AS a, 'hello' AS b).a + +-- Test: UNNEST table factor on a primitive array +statement ok +SELECT element FROM UNNEST([10, 20, 30]) AS element + +---- +SELECT element FROM UNNEST([10, 20, 30]) AS element + +-- Test: UNNEST over an array of STRUCTs with field projection +statement ok +SELECT item.id FROM UNNEST([STRUCT(1 AS id), STRUCT(2 AS id)]) AS item + +---- +SELECT item.id FROM UNNEST([STRUCT(1 AS id), STRUCT(2 AS id)]) AS item + +-- Test: UNNEST with WITH OFFSET +statement ok +SELECT element, offset FROM UNNEST(['a', 'b', 'c']) AS element WITH OFFSET + +---- +SELECT element, offset FROM UNNEST(['a', 'b', 'c']) AS element WITH OFFSET \ No newline at end of file