Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wren-core-base/src/mdl/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub(crate) fn parse_identifiers_normalized(
})
}

pub fn quote_identifier(s: &str) -> Cow<str> {
pub fn quote_identifier(s: &str) -> Cow<'_, str> {
if needs_quotes(s) {
Cow::Owned(format!("\"{}\"", s.replace('"', "\"\"")))
} else {
Expand Down
14 changes: 8 additions & 6 deletions wren-core/core/src/logical_plan/analyze/model_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ impl ModelGenerationRule {
let rls_filter = filters
.into_iter()
.reduce(|acc, filter| {
if acc.is_none() {
filter
} else if let Some(filter) = filter {
Some(acc.unwrap().and(filter))
if let Some(acc) = acc {
if let Some(filter) = filter {
Some(acc.and(filter))
} else {
Some(acc)
}
} else {
acc
filter
}
})
.flatten();
Expand Down Expand Up @@ -231,7 +233,7 @@ impl ModelGenerationRule {
.build()?;
Ok(Transformed::yes(alias))
} else {
return plan_err!("measures should have an alias");
plan_err!("measures should have an alias")
}
} else if let Some(partial_model) = extension
.node
Expand Down
117 changes: 115 additions & 2 deletions wren-core/core/src/mdl/dialect/inner_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

use crate::mdl::dialect::utils::scalar_function_to_sql_internal;
use crate::mdl::manifest::DataSource;
use datafusion::common::Result;
use datafusion::common::{plan_err, Result};
use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
use datafusion::logical_expr::Expr;

use datafusion::sql::sqlparser::ast;
use datafusion::scalar::ScalarValue;
use datafusion::sql::sqlparser::ast::{self, ExtractSyntax, Ident, WindowFrameBound};
use datafusion::sql::unparser::Unparser;
use regex::Regex;

Expand Down Expand Up @@ -52,6 +53,15 @@ pub trait InnerDialect: Send + Sync {
fn col_alias_overrides(&self, _alias: &str) -> Result<Option<String>> {
Ok(None)
}

fn window_func_support_window_frame(
&self,
_func_name: &str,
_start_bound: &WindowFrameBound,
_end_bound: &WindowFrameBound,
) -> bool {
true
}
}

/// [get_inner_dialect] returns the suitable InnerDialect for the given data source.
Expand Down Expand Up @@ -116,6 +126,109 @@ impl InnerDialect for BigQueryDialect {
Ok(Some(alias.to_string()))
}
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
function_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
match function_name {
"date_part" => {
if args.len() != 2 {
return plan_err!(
"date_part requires exactly 2 arguments, found {}",
args.len()
);
}
Ok(Some(ast::Expr::Extract {
field: self.datetime_field_from_expr(&args[0])?,
syntax: ExtractSyntax::From,
expr: Box::new(unparser.expr_to_sql(&args[1])?),
}))
}
_ => Ok(None),
}
}

/// BigQuery only allow the aggregation function with window frame.
/// Other [window functions](https://cloud.google.com/bigquery/docs/reference/standard-sql/window-functions) are not supported.
fn window_func_support_window_frame(
&self,
func_name: &str,
_start_bound: &WindowFrameBound,
_end_bound: &WindowFrameBound,
) -> bool {
!matches!(
func_name,
"cume_dist"
| "dense_rank"
| "first_value"
| "lag"
| "last_value"
| "lead"
| "nth_value"
| "ntile"
| "percent_rank"
| "percentile_cont"
| "percentile_disc"
| "rank"
| "row_number"
| "st_clusterdbscan"
)
}
}

impl BigQueryDialect {
fn datetime_field_from_expr(&self, expr: &Expr) -> Result<ast::DateTimeField> {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(s)))
| Expr::Literal(ScalarValue::LargeUtf8(Some(s))) => {
Ok(self.datetime_field_from_str(s)?)
}
_ => plan_err!(
"Invalid argument type for datetime field. Expected UTF8 string."
),
}
}

/// BigQuery supports only the following date part
/// <https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#extract>
fn datetime_field_from_str(&self, s: &str) -> Result<ast::DateTimeField> {
let s = s.to_uppercase();
if s.starts_with("WEEK") {
if s.len() > 4 {
// Parse WEEK(MONDAY) format
if let Some(start) = s.find('(') {
if let Some(end) = s.find(')') {
let weekday = &s[start + 1..end];
match weekday {
"SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY"
| "THURSDAY" | "FRIDAY" | "SATURDAY" => {
return Ok(ast::DateTimeField::Week(Some(Ident::new(weekday))));
}
_ => return plan_err!("Invalid weekday '{}' for WEEK. Valid values are SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, and SATURDAY", weekday),
}
}
}
return plan_err!("Invalid WEEK format '{}'. Expected WEEK(WEEKDAY)", s);
}
return Ok(ast::DateTimeField::Week(None));
}
match s.as_str() {
"DAYOFWEEK" => Ok(ast::DateTimeField::DayOfWeek),
"DAY" => Ok(ast::DateTimeField::Day),
"DAYOFYEAR" => Ok(ast::DateTimeField::DayOfYear),
"ISOWEEK" => Ok(ast::DateTimeField::IsoWeek),
"MONTH" => Ok(ast::DateTimeField::Month),
"QUARTER" => Ok(ast::DateTimeField::Quarter),
"YEAR" => Ok(ast::DateTimeField::Year),
"ISOYEAR" => Ok(ast::DateTimeField::Isoyear),
_ => {
plan_err!("Unsupported date part '{}' for BigQuery", s)
}
}
}
}

pub struct OracleDialect {}
Expand Down
15 changes: 14 additions & 1 deletion wren-core/core/src/mdl/dialect/wren_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion::common::{internal_err, plan_err, Result, ScalarValue};
use datafusion::logical_expr::sqlparser::ast::{Ident, Subscript};
use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
use datafusion::logical_expr::Expr;
use datafusion::sql::sqlparser::ast;
use datafusion::sql::sqlparser::ast::{self, WindowFrameBound};
use datafusion::sql::sqlparser::ast::{AccessExpr, Array, Value};
use datafusion::sql::sqlparser::tokenizer::Span;
use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle};
Expand Down Expand Up @@ -95,6 +95,19 @@ impl Dialect for WrenDialect {
fn col_alias_overrides(&self, alias: &str) -> Result<Option<String>> {
self.inner_dialect.col_alias_overrides(alias)
}

fn window_func_support_window_frame(
&self,
func_name: &str,
start_bound: &WindowFrameBound,
end_bound: &WindowFrameBound,
) -> bool {
self.inner_dialect.window_func_support_window_frame(
func_name,
start_bound,
end_bound,
)
}
}

impl Default for WrenDialect {
Expand Down
4 changes: 3 additions & 1 deletion wren-core/core/src/mdl/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,9 @@ mod test {
DataType::List(Arc::new(Field::new("element", DataType::Int32, false)));
assert_eq!(udf.name, "test");
assert_eq!(
udf.return_type.to_data_type(&[list_type.clone()]).unwrap(),
udf.return_type
.to_data_type(std::slice::from_ref(&list_type))
.unwrap(),
DataType::Int32
);
assert_eq!(
Expand Down
89 changes: 89 additions & 0 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2930,6 +2930,95 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_extract_roundtrip_bigquery() -> Result<()> {
let ctx = SessionContext::new();
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("orders")
.table_reference("orders")
.column(ColumnBuilder::new("o_orderdate", "date").build())
.build(),
)
.data_source(DataSource::BigQuery)
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(
manifest,
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let headers = Arc::new(HashMap::default());
let sql = "SELECT EXTRACT(YEAR FROM o_orderdate) FROM orders";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@"SELECT EXTRACT(YEAR FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders"
);

let sql = "SELECT EXTRACT(WEEK FROM o_orderdate) FROM orders";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@"SELECT EXTRACT(WEEK FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders"
);

let sql = "SELECT EXTRACT(WEEK(MONDAY) FROM o_orderdate) FROM orders";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@"SELECT EXTRACT(WEEK(MONDAY) FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders"
);

let sql = "SELECT EXTRACT(WEEK(NOTFOUND) FROM o_orderdate) FROM orders";
match transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
&[],
Arc::clone(&headers),
sql,
)
.await
{
Ok(_) => {
panic!("Expected error, but got SQL");
}
Err(e) => assert_snapshot!(
e.to_string(),
@"Error during planning: Invalid weekday 'NOTFOUND' for WEEK. Valid values are SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, and SATURDAY"
),
}
Ok(())
}

#[tokio::test]
async fn test_window_functions_without_frame_bigquery() -> Result<()> {
let ctx = SessionContext::new();
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("orders")
.table_reference("orders")
.column(ColumnBuilder::new("o_orderkey", "int").build())
.column(ColumnBuilder::new("o_custkey", "int").build())
.column(ColumnBuilder::new("o_orderdate", "date").build())
.build(),
)
.data_source(DataSource::BigQuery)
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(
manifest,
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let headers = Arc::new(HashMap::default());
let sql = "SELECT rank() OVER (PARTITION BY o_custkey ORDER BY o_orderdate) FROM orders";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@"SELECT rank() OVER (PARTITION BY orders.o_custkey ORDER BY orders.o_orderdate ASC NULLS LAST) FROM (SELECT orders.o_custkey, orders.o_orderdate FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders"
);
Ok(())
}

/// Return a RecordBatch with made up data about customer
fn customer() -> RecordBatch {
let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
Expand Down
Loading