diff --git a/ibis-server/app/config.py b/ibis-server/app/config.py index ee36e4224..0de5460a6 100644 --- a/ibis-server/app/config.py +++ b/ibis-server/app/config.py @@ -61,6 +61,11 @@ def update(self, diagnose: bool): def get_remote_function_list_path(self, data_source: str) -> str: if not self.remote_function_list_path: return None + + # The function list has been defined by Wren Core + if data_source in {"bigquery"}: + return None + if data_source in {"local_file", "s3_file", "minio_file", "gcs_file"}: data_source = "duckdb" base_path = os.path.normpath(self.remote_function_list_path) diff --git a/ibis-server/app/mdl/core.py b/ibis-server/app/mdl/core.py index e23d7b750..ba7419da9 100644 --- a/ibis-server/app/mdl/core.py +++ b/ibis-server/app/mdl/core.py @@ -5,9 +5,14 @@ @cache def get_session_context( - manifest_str: str | None, function_path: str, properties: frozenset | None = None + manifest_str: str | None, + function_path: str, + properties: frozenset | None = None, + data_source: str | None = None, ) -> wren_core.SessionContext: - return wren_core.SessionContext(manifest_str, function_path, properties) + return wren_core.SessionContext( + manifest_str, function_path, properties, data_source + ) def get_manifest_extractor(manifest_str: str) -> wren_core.ManifestExtractor: diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index e170190af..3fe7e8af9 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -42,7 +42,9 @@ def __init__( self.properties = properties if experiment: function_path = get_config().get_remote_function_list_path(data_source) - self._rewriter = EmbeddedEngineRewriter(function_path) + self._rewriter = EmbeddedEngineRewriter( + function_path=function_path, data_source=data_source + ) else: self._rewriter = ExternalEngineRewriter(java_engine_connector) @@ -130,7 +132,8 @@ def handle_extract_exception(e: Exception): class EmbeddedEngineRewriter: - def __init__(self, function_path: str): + def __init__(self, function_path: str, data_source: DataSource = None): + self.data_source = data_source self.function_path = function_path @tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL) @@ -140,7 +143,10 @@ async def rewrite( try: processed_properties = self.get_session_properties(properties) session_context = get_session_context( - manifest_str, self.function_path, processed_properties + manifest_str, + self.function_path, + processed_properties, + self.data_source.name if self.data_source else None, ) return await to_thread.run_sync( session_context.transform_sql, @@ -151,12 +157,18 @@ async def rewrite( @tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL) def rewrite_sync( - self, manifest_str: str, sql: str, properties: dict | None = None + self, + manifest_str: str, + sql: str, + properties: dict | None = None, ) -> str: try: processed_properties = self.get_session_properties(properties) session_context = get_session_context( - manifest_str, self.function_path, processed_properties + manifest_str, + self.function_path, + processed_properties, + self.data_source.name if self.data_source else None, ) return session_context.transform_sql(sql) except Exception as e: diff --git a/ibis-server/app/model/error.py b/ibis-server/app/model/error.py index d64b3c16d..59d38b0f0 100644 --- a/ibis-server/app/model/error.py +++ b/ibis-server/app/model/error.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from starlette.status import ( HTTP_404_NOT_FOUND, - HTTP_422_UNPROCESSABLE_ENTITY, + HTTP_422_UNPROCESSABLE_CONTENT, HTTP_500_INTERNAL_SERVER_ERROR, HTTP_501_NOT_IMPLEMENTED, HTTP_502_BAD_GATEWAY, @@ -109,7 +109,7 @@ def get_http_status_code(self) -> int: return HTTP_504_GATEWAY_TIMEOUT case e: if e.value < 100: - return HTTP_422_UNPROCESSABLE_ENTITY + return HTTP_422_UNPROCESSABLE_CONTENT return HTTP_500_INTERNAL_SERVER_ERROR diff --git a/ibis-server/resources/function_list/bigquery.csv b/ibis-server/resources/function_list/bigquery.csv deleted file mode 100644 index e44c70ed5..000000000 --- a/ibis-server/resources/function_list/bigquery.csv +++ /dev/null @@ -1,44 +0,0 @@ -function_type,name,return_type,param_names,param_types,description -aggregate,countif,int64,,boolean,"Counts the rows where a condition is true." -aggregate,any_value,same_as_input,,"any","Returns any arbitrary value from the input values." -scalar,format,text,,"text","Formats values into a string." -scalar,safe_divide,float64,,"float64,float64","Divides two numbers, returning NULL if the divisor is zero." -scalar,safe_multiply,float64,,"float64,float64","Multiplies two numbers, returning NULL if an overflow occurs." -scalar,safe_add,float64,,"float64,float64","Adds two numbers, returning NULL if an overflow occurs." -scalar,safe_subtract,float64,,"float64,float64","Subtracts two numbers, returning NULL if an overflow occurs." -scalar,current_datetime,timestamp,,"","Returns current date and time." -scalar,current_timestamp,timestamptz,,"","Returns current timestamp." -scalar,date_add,date,,"date,interval","Adds a number of day to a date." -scalar,date_sub,date,,"date,interval","Subtracts a specified interval from a date." -scalar,date_diff,int64,,"date,date,granularity","Returns the difference between two dates." -scalar,datediff,int64,,"date,date,granularity","Returns the difference between two dates." -scalar,timestamp_add,timestamp,,"timestamp,granularity","Adds a specified interval to a timestamp." -scalar,timestamp_sub,timestamp,,"timestamp,granularity","Subtracts a specified interval from a timestamp." -scalar,timestamp_diff,int64,,"timestamp,timestamp,granularity","Returns the difference between two timestamps." -scalar,timestamp_trunc,timestamp,,"timestamp,granularity","Truncates a timestamp to a specified granularity." -scalar,timestamp_micros,timestamp,,"int64","Converts the number of microseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,timestamp_millis,timestamp,,"int64","Converts the number of milliseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,timestamp_seconds,timestamp,,"int64","Converts the number of seconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,timestamp,timestamp,,"text","Converts a string to a TIMESTAMP." -scalar,format_date,string,,"string,date","Formats a date according to the specified format string." -scalar,format_timestamp,string,,"string,timestamp","Formats a timestamp according to the specified format string." -scalar,parse_date,date,,"text,text","Parses a date from a string." -scalar,parse_datetime,datetime,,"text,text","Converts a STRING value to a DATETIME value." -scalar,json_query,text,,"json,text","Extracts a JSON value from a JSON string." -scalar,json_value,text,,"json,text","Extracts a scalar JSON value as a string." -scalar,json_query_array,array,,"json,text","Extracts a JSON array from a JSON string." -scalar,json_value_array,array,,"json,text","Extracts an array of scalar JSON values as strings." -scalar,lax_bool,boolean,,"any","Converts a value to boolean with relaxed type checking." -scalar,lax_float64,float64,,"any","Converts a value to float with relaxed type checking." -scalar,lax_int64,int64,,"any","Converts a value to int with relaxed type checking." -scalar,lax_string,text,,"any","Converts a value to text with relaxed type checking." -scalar,bool,boolean,,"any","Converts a JSON value to SQL boolean type." -scalar,float64,float64,,"any","Converts a JSON value to SQL float type." -scalar,int64,int64,,"any","Converts a JSON value to SQL int type." -scalar,string,text,,"any","Converts a JSON value to SQL text type." -scalar,regexp_contains,boolean,,"text,text","Returns TRUE if value is a partial match for the regular expression, regexp." -scalar,datetime,timestamp,,"text","Converts a string to a DATETIME." -scalar,datetime_add,datetime,,"datetime,interval","Adds a specified interval to a datetime." -scalar,datetime_sub,datetime,,"datetime,interval","Subtracts a specified interval from a datetime." -scalar,date_add,date,,"date,interval","Adds a specified interval to a date." -aggregate,group_concat,string,,"any","(backward compatible)(deprecated) Concatenates multiple strings into a single string, where each value is separated by the optional separator parameter. If separator is omitted, BigQuery returns a comma-separated string. This function has been deprecated in favor of STRING_AGG." diff --git a/ibis-server/tools/query_local_run.py b/ibis-server/tools/query_local_run.py index 2e846e721..dc69981c1 100644 --- a/ibis-server/tools/query_local_run.py +++ b/ibis-server/tools/query_local_run.py @@ -77,7 +77,7 @@ print("### Session Properties ###") for key, value in properties: print(f"# {key}: {value}") -session_context = SessionContext(encoded_str, function_list_path + f"/{data_source}.csv", properties) +session_context = SessionContext(encoded_str, function_list_path + f"/{data_source}.csv", properties, data_source) planned_sql = session_context.transform_sql(sql) print("# Planned SQL:\n", planned_sql) diff --git a/ibis-server/wren/session/__init__.py b/ibis-server/wren/session/__init__.py index 63ce75b7c..ef4c1051e 100644 --- a/ibis-server/wren/session/__init__.py +++ b/ibis-server/wren/session/__init__.py @@ -97,7 +97,9 @@ def plan(self, input_sql): ) self.planned_sql = self.context.rewriter.rewrite_sync( - self.manifest, self.wren_sql, self.properties + self.manifest, + self.wren_sql, + self.properties, ) read = self._get_read_dialect() diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 36f22a6af..49537b1b1 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -1,3 +1,4 @@ +use std::error::Error; /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -17,6 +18,7 @@ * under the License. */ use std::fmt::Display; +use std::str::FromStr; use std::sync::Arc; #[cfg(not(feature = "python-binding"))] @@ -99,6 +101,32 @@ mod manifest_impl { pub use crate::mdl::manifest::manifest_impl::*; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParsedDataSourceError { + pub message: String, +} + +impl ParsedDataSourceError { + pub fn new(msg: &str) -> ParsedDataSourceError { + ParsedDataSourceError { + message: msg.to_string(), + } + } +} + +impl Display for ParsedDataSourceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ParsedDataSourceError: {}", self.message) + } +} + +impl Error for ParsedDataSourceError { + #[allow(deprecated)] + fn description(&self) -> &str { + &self.message + } +} + impl Display for DataSource { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -124,6 +152,37 @@ impl Display for DataSource { } } +impl FromStr for DataSource { + type Err = ParsedDataSourceError; + + fn from_str(s: &str) -> Result { + match s.to_uppercase().as_str() { + "BIGQUERY" => Ok(DataSource::BigQuery), + "CLICKHOUSE" => Ok(DataSource::Clickhouse), + "CANNER" => Ok(DataSource::Canner), + "TRINO" => Ok(DataSource::Trino), + "MSSQL" => Ok(DataSource::MSSQL), + "MYSQL" => Ok(DataSource::MySQL), + "POSTGRES" => Ok(DataSource::Postgres), + "SNOWFLAKE" => Ok(DataSource::Snowflake), + "DATAFUSION" => Ok(DataSource::Datafusion), + "DUCKDB" => Ok(DataSource::DuckDB), + "LOCAL_FILE" => Ok(DataSource::LocalFile), + "S3_FILE" => Ok(DataSource::S3File), + "GCS_FILE" => Ok(DataSource::GcsFile), + "MINIO_FILE" => Ok(DataSource::MinioFile), + "ORACLE" => Ok(DataSource::Oracle), + "ATHENA" => Ok(DataSource::Athena), + "REDSHIFT" => Ok(DataSource::Redshift), + "DATABRICKS" => Ok(DataSource::Databricks), + _ => Err(ParsedDataSourceError::new(&format!( + "Unknown data source: {}", + s + ))), + } + } +} + mod table_reference { use serde::{self, Deserialize, Deserializer, Serialize, Serializer}; @@ -260,7 +319,7 @@ impl Model { self.columns .iter() .filter(|c| c.relationship.is_none()) - .map(|c| Arc::clone(&c)) + .map(Arc::clone) .collect() } } @@ -286,7 +345,7 @@ impl Model { self.columns .iter() .find(|c| c.name == column_name) - .map(|c| Arc::clone(&c)) + .map(Arc::clone) } /// Return the primary key of the model diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index c0bff3749..02148175e 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -40,6 +40,7 @@ use wren_core::mdl::function::{ use wren_core::{ mdl, AggregateUDF, AnalyzedWrenMDL, ScalarUDF, SessionConfig, WindowUDF, }; +use wren_core_base::mdl::DataSource; /// The Python wrapper for the Wren Core session context. #[pyclass(name = "SessionContext")] @@ -77,45 +78,27 @@ impl PySessionContext { /// if `mdl_base64` is provided, the session context will be created with the given MDL. Otherwise, an empty MDL will be created. /// if `remote_functions_path` is provided, the session context will be created with the remote functions defined in the CSV file. #[new] - #[pyo3(signature = (mdl_base64=None, remote_functions_path=None, properties=None))] + #[pyo3(signature = (mdl_base64=None, remote_functions_path=None, properties=None, data_source=None))] pub fn new( mdl_base64: Option<&str>, remote_functions_path: Option<&str>, properties: Option>, + data_source: Option<&str>, ) -> PyResult { - let remote_functions = Self::read_remote_function_list(remote_functions_path) - .map_err(CoreError::from)?; - let remote_functions: Vec = remote_functions - .into_iter() - .map(|f| f.into()) - .collect::>(); - - let config = SessionConfig::default().with_information_schema(true); - let ctx = wren_core::mdl::create_wren_ctx(Some(config)); let runtime = Runtime::new().map_err(CoreError::from)?; - let registered_functions = runtime - .block_on(Self::get_registered_functions(&ctx)) - .map(|functions| { - functions - .into_iter() - .map(|f| f.name) - .collect::>() - }) - .map_err(CoreError::from)?; - - remote_functions - .into_iter() - .try_for_each(|remote_function| { - debug!("Registering remote function: {:?}", remote_function); - // TODO: check not only the name but also the return type and the parameter types - if !registered_functions.contains(&remote_function.name) { - Self::register_remote_function(&ctx, remote_function)?; - } - Ok::<(), CoreError>(()) - })?; - let Some(mdl_base64) = mdl_base64 else { + let data_source = data_source + .map(|ds| DataSource::from_str(ds).map_err(CoreError::from)) + .transpose()?; + let config = SessionConfig::default().with_information_schema(true); + let ctx = wren_core::mdl::create_wren_ctx(Some(config), data_source.as_ref()); + Self::register_function_by_data_source( + data_source.as_ref(), + remote_functions_path, + &runtime, + &ctx, + )?; return Ok(Self { ctx: ctx.clone(), exec_ctx: ctx, @@ -125,7 +108,30 @@ impl PySessionContext { }); }; - Python::attach(|py| { + let manifest = to_manifest(mdl_base64)?; + + // If the manifest has a data source, use it. + // Otherwise, if the data_source parameter is provided, use it. + // Otherwise, use None. + let data_source = if let Some(ds) = &manifest.data_source { + Some(*ds) + } else if let Some(ds_str) = data_source { + Some(DataSource::from_str(ds_str).map_err(CoreError::from)?) + } else { + None + }; + + let config = SessionConfig::default().with_information_schema(true); + let ctx = wren_core::mdl::create_wren_ctx(Some(config), data_source.as_ref()); + + Self::register_function_by_data_source( + data_source.as_ref(), + remote_functions_path, + &runtime, + &ctx, + )?; + + Python::attach(|py: Python<'_>| { let properties_map = if let Some(obj) = properties { let obj = obj.as_ref(); if obj.is_none(py) { @@ -159,7 +165,6 @@ impl PySessionContext { } else { HashMap::new() }; - let manifest = to_manifest(mdl_base64)?; let properties_ref = Arc::new(properties_map); match AnalyzedWrenMDL::analyze( manifest, @@ -360,6 +365,48 @@ impl PySessionContext { } Ok(functions) } + + fn register_function_by_data_source( + data_source: Option<&DataSource>, + remote_functions_path: Option<&str>, + runtime: &Runtime, + ctx: &wren_core::SessionContext, + ) -> PyResult<()> { + match data_source { + Some(DataSource::BigQuery) => {} + _ => { + let remote_functions = + Self::read_remote_function_list(remote_functions_path) + .map_err(CoreError::from)?; + let remote_functions: Vec = remote_functions + .into_iter() + .map(|f| f.into()) + .collect::>(); + + let registered_functions = runtime + .block_on(Self::get_registered_functions(ctx)) + .map(|functions| { + functions + .into_iter() + .map(|f| f.name) + .collect::>() + }) + .map_err(CoreError::from)?; + + remote_functions + .into_iter() + .try_for_each(|remote_function| { + debug!("Registering remote function: {:?}", remote_function); + // TODO: check not only the name but also the return type and the parameter types + if !registered_functions.contains(&remote_function.name) { + Self::register_remote_function(ctx, remote_function)?; + } + Ok::<(), CoreError>(()) + })?; + } + } + Ok(()) + } } struct RemoteFunctionDto { diff --git a/wren-core-py/src/errors.rs b/wren-core-py/src/errors.rs index 3e89aeb02..7f6f1b851 100644 --- a/wren-core-py/src/errors.rs +++ b/wren-core-py/src/errors.rs @@ -6,6 +6,7 @@ use std::string::FromUtf8Error; use thiserror::Error; use wren_core::DataFusionError; use wren_core::WrenError; +use wren_core_base::mdl::ParsedDataSourceError; #[derive(Error, Debug, PartialEq)] #[error("{message}")] @@ -87,3 +88,9 @@ impl From for CoreError { CoreError::new(&format!("IO error: {}", err)) } } + +impl From for CoreError { + fn from(err: ParsedDataSourceError) -> Self { + CoreError::new(&format!("DataSource error: {}", err)) + } +} diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index c0155365d..aacb9a0ba 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -17,7 +17,7 @@ manifest = { "catalog": "my_catalog", "schema": "my_schema", - "dataSource": "bigquery", + "dataSource": "datafusion", "models": [ { "name": "customer", @@ -260,7 +260,7 @@ def test_extract_by(dataset, expected_models): extracted_manifest = ManifestExtractor(manifest_str).extract_by(dataset) assert len(extracted_manifest.models) == len(expected_models) assert [m.name for m in extracted_manifest.models] == expected_models - assert extracted_manifest.data_source.__str__() == "DataSource.BigQuery" + assert extracted_manifest.data_source.__str__() == "DataSource.Datafusion" def test_to_json_base64(): diff --git a/wren-core/core/src/mdl/dialect/inner_dialect.rs b/wren-core/core/src/mdl/dialect/inner_dialect.rs index 91d47f3fc..5a9605380 100644 --- a/wren-core/core/src/mdl/dialect/inner_dialect.rs +++ b/wren-core/core/src/mdl/dialect/inner_dialect.rs @@ -17,7 +17,13 @@ * under the License. */ +use std::sync::Arc; + use crate::mdl::dialect::utils::scalar_function_to_sql_internal; +use crate::mdl::function::dialect::bigquery::{ + bigquery_aggregate_functions, bigquery_scalar_functions, bigquery_window_functions, +}; +use crate::mdl::function::{aggregate_functions, scalar_functions, window_functions}; use crate::mdl::manifest::DataSource; use datafusion::common::{plan_err, Result}; use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; @@ -88,6 +94,21 @@ pub trait InnerDialect: Send + Sync { ) -> bool { false } + + /// Define the supported UDFs for the dialect which will be registered in the execution context. + fn supported_udfs(&self) -> Vec> { + scalar_functions() + } + + /// Define the supported UDAFs for the dialect which will be registered in the execution context. + fn supported_udafs(&self) -> Vec> { + aggregate_functions() + } + + /// Define the supported UDWFs for the dialect which will be registered in the execution context. + fn supported_udwfs(&self) -> Vec> { + window_functions() + } } /// [get_inner_dialect] returns the suitable InnerDialect for the given data source. @@ -128,6 +149,18 @@ impl InnerDialect for MySQLDialect { pub struct BigQueryDialect {} impl InnerDialect for BigQueryDialect { + fn supported_udafs(&self) -> Vec> { + bigquery_aggregate_functions() + } + + fn supported_udfs(&self) -> Vec> { + bigquery_scalar_functions() + } + + fn supported_udwfs(&self) -> Vec> { + bigquery_window_functions() + } + fn unnest_as_table_factor(&self) -> bool { true } @@ -175,49 +208,15 @@ impl InnerDialect for BigQueryDialect { expr: Box::new(unparser.expr_to_sql(&args[1])?), })) } - "date_diff" => { - if args.len() != 3 { - return plan_err!( - "date_diff requires exactly 3 arguments, found {}", - args.len() - ); - } - let Expr::Literal(ScalarValue::Utf8(Some(s)), _) = args[0].clone() else { - return plan_err!( - "date_diff requires a string literal as the third argument" - ); - }; - let granularity = ast::Expr::Identifier(Ident::new( - self.datetime_field_from_str(&s)?.to_string(), - )); - // DATE_DIFF(end_date, start_date, granularity) - // https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date_diff - Ok(Some(ast::Expr::Function(Function { - name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new( - "DATE_DIFF", - ))]), - args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, - args: vec![ - unparser.expr_to_sql(&args[2]).map(|e| { - ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) - })?, - unparser.expr_to_sql(&args[1]).map(|e| { - ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) - })?, - ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( - granularity, - )), - ], - clauses: vec![], - }), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - parameters: ast::FunctionArguments::None, - uses_odbc_syntax: false, - }))) + // DATE_DIFF(end_date, start_date, granularity) + // https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date_diff + "date_diff" => self.transform_diff_function("DATE_DIFF", args, unparser), + "time_diff" => self.transform_diff_function("TIME_DIFF", args, unparser), + "timestamp_diff" => { + self.transform_diff_function("TIMESTAMP_DIFF", args, unparser) + } + "datetime_diff" => { + self.transform_diff_function("DATETIME_DIFF", args, unparser) } "now" => { scalar_function_to_sql_internal(unparser, None, "CURRENT_TIMESTAMP", args) @@ -256,6 +255,52 @@ impl InnerDialect for BigQueryDialect { } impl BigQueryDialect { + fn transform_diff_function( + &self, + func_name: &str, + args: &[Expr], + unparser: &Unparser, + ) -> Result> { + if args.len() != 3 { + return plan_err!( + "{} requires exactly 3 arguments, found {}", + func_name, + args.len() + ); + } + let Expr::Literal(ScalarValue::Utf8(Some(s)), _) = args[0].clone() else { + return plan_err!( + "{} requires a string literal as the third argument (granularity)", + func_name + ); + }; + let granularity = ast::Expr::Identifier(Ident::new( + self.datetime_field_from_str(&s)?.to_string(), + )); + Ok(Some(ast::Expr::Function(Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new(func_name))]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + unparser.expr_to_sql(&args[2]).map(|e| { + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) + })?, + unparser.expr_to_sql(&args[1]).map(|e| { + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) + })?, + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(granularity)), + ], + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, + }))) + } + fn datetime_field_from_expr(&self, expr: &Expr) -> Result { match expr { Expr::Literal(ScalarValue::Utf8(Some(s)), _) diff --git a/wren-core/core/src/mdl/dialect/mod.rs b/wren-core/core/src/mdl/dialect/mod.rs index b1bc4c742..7f1f926f4 100644 --- a/wren-core/core/src/mdl/dialect/mod.rs +++ b/wren-core/core/src/mdl/dialect/mod.rs @@ -17,7 +17,7 @@ * under the License. */ -mod inner_dialect; +pub mod inner_dialect; mod utils; mod wren_dialect; diff --git a/wren-core/core/src/mdl/function/dialect/bigquery/aggregate.rs b/wren-core/core/src/mdl/function/dialect/bigquery/aggregate.rs new file mode 100644 index 000000000..73d5cca09 --- /dev/null +++ b/wren-core/core/src/mdl/function/dialect/bigquery/aggregate.rs @@ -0,0 +1,209 @@ +use std::sync::Arc; + +use datafusion::{ + arrow::datatypes::{DataType, Field, Fields}, + common::types::logical_boolean, + logical_expr::{Coercion, Signature, TypeSignature, TypeSignatureClass, Volatility}, +}; + +use crate::{ + make_udaf_function, + mdl::function::{dialect::utils::build_document, ByPassAggregateUDF, ReturnType}, +}; + +make_udaf_function!( + ByPassAggregateUDF::new( + "any_value", + ReturnType::SameAsInput, + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Gets an expression for some row.", + "SELECT ANY_VALUE(column_name) FROM table;" + )), + ), + any_value +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "approx_count_distinct", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Gets the approximate result for COUNT(DISTINCT expression).", + "SELECT APPROX_COUNT_DISTINCT(column_name) FROM table;" + )), + ), + approx_count_distinct +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "approx_quantiles", + ReturnType::ArrayOfInputFirstArgument, + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Gets the approximate quantile boundaries.", + "SELECT APPROX_QUANTILES(column_name, 100) FROM table;" + )), + ), + approx_quantiles +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "approx_top_count", + ReturnType::Specific(DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("value", DataType::Utf8, true), + Field::new("count", DataType::Int64, true), + ])), + true, + )))), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Gets the approximate top elements and their approximate count.", + "SELECT APPROX_TOP_COUNT(column_name, 10) FROM table;" + )), + ), + approx_top_count +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "approx_top_sum", + ReturnType::Specific(DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("value", DataType::Utf8, true), + Field::new("sum", DataType::Float64, true), + ])), + true, + )))), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Gets the approximate top elements and their approximate sum.", + "SELECT APPROX_TOP_SUM(column_name, 10) FROM table;" + )), + ), + approx_top_sum +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "array_concat_agg", + ReturnType::SameAsInput, + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Concatenates all input arrays into a single array.", + "SELECT ARRAY_CONCAT_AGG(column_name) FROM table;" + )), + ), + array_concat_agg +); + +make_udaf_function!( + ByPassAggregateUDF::new_with_alias( + "countif", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + vec!["count_if".to_string()], + Some(build_document( + "Counts the number of input rows for which the given condition is true.", + "SELECT COUNTIF(column_name > 10) FROM table;" + )), + ), + countif +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "logical_and", + ReturnType::Specific(DataType::Boolean), + Signature::coercible(vec![Coercion::new_exact(TypeSignatureClass::Native(logical_boolean()))], Volatility::Immutable), + Some(build_document( + "Returns the logical AND of all non-NULL expressions. Returns NULL if there are zero input rows or expression evaluates to NULL for all rows.", + "SELECT LOGICAL_AND(column_name) FROM table;" + )), + ), + logical_and +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "logical_or", + ReturnType::Specific(DataType::Boolean), + Signature::coercible(vec![Coercion::new_exact(TypeSignatureClass::Native(logical_boolean()))], Volatility::Immutable), + Some(build_document( + "Returns the logical OR of all non-NULL expressions. Returns NULL if there are zero input rows or expression evaluates to NULL for all rows.", + "SELECT LOGICAL_OR(column_name) FROM table;" + )), + ), + logical_or +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "max_by", + ReturnType::SameAsInput, + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Synonym for ANY_VALUE(x HAVING MAX y).", + "SELECT MAX_BY(value_column, order_column) FROM table;" + )), + ), + max_by +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "min_by", + ReturnType::SameAsInput, + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Synonym for ANY_VALUE(x HAVING MIN y).", + "SELECT MIN_BY(value_column, order_column) FROM table;" + )), + ), + min_by +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "stddev_samp", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Calculates the sample standard deviation of a set of values.", + "SELECT STDDEV_SAMP(column_name) FROM table;" + )), + ), + stddev_samp +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "variance", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Calculates the variance of a set of values.", + "SELECT VARIANCE(column_name) FROM table;" + )), + ), + variance +); + +make_udaf_function!( + ByPassAggregateUDF::new( + "group_concat", + ReturnType::Specific(DataType::Utf8), + Signature::one_of(vec![TypeSignature::Any(1), TypeSignature::Any(2)], Volatility::Immutable), + Some(build_document( + "(DEPRECATED)(BigQuery Legacy SQL)Concatenates values from a group into a single string with a specified separator.", + "SELECT GROUP_CONCAT(column_name, ', ') FROM table;" + )), + ), + group_concat +); diff --git a/wren-core/core/src/mdl/function/dialect/bigquery/mod.rs b/wren-core/core/src/mdl/function/dialect/bigquery/mod.rs new file mode 100644 index 000000000..7cb77c1b5 --- /dev/null +++ b/wren-core/core/src/mdl/function/dialect/bigquery/mod.rs @@ -0,0 +1,356 @@ +use std::sync::Arc; + +use datafusion::{ + functions::{ + core::{coalesce, greatest, least, named_struct, nullif, nvl, nvl2}, + crypto::{md5, sha256, sha512}, + datetime::{ + current_date, current_time, date_diff, date_trunc, from_unixtime, now, + }, + math::{ + abs, acos, acosh, asin, asinh, atan, atan2, atanh, cbrt, ceil, cos, cosh, + cot, floor, ln, log, log10, log2, power, random, round, signum, sin, sinh, + sqrt, tan, tanh, trunc, + }, + regex::{regexp_instr, regexp_like, regexp_replace}, + string::{ + concat, contains, lower, ltrim, octet_length, repeat, replace, rtrim, + starts_with, to_hex, upper, uuid, + }, + unicode::{ + left, lpad, reverse, right, rpad, strpos, substr, substring, translate, + }, + }, + functions_aggregate::{ + approx_distinct::approx_distinct_udaf, + array_agg::array_agg_udaf, + average::avg_udaf, + bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}, + bool_and_or::{bool_and_udaf, bool_or_udaf}, + correlation::corr_udaf, + count::count_udaf, + covariance::{covar_pop_udaf, covar_samp_udaf}, + grouping::grouping_udaf, + min_max::{max_udaf, min_udaf}, + stddev::{stddev_pop_udaf, stddev_udaf}, + string_agg::string_agg_udaf, + sum::sum_udaf, + variance::{var_pop_udaf, var_samp_udaf}, + }, + functions_array::{ + array_has::array_has_udf, + cardinality::cardinality_udf, + concat::array_concat_udf, + extract::{array_element_udf, array_slice_udf}, + length::array_length_udf, + make_array::make_array_udf, + range::range_udf, + reverse::array_reverse_udf, + string::array_to_string_udf, + }, + functions_window::{ + cume_dist::cume_dist_udwf, + lead_lag::{lag_udwf, lead_udwf}, + nth_value::{first_value_udwf, last_value_udwf, nth_value_udwf}, + ntile::ntile_udwf, + rank::{dense_rank_udwf, percent_rank_udwf, rank_udwf}, + row_number::row_number_udwf, + }, + logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}, +}; + +use aggregate::*; +use scalar::*; +use window::*; + +use crate::mdl::function::scalar::to_char; + +mod aggregate; +mod scalar; +mod window; + +/// https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-all#function_list +pub fn bigquery_scalar_functions() -> Vec> { + vec![ + // array() isn't supported by Wren + // list_cat, list_concat + array_concat_udf(), + array_first(), + array_last(), + array_reverse_udf(), + array_slice_udf(), + array_to_string_udf(), + array_length_udf(), + generate_array(), + generate_date_array(), + generate_timestamp_array(), + bit_count(), + parse_bignumeric(), + parse_numeric(), + current_date(), + date(), + date_add(), + date_diff(), + date_from_unix_date(), + date_sub(), + date_trunc(), + format_date(), + parse_date(), + unix_date(), + farm_fingerprint(), + md5(), + sha1(), + sha256(), + sha512(), + justify_days(), + justify_hours(), + justify_interval(), + bool(), + float64(), + int64(), + json_array(), + json_array_append(), + json_array_insert(), + json_extract(), + json_extract_array(), + json_extract_scalar(), + json_extract_string_array(), + json_keys(), + json_object(), + json_query(), + json_query_array(), + json_remove(), + json_set(), + json_remove(), + json_set(), + json_strip_nulls(), + json_type(), + json_value(), + json_value_array(), + lax_bool(), + lax_float64(), + lax_int64(), + lax_string(), + parse_json(), + string(), + to_json(), + to_json_string(), + abs(), + acos(), + acosh(), + asin(), + asinh(), + atan(), + atan2(), + atanh(), + cbrt(), + ceil(), + ceiling(), + cos(), + cosh(), + cosine_distance(), + cot(), + coth(), + csc(), + csch(), + div(), + exp(), + euclidean_distance(), + floor(), + greatest(), + is_inf(), + is_nan(), + least(), + ln(), + log(), + log10(), + r#mod(), + power(), + rand(), + safe_add(), + safe_divide(), + safe_multiply(), + safe_subtract(), + safe_negate(), + sec(), + sech(), + sign(), + sin(), + sinh(), + sqrt(), + tan(), + tanh(), + trunc(), + generate_range_array(), + range_udf(), + range_contains(), + range_end(), + range_intersect(), + range_overlaps(), + range_start(), + ascii(), + byte_length(), + char_length(), + character_length(), + chr(), + code_points_to_bytes(), + code_points_to_string(), + collate(), + concat(), + contains_substr(), + edit_distance(), + ends_with(), + format(), + from_base32(), + from_base64(), + from_hex(), + initcap(), + left(), + length(), + lower(), + lpad(), + ltrim(), + normalize(), + normalize_and_casefold(), + octet_length(), + regexp_contains(), + regexp_extract(), + regexp_extract_all(), + regexp_instr(), + regexp_replace(), + regexp_substr(), + repeat(), + replace(), + reverse(), + right(), + rpad(), + rtrim(), + safe_convert_bytes_to_string(), + soundex(), + split(), + starts_with(), + // instr, position + strpos(), + substr(), + substring(), + to_base32(), + to_base64(), + to_hex(), + to_code_points(), + translate(), + trim(), + unicode(), + upper(), + current_time(), + format_time(), + parse_time(), + time(), + time_add(), + time_diff(), + time_sub(), + time_trunc(), + format_timestamp(), + parse_timestamp(), + timestamp(), + timestamp_add(), + timestamp_diff(), + timestamp_micros(), + timestamp_millis(), + timestamp_seconds(), + timestamp_sub(), + timestamp_trunc(), + unix_micros(), + unix_millis(), + unix_seconds(), + current_datetime(), + datetime(), + datetime_add(), + datetime_diff(), + datetime_sub(), + datetime_trunc(), + format_datetime(), + parse_datetime(), + round(), + nvl(), + nullif(), + array_has_udf(), + cardinality_udf(), + signum(), + contains(), + array_element_udf(), + now(), + uuid(), + from_unixtime(), + make_array_udf(), + to_char(), + power(), + regexp_like(), + nvl2(), + named_struct(), + random(), + coalesce(), + log2(), + offset(), + ordinal(), + safe_offset(), + safe_ordinal(), + ] +} + +pub fn bigquery_aggregate_functions() -> Vec> { + vec![ + any_value(), + approx_count_distinct(), + approx_quantiles(), + approx_top_count(), + approx_top_sum(), + approx_distinct_udaf(), + array_agg_udaf(), + array_concat_agg(), + avg_udaf(), + bit_and_udaf(), + bit_or_udaf(), + bit_xor_udaf(), + count_udaf(), + countif(), + grouping_udaf(), + logical_and(), + logical_or(), + max_udaf(), + max_by(), + min_udaf(), + min_by(), + string_agg_udaf(), + sum_udaf(), + corr_udaf(), + covar_pop_udaf(), + covar_samp_udaf(), + stddev_udaf(), + stddev_pop_udaf(), + stddev_samp(), + var_pop_udaf(), + var_samp_udaf(), + variance(), + bool_or_udaf(), + bool_and_udaf(), + group_concat(), + ] +} + +pub fn bigquery_window_functions() -> Vec> { + vec![ + first_value_udwf(), + lag_udwf(), + last_value_udwf(), + lead_udwf(), + nth_value_udwf(), + percentile_cont_udwf(), + percentile_disc_udwf(), + cume_dist_udwf(), + dense_rank_udwf(), + ntile_udwf(), + percent_rank_udwf(), + rank_udwf(), + row_number_udwf(), + ] +} diff --git a/wren-core/core/src/mdl/function/dialect/bigquery/scalar.rs b/wren-core/core/src/mdl/function/dialect/bigquery/scalar.rs new file mode 100644 index 000000000..f5388fc1e --- /dev/null +++ b/wren-core/core/src/mdl/function/dialect/bigquery/scalar.rs @@ -0,0 +1,1928 @@ +use std::sync::Arc; + +use datafusion::{ + arrow::datatypes::{DataType, Field}, + common::types::{logical_binary, logical_string}, + logical_expr::{Coercion, Signature, TypeSignature, Volatility}, +}; + +use crate::{ + make_udf_function, + mdl::function::{dialect::utils::build_document, ByPassScalarUDF, ReturnType}, +}; + +make_udf_function!( + ByPassScalarUDF::new( + "array_first", + ReturnType::SameAsInputFirstArrayElement, + Signature::array(Volatility::Immutable), + Some(build_document( + "Returns the first element of the array.", + "SELECT ARRAY_FIRST([1, 2, 3]); -- returns 1" + )), + ), + array_first +); + +make_udf_function!( + ByPassScalarUDF::new( + "array_last", + ReturnType::SameAsInputFirstArrayElement, + Signature::array(Volatility::Immutable), + Some(build_document( + "Returns the last element of the array.", + "SELECT ARRAY_LAST([1, 2, 3]); -- returns 3" + )), + ), + array_last +); + +make_udf_function!( + ByPassScalarUDF::new( + "generate_array", + ReturnType::ArrayOfInputFirstArgument, + Signature::one_of( + vec![TypeSignature::Any(3), TypeSignature::Any(2),], + Volatility::Immutable + ), + Some(build_document( + "Generates an array of values from start to end with an optional step.", + "SELECT GENERATE_ARRAY(1, 5); -- returns [1, 2, 3, 4, 5]" + )), + ), + generate_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "generate_date_array", + ReturnType::ArrayOfInputFirstArgument, + Signature::one_of(vec![ + TypeSignature::Any(3), + TypeSignature::Any(2), + ], Volatility::Immutable), + Some(build_document( + "Generates an array of dates from start to end with an optional step.", + "SELECT GENERATE_DATE_ARRAY(DATE '2021-01-01', DATE '2021-01-05'); -- returns [2021-01-01, 2021-01-02, 2021-01-03, 2021-01-04, 2021-01-05]" + )), + ), + generate_date_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "generate_timestamp_array", + ReturnType::ArrayOfInputFirstArgument, + Signature::one_of(vec![ + TypeSignature::Any(3), + ], Volatility::Immutable), + Some(build_document( + "Returns an ARRAY of TIMESTAMPS separated by a given interval. The start_timestamp and end_timestamp parameters determine the inclusive lower and upper bounds of the ARRAY.", + "SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', '2016-10-07 00:00:00', INTERVAL 1 DAY) AS timestamp_array;" + )), + ), + generate_timestamp_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "bit_count", + ReturnType::Specific(DataType::Int64), + Signature::one_of(vec![ + TypeSignature::Coercible(vec![Coercion::new_exact( + datafusion::logical_expr::TypeSignatureClass::Integer, + )]), + TypeSignature::Coercible(vec![Coercion::new_exact( + datafusion::logical_expr::TypeSignatureClass::Native(logical_binary()), + )]), + ], Volatility::Immutable), + Some(build_document( + "Returns the number of bits set to 1 in the binary representation of the input integer.", + "SELECT BIT_COUNT(29); -- returns 4, since 29 in binary is 11101" + )), + ), + bit_count +); + +make_udf_function!( + ByPassScalarUDF::new( + "parse_bignumeric", + ReturnType::Specific(DataType::Decimal128(38, 9)), + Signature::coercible( + vec![Coercion::new_exact( + datafusion::logical_expr::TypeSignatureClass::Native(logical_string()), + )], + Volatility::Immutable, + ), + Some(build_document( + "Parses a string and returns a BIGNUMERIC value.", + "SELECT PARSE_BIGNUMERIC('1234567890.123456789'); -- returns 1234567890.123456789" + )), + ), + parse_bignumeric +); + +make_udf_function!( + ByPassScalarUDF::new( + "parse_numeric", + ReturnType::Specific(DataType::Decimal128(38, 9)), + Signature::coercible( + vec![Coercion::new_exact( + datafusion::logical_expr::TypeSignatureClass::Native(logical_string()), + )], + Volatility::Immutable, + ), + Some(build_document( + "Parses a string and returns a NUMERIC value.", + "SELECT PARSE_NUMERIC('12345.6789'); -- returns 12345.6789" + )), + ), + parse_numeric +); + +make_udf_function!( + ByPassScalarUDF::new( + "date", + ReturnType::Specific(DataType::Date32), + Signature::one_of( + vec![TypeSignature::Any(1), TypeSignature::Any(2),], + Volatility::Immutable, + ), + Some(build_document( + "Extracts the date part from a timestamp or string.", + "SELECT DATE('2021-10-05 12:34:56'); -- returns 2021-10-05" + )), + ), + date +); + +make_udf_function!( + ByPassScalarUDF::new( + "date_add", + ReturnType::Specific(DataType::Date32), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Adds a specified interval to a date.", + "SELECT DATE_ADD(DATE '2021-01-01', INTERVAL 5 DAY); -- returns 2021-01-06" + )), + ), + date_add +); + +make_udf_function!( + ByPassScalarUDF::new( + "date_from_unix_date", + ReturnType::Specific(DataType::Date32), + Signature::coercible( + vec![Coercion::new_exact( + datafusion::logical_expr::TypeSignatureClass::Integer + ),], + Volatility::Immutable, + ), + Some(build_document( + "Converts a Unix date (number of days since 1970-01-01) to a DATE.", + "SELECT DATE_FROM_UNIX_DATE(18628); -- returns 2021-01-01" + )), + ), + date_from_unix_date +); + +make_udf_function!( + ByPassScalarUDF::new( + "date_sub", + ReturnType::Specific(DataType::Date32), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Subtracts a specified interval from a date.", + "SELECT DATE_SUB(DATE '2021-01-10', INTERVAL 5 DAY); -- returns 2021-01-05" + )), + ), + date_sub +); + +make_udf_function!( + ByPassScalarUDF::new( + "format_date", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Formats a DATE according to the specified format string.", + "SELECT FORMAT_DATE('%Y-%m-%d', DATE '2021-01-05'); -- returns '2021-01-05'" + )), + ), + format_date +); + +make_udf_function!( + ByPassScalarUDF::new( + "parse_date", + ReturnType::Specific(DataType::Date32), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Parses a string into a DATE according to the specified format string.", + "SELECT PARSE_DATE('%Y-%m-%d', '2021-01-05'); -- returns 2021-01-05" + )), + ), + parse_date +); + +make_udf_function!( + ByPassScalarUDF::new( + "unix_date", + ReturnType::Specific(DataType::Int32), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a DATE to a Unix date (number of days since 1970-01-01).", + "SELECT UNIX_DATE(DATE '2021-01-01'); -- returns 18628" + )), + ), + unix_date +); + +make_udf_function!( + ByPassScalarUDF::new( + "farm_fingerprint", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes a fingerprint for a string.", + "SELECT FARM_FINGERPRINT('Hello, world!'); -- returns a 64-bit integer" + )), + ), + farm_fingerprint +); + +make_udf_function!( + ByPassScalarUDF::new( + "sha1", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes the SHA-1 hash of a string.", + "SELECT SHA1('Hello, world!'); -- returns '2ef7bdecadad9f73dffb5fbdc4f1b3e6eed8c5'" + )), + ), + sha1 +); + +make_udf_function!( + ByPassScalarUDF::new( + "justify_days", + ReturnType::Specific(DataType::Interval( + datafusion::arrow::datatypes::IntervalUnit::DayTime + )), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Justifies a number of days.", + "SELECT JUSTIFY_DAYS(5); -- returns INTERVAL '5' DAY" + )), + ), + justify_days +); + +make_udf_function!( + ByPassScalarUDF::new( + "justify_hours", + ReturnType::Specific(DataType::Interval( + datafusion::arrow::datatypes::IntervalUnit::DayTime + )), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Justifies a number of hours.", + "SELECT JUSTIFY_HOURS(48); -- returns INTERVAL '2' DAY" + )), + ), + justify_hours +); + +make_udf_function!( + ByPassScalarUDF::new( + "justify_interval", + ReturnType::Specific(DataType::Interval(datafusion::arrow::datatypes::IntervalUnit::MonthDayNano)), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Justifies an interval.", + "SELECT JUSTIFY_INTERVAL(INTERVAL '36' HOUR); -- returns INTERVAL '1' DAY '12' HOUR" + )), + ), + justify_interval +); + +make_udf_function!( + ByPassScalarUDF::new( + "ceiling", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Synonym of CEIL(X)", + "SELECT CEILING(3.14); -- returns 4.0" + )), + ), + ceiling +); + +make_udf_function!( + ByPassScalarUDF::new( + "cosine_distance", + ReturnType::Specific(DataType::Float64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Computes the cosine distance between two vectors.", + "SELECT COSINE_DISTANCE(ARRAY[1, 2, 3], ARRAY[4, 5, 6]); -- returns 0.9746318461970762" + )), + ), + cosine_distance +); + +make_udf_function!( + ByPassScalarUDF::new( + "coth", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes the hyperbolic cotangent of a number.", + "SELECT COTH(1.0); -- returns 1.3130352854993312" + )), + ), + coth +); + +make_udf_function!( + ByPassScalarUDF::new( + "csc", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes the cosecant of the input angle, which is in radians.", + "SELECT CSC(1.0); -- returns 1.8508157176809257" + )), + ), + csc +); + +make_udf_function!( + ByPassScalarUDF::new( + "csch", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes the hyperbolic cosecant of a number.", + "SELECT CSCH(1.0); -- returns 0.8509181282393216" + )), + ), + csch +); + +make_udf_function!( + ByPassScalarUDF::new( + "div", + ReturnType::Specific(DataType::Int64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Computes the integer division of two numbers.", + "SELECT DIV(5, 2); -- returns 2" + )), + ), + div +); + +make_udf_function!( + ByPassScalarUDF::new( + "exp", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes the exponential of a number.", + "SELECT EXP(1.0); -- returns 2.718281828459045" + )), + ), + exp +); + +make_udf_function!( + ByPassScalarUDF::new( + "euclidean_distance", + ReturnType::Specific(DataType::Float64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Computes the Euclidean distance between two points.", + "SELECT EUCLIDEAN_DISTANCE(ARRAY[1, 2], ARRAY[4, 6]); -- returns 5.0" + )), + ), + euclidean_distance +); + +make_udf_function!( + ByPassScalarUDF::new( + "is_inf", + ReturnType::Specific(DataType::Boolean), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns true if the input is positive or negative infinity.", + "SELECT IS_INF(1.0 / 0.0); -- returns true" + )), + ), + is_inf +); + +make_udf_function!( + ByPassScalarUDF::new( + "is_nan", + ReturnType::Specific(DataType::Boolean), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns true if the input is NaN (Not a Number).", + "SELECT IS_NAN(0.0 / 0.0); -- returns true" + )), + ), + is_nan +); + +make_udf_function!( + ByPassScalarUDF::new( + "mod", + ReturnType::Specific(DataType::Int64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Computes the modulus of two numbers.", + "SELECT MOD(5, 2); -- returns 1" + )), + ), + r#mod +); + +make_udf_function!( + ByPassScalarUDF::new( + "rand", + ReturnType::Specific(DataType::Float64), + Signature::nullary(Volatility::Volatile), + Some(build_document( + "Returns a random float value in the range [0, 1). The random seed is unique to each row.", + "SELECT RAND(); -- returns a random float between 0 and 1" + )), + ), + rand +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_add", + ReturnType::Specific(DataType::Int64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Performs addition and returns NULL if overflow occurs.", + "SELECT SAFE_ADD(9223372036854775807, 1); -- returns NULL" + )), + ), + safe_add +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_subtract", + ReturnType::Specific(DataType::Int64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Performs subtraction and returns NULL if overflow occurs.", + "SELECT SAFE_SUBTRACT(-9223372036854775808, 1); -- returns NULL" + )), + ), + safe_subtract +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_multiply", + ReturnType::Specific(DataType::Int64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Performs multiplication and returns NULL if overflow occurs.", + "SELECT SAFE_MULTIPLY(3037000499, 3037000499); -- returns NULL" + )), + ), + safe_multiply +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_divide", + ReturnType::Specific(DataType::Float64), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Performs division and returns NULL if division by zero occurs.", + "SELECT SAFE_DIVIDE(1, 0); -- returns NULL" + )), + ), + safe_divide +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_negate", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Performs negation and returns NULL if overflow occurs.", + "SELECT SAFE_NEGATE(-9223372036854775808); -- returns NULL" + )), + ), + safe_negate +); + +make_udf_function!( + ByPassScalarUDF::new( + "sec", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes the secant of the input angle, which is in radians.", + "SELECT SEC(1.0); -- returns 1.8508157176809257" + )), + ), + sec +); + +make_udf_function!( + ByPassScalarUDF::new( + "sech", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Computes the hyperbolic secant of a number.", + "SELECT SECH(1.0); -- returns 0.6480542736638855" + )), + ), + sech +); + +make_udf_function!( + ByPassScalarUDF::new( + "sign", + ReturnType::Specific(DataType::Int8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the sign of a number: -1 for negative, 0 for zero, and 1 for positive.", + "SELECT SIGN(-10); -- returns -1" + )), + ), + sign +); + +make_udf_function!( + ByPassScalarUDF::new( + "generate_range_array", + ReturnType::ArrayOfInputFirstArgument, + Signature::one_of(vec![ + TypeSignature::Any(3), + TypeSignature::Any(2), + ], Volatility::Immutable), + Some(build_document( + "Generates an array of numbers within a specified range with an optional step.", + "SELECT GENERATE_RANGE_ARRAY(1, 10); -- returns [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]" + )), + ), + generate_range_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "range_contains", + ReturnType::Specific(DataType::Boolean), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Checks if the inner range is in the outer range.", + "SELECT RANGE_CONTAINS(RANGE(1, 10), RANGE(3, 7)); -- returns true" + )), + ), + range_contains +); + +make_udf_function!( + ByPassScalarUDF::new( + "range_end", + ReturnType::SameAsInputFirstArrayElement, + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Gets the upper bound of a range.", + "SELECT RANGE_END(RANGE(1, 10)); -- returns 10" + )), + ), + range_end +); + +make_udf_function!( + ByPassScalarUDF::new( + "range_start", + ReturnType::SameAsInputFirstArrayElement, + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Gets the lower bound of a range.", + "SELECT RANGE_START(RANGE(1, 10)); -- returns 1" + )), + ), + range_start +); + +make_udf_function!( + ByPassScalarUDF::new( + "range_intersect", + ReturnType::SameAsInputFirstArrayElement, + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Computes the intersection of two ranges.", + "SELECT RANGE_INTERSECT(RANGE(1, 10), RANGE(5, 15)); -- returns RANGE(5, 10)" + )), + ), + range_intersect +); + +make_udf_function!( + ByPassScalarUDF::new( + "range_overlaps", + ReturnType::Specific(DataType::Boolean), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Checks if two ranges overlap.", + "SELECT RANGE_OVERLAPS(RANGE(1, 10), RANGE(5, 15)); -- returns true" + )), + ), + range_overlaps +); + +make_udf_function!( + ByPassScalarUDF::new( + "ascii", + ReturnType::Specific(DataType::Int32), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the ASCII code of the first character of the input string.", + "SELECT ASCII(column_name) FROM table;" + )), + ), + ascii +); + +make_udf_function!( + ByPassScalarUDF::new( + "byte_length", + ReturnType::Specific(DataType::Int32), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the length of the input string in bytes.", + "SELECT BYTE_LENGTH(column_name) FROM table;" + )), + ), + byte_length +); + +make_udf_function!( + ByPassScalarUDF::new( + "char_length", + ReturnType::Specific(DataType::Int32), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the length of the input string in characters.", + "SELECT CHAR_LENGTH(column_name) FROM table;" + )), + ), + char_length +); + +make_udf_function!( + ByPassScalarUDF::new( + "character_length", + ReturnType::Specific(DataType::Int32), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the length of the input string in characters. Synonym for CHAR_LENGTH.", + "SELECT CHARACTER_LENGTH(column_name) FROM table;" + )), + ), + character_length +); + +make_udf_function!( + ByPassScalarUDF::new( + "chr", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the character corresponding to the given ASCII code.", + "SELECT CHR(65); -- returns 'A'" + )), + ), + chr +); + +make_udf_function!( + ByPassScalarUDF::new( + "code_points_to_bytes", + ReturnType::Specific(DataType::Binary), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts an array of Unicode code points to a byte array.", + "SELECT CODE_POINTS_TO_BYTES([72, 101, 108, 108, 111]); -- returns b'Hello'" + )), + ), + code_points_to_bytes +); + +make_udf_function!( + ByPassScalarUDF::new( + "code_points_to_string", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts an array of Unicode code points to a string.", + "SELECT CODE_POINTS_TO_STRING([72, 101, 108, 108, 111]); -- returns 'Hello'" + )), + ), + code_points_to_string +); + +make_udf_function!( + ByPassScalarUDF::new( + "collate", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Applies a collation to a string.", + "SELECT COLLATE('straße', 'de_DE'); -- returns 'straße' with German collation" + )), + ), + collate +); + +make_udf_function!( + ByPassScalarUDF::new( + "contains_substr", + ReturnType::Specific(DataType::Boolean), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Checks if the first string contains the second substring.", + "SELECT CONTAINS_SUBSTR('Hello, world!', 'world'); -- returns true" + )), + ), + contains_substr +); + +make_udf_function!( + ByPassScalarUDF::new( + "edit_distance", + ReturnType::Specific(DataType::Int32), + Signature::one_of( + vec![TypeSignature::Any(2), TypeSignature::Any(3),], + Volatility::Immutable + ), + Some(build_document( + "Computes the Levenshtein edit distance between two strings.", + "SELECT EDIT_DISTANCE('kitten', 'sitting'); -- returns 3" + )), + ), + edit_distance +); + +make_udf_function!( + ByPassScalarUDF::new( + "ends_with", + ReturnType::Specific(DataType::Boolean), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Checks if the first string ends with the second substring.", + "SELECT ENDS_WITH('Hello, world!', 'world!'); -- returns true" + )), + ), + ends_with +); + +make_udf_function!( + ByPassScalarUDF::new( + "format", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Formats a string using the specified format and arguments.", + "SELECT FORMAT('date: %s!', FORMAT_DATE('%B %d, %Y', date '2015-01-02'));" + )), + ), + format +); + +make_udf_function!( + ByPassScalarUDF::new( + "from_base32", + ReturnType::Specific(DataType::Binary), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Decodes a base32-encoded string to a byte array.", + "SELECT FROM_BASE32('JBSWY3DPEBLW64TMMQQ===='); -- returns b'Hello, world!'" + )), + ), + from_base32 +); + +make_udf_function!( + ByPassScalarUDF::new( + "from_base64", + ReturnType::Specific(DataType::Binary), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Decodes a base64-encoded string to a byte array.", + "SELECT FROM_BASE64('SGVsbG8sIHdvcmxkIQ=='); -- returns b'Hello, world!'" + )), + ), + from_base64 +); + +make_udf_function!( + ByPassScalarUDF::new( + "from_hex", + ReturnType::Specific(DataType::Binary), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Decodes a hexadecimal-encoded string to a byte array.", + "SELECT FROM_HEX('48656c6c6f2c20776f726c6421'); -- returns b'Hello, world!'" + )), + ), + from_hex +); + +make_udf_function!( + ByPassScalarUDF::new( + "initcap", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Capitalizes the first letter of each word in the input string.", + "SELECT INITCAP('hello world!'); -- returns 'Hello World!'" + )), + ), + initcap +); + +make_udf_function!( + ByPassScalarUDF::new( + "length", + ReturnType::Specific(DataType::Int32), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the length of the input string in characters.", + "SELECT LENGTH('Hello, world!'); -- returns 13" + )), + ), + length +); + +make_udf_function!( + ByPassScalarUDF::new( + "normalize", + ReturnType::Specific(DataType::Utf8), + Signature::one_of( + vec![TypeSignature::Any(1), TypeSignature::Any(2),], + Volatility::Immutable + ), + Some(build_document( + "Normalizes a string to the specified Unicode normalization form.", + "SELECT NORMALIZE('é'); -- returns 'é' in NFC form" + )), + ), + normalize +); + +make_udf_function!( + ByPassScalarUDF::new( + "normalize_and_casefold", + ReturnType::Specific(DataType::Utf8), + Signature::one_of(vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + ], Volatility::Immutable), + Some(build_document( + "Takes a string value and returns it as a normalized string. If you don't provide a normalization mode, NFC is used.", + "SELECT NORMALIZE_AND_CASEFOLD('Straße'); -- returns 'strasse'" + )), + ), + normalize_and_casefold +); + +make_udf_function!( + ByPassScalarUDF::new( + "regexp_contains", + ReturnType::Specific(DataType::Boolean), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Checks if the input string matches the specified regular expression.", + "SELECT REGEXP_CONTAINS('Hello, world!', r'world'); -- returns true" + )), + ), + regexp_contains +); + +make_udf_function!( + ByPassScalarUDF::new( + "regexp_extract", + ReturnType::Specific(DataType::Utf8), + Signature::one_of(vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + TypeSignature::Any(4), + ], Volatility::Immutable), + Some(build_document( + "Extracts a substring from the input string that matches the specified regular expression.", + "SELECT REGEXP_EXTRACT('Hello, world!', r'world'); -- returns 'world'" + )), + ), + regexp_extract +); + +make_udf_function!( + ByPassScalarUDF::new( + "regexp_extract_all", + ReturnType::Specific(DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)))), + Signature::one_of(vec![ + TypeSignature::Any(2), + ], Volatility::Immutable), + Some(build_document( + "Extracts all substrings from the input string that match the specified regular expression.", + "SELECT REGEXP_EXTRACT_ALL('ababab', r'ab'); -- returns ['ab', 'ab', 'ab']" + )), + ), + regexp_extract_all +); + +make_udf_function!( + ByPassScalarUDF::new( + "regexp_substr", + ReturnType::Specific(DataType::Utf8), + Signature::one_of( + vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + TypeSignature::Any(4), + ], + Volatility::Immutable + ), + Some(build_document( + "Returns the substring that matches the specified regular expression.", + "SELECT REGEXP_SUBSTR('Hello, world!', r'world'); -- returns 'world'" + )), + ), + regexp_substr +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_convert_bytes_to_string", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a byte array to a string, returning NULL if the byte array is not valid UTF-8.", + "SELECT SAFE_CONVERT_BYTES_TO_STRING(b'Hello, world!'); -- returns 'Hello, world!'" + )), + ), + safe_convert_bytes_to_string +); + +make_udf_function!( + ByPassScalarUDF::new( + "soundex", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the Soundex code for the input string.", + "SELECT SOUNDEX('Robert'); -- returns 'R163'" + )), + ), + soundex +); + +make_udf_function!( + ByPassScalarUDF::new( + "split", + ReturnType::Specific(DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)))), + Signature::one_of(vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + ], Volatility::Immutable), + Some(build_document( + "Splits the input string into an array of substrings based on the specified delimiter.", + "SELECT SPLIT('apple,banana,cherry', ','); -- returns ['apple', 'banana', 'cherry']" + )), + ), + split +); + +make_udf_function!( + ByPassScalarUDF::new( + "to_base32", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Encodes a byte array to a base32-encoded string.", + "SELECT TO_BASE32(b'Hello, world!'); -- returns 'JBSWY3DPEBLW64TMMQQ===='" + )), + ), + to_base32 +); + +make_udf_function!( + ByPassScalarUDF::new( + "to_base64", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Encodes a byte array to a base64-encoded string.", + "SELECT TO_BASE64(b'Hello, world!'); -- returns 'SGVsbG8sIHdvcmxkIQ=='" + )), + ), + to_base64 +); + +make_udf_function!( + ByPassScalarUDF::new( + "to_code_points", + ReturnType::Specific(DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a string to an array of Unicode code points.", + "SELECT TO_CODE_POINTS('Hello'); -- returns [72, 101, 108, 108, 111]" + )), + ), + to_code_points +); + +make_udf_function!( + ByPassScalarUDF::new( + "trim", + ReturnType::Specific(DataType::Utf8), + Signature::one_of(vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + ], Volatility::Immutable), + Some(build_document( + "Removes leading and trailing spaces or specified characters from the input string.", + "SELECT TRIM(' Hello, world! '); -- returns 'Hello, world!'" + )), + ), + trim +); + +make_udf_function!( + ByPassScalarUDF::new( + "unicode", + ReturnType::Specific(DataType::Int32), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the Unicode code point of the first character of the input string.", + "SELECT UNICODE('A'); -- returns 65" + )), + ), + unicode +); + +make_udf_function!( + ByPassScalarUDF::new( + "format_time", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Formats a TIME according to the specified format string.", + r#"SELECT FORMAT_TIME("%R", TIME "15:30:00") as formatted_time; -- returns '15:30'"# + )), + ), + format_time +); + +make_udf_function!( + ByPassScalarUDF::new( + "parse_time", + ReturnType::Specific(DataType::Time64(datafusion::arrow::datatypes::TimeUnit::Microsecond)), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Parses a string into a TIME according to the specified format string.", + "SELECT PARSE_TIME('%I:%M:%S %p', '2:23:38 pm') AS parsed_time; -- returns TIME '14:23:38'" + )), + ), + parse_time +); + +make_udf_function!( + ByPassScalarUDF::new( + "time", + ReturnType::Specific(DataType::Time64( + datafusion::arrow::datatypes::TimeUnit::Microsecond + )), + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable + ), + Some(build_document( + "Converts a string to a TIME value.", + "SELECT TIME('15:30:00') AS time_value;" + )), + ), + time +); + +make_udf_function!( + ByPassScalarUDF::new( + "time_add", + ReturnType::Specific(DataType::Time64(datafusion::arrow::datatypes::TimeUnit::Microsecond)), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Adds an INTERVAL to a TIME value.", + "SELECT TIME_ADD(TIME '15:30:00', INTERVAL '02:00:00' HOUR TO SECOND) AS new_time;" + )), + ), + time_add +); + +make_udf_function!( + ByPassScalarUDF::new( + "time_diff", + ReturnType::Specific(DataType::Int64), + Signature::any(3, Volatility::Immutable), + Some(build_document( + "Calculates the difference between two TIME values as an INTERVAL.", + "SELECT TIME_DIFF('SECOND', TIME '18:30:00', TIME '15:30:00') AS time_difference;" + )), + ), + time_diff +); + +make_udf_function!( + ByPassScalarUDF::new( + "time_sub", + ReturnType::Specific(DataType::Time64(datafusion::arrow::datatypes::TimeUnit::Microsecond)), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Subtracts an INTERVAL from a TIME value.", + "SELECT TIME_SUB(TIME '15:30:00', INTERVAL '02:00:00' HOUR TO SECOND) AS new_time;" + )), + ), + time_sub +); + +make_udf_function!( + ByPassScalarUDF::new( + "time_trunc", + ReturnType::Specific(DataType::Time64( + datafusion::arrow::datatypes::TimeUnit::Microsecond + )), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Truncates a TIME value to the specified part.", + "SELECT TIME_TRUNC('HOUR', TIME '15:45:30') AS truncated_time;" + )), + ), + time_trunc +); + +make_udf_function!( + ByPassScalarUDF::new( + "format_timestamp", + ReturnType::Specific(DataType::Utf8), + Signature::one_of(vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable), + Some(build_document( + "Formats a TIMESTAMP according to the specified format string with optional timezone.", + r#"SSELECT FORMAT_TIMESTAMP("%c", TIMESTAMP "2050-12-25 15:30:55+00", "UTC"); -- returns 'Sun Dec 25 15:30:55 2050'"# + )), + ), + format_timestamp +); + +make_udf_function!( + ByPassScalarUDF::new( + "parse_timestamp", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::one_of(vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable), + Some(build_document( + "Parses a string into a TIMESTAMP according to the specified format string with optional timezone.", + r#"SELECT PARSE_TIMESTAMP("%c", "Thu Dec 25 07:30:00 2008") AS parsed;"# + )), + ), + parse_timestamp +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp", + ReturnType::Specific(DataType::Timestamp( + datafusion::arrow::datatypes::TimeUnit::Microsecond, + None + )), + Signature::one_of( + vec![TypeSignature::Any(1), TypeSignature::Any(2),], + Volatility::Immutable + ), + Some(build_document( + "Converts a string to a TIMESTAMP value with optional timezone.", + "SELECT TIMESTAMP('2023-10-05 14:30:00', 'UTC') AS timestamp_value;" + )), + ), + timestamp +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp_add", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Adds an INTERVAL to a TIMESTAMP value.", + "SELECT TIMESTAMP_ADD(TIMESTAMP '2023-10-05 14:30:00', INTERVAL 10 MINUTE) AS new_timestamp;" + )), + ), + timestamp_add +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp_diff", + ReturnType::Specific(DataType::Int64), + Signature::any(3, Volatility::Immutable), + Some(build_document( + "Calculates the difference between two TIMESTAMP values as an INTERVAL.", + "SELECT TIMESTAMP_DIFF('MINUTE', TIMESTAMP '2023-10-05 15:30:00', TIMESTAMP '2023-10-05 14:30:00') AS timestamp_difference;" + )), + ), + timestamp_diff +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp_micros", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Interprets int64_expression as the number of microseconds since 1970-01-01 00:00:00 UTC and returns a timestamp.", + "SELECT TIMESTAMP_MICROS(1230219000000000) AS timestamp_value;" + )), + ), + timestamp_micros +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp_millis", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Interprets int64_expression as the number of milliseconds since 1970-01-01 00:00:00 UTC and returns a timestamp.", + "SELECT TIMESTAMP_MILLIS(1230219000000) AS timestamp_value;" + )), + ), + timestamp_millis +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp_seconds", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Interprets int64_expression as the number of seconds since 1970-01-01 00:00:00 UTC and returns a timestamp.", + "SELECT TIMESTAMP_SECONDS(1230219000) AS timestamp_value;" + )), + ), + timestamp_seconds +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp_sub", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Subtracts an INTERVAL from a TIMESTAMP value.", + "SELECT TIMESTAMP_SUB(TIMESTAMP '2023-10-05 14:30:00', INTERVAL 10 MINUTE) AS new_timestamp;" + )), + ), + timestamp_sub +); + +make_udf_function!( + ByPassScalarUDF::new( + "timestamp_trunc", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::one_of(vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable), + Some(build_document( + "Truncates a TIMESTAMP value to the specified part.", + "SELECT TIMESTAMP_TRUNC('HOUR', TIMESTAMP '2023-10-05 14:45:30') AS truncated_timestamp;" + )), + ), + timestamp_trunc +); + +make_udf_function!( + ByPassScalarUDF::new( + "unix_micros", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the number of microseconds since 1970-01-01 00:00:00 UTC for the given TIMESTAMP.", + "SELECT UNIX_MICROS(TIMESTAMP '2023-10-05 14:30:00') AS microseconds_since_epoch;" + )), + ), + unix_micros +); + +make_udf_function!( + ByPassScalarUDF::new( + "unix_millis", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the number of milliseconds since 1970-01-01 00:00:00 UTC for the given TIMESTAMP.", + "SELECT UNIX_MILLIS(TIMESTAMP '2023-10-05 14:30:00') AS milliseconds_since_epoch;" + )), + ), + unix_millis +); + +make_udf_function!( + ByPassScalarUDF::new( + "unix_seconds", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the number of seconds since 1970-01-01 00:00:00 UTC for the given TIMESTAMP.", + "SELECT UNIX_SECONDS(TIMESTAMP '2023-10-05 14:30:00') AS seconds_since_epoch;" + )), + ), + unix_seconds +); + +make_udf_function!( + ByPassScalarUDF::new( + "current_datetime", + ReturnType::Specific(DataType::Timestamp( + datafusion::arrow::datatypes::TimeUnit::Microsecond, + None + )), + Signature::nullary(Volatility::Volatile), + Some(build_document( + "Returns the current date and time.", + "SELECT CURRENT_DATETIME() AS now;" + )), + ), + current_datetime +); + +make_udf_function!( + ByPassScalarUDF::new( + "datetime", + ReturnType::Specific(DataType::Timestamp( + datafusion::arrow::datatypes::TimeUnit::Microsecond, + None + )), + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable + ), + Some(build_document( + "Converts a string to a DATETIME value with optional timezone.", + "SELECT DATETIME('2023-10-05 14:30:00', 'UTC') AS datetime_value;" + )), + ), + datetime +); + +make_udf_function!( + ByPassScalarUDF::new( + "datetime_add", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Adds an INTERVAL to a DATETIME value.", + "SELECT DATETIME_ADD(DATETIME '2023-10-05 14:30:00', INTERVAL 10 MINUTE) AS new_datetime;" + )), + ), + datetime_add +); + +make_udf_function!( + ByPassScalarUDF::new( + "datetime_diff", + ReturnType::Specific(DataType::Int64), + Signature::any(3, Volatility::Immutable), + Some(build_document( + "Calculates the difference between two DATETIME values as an INTERVAL.", + "SELECT DATETIME_DIFF('MINUTE', DATETIME '2023-10-05 15:30:00', DATETIME '2023-10-05 14:30:00') AS datetime_difference;" + )), + ), + datetime_diff +); + +make_udf_function!( + ByPassScalarUDF::new( + "datetime_sub", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Subtracts an INTERVAL from a DATETIME value.", + "SELECT DATETIME_SUB(DATETIME '2023-10-05 14:30:00', INTERVAL 10 MINUTE) AS new_datetime;" + )), + ), + datetime_sub +); + +make_udf_function!( + ByPassScalarUDF::new( + "datetime_trunc", + ReturnType::Specific(DataType::Timestamp( + datafusion::arrow::datatypes::TimeUnit::Microsecond, + None + )), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Truncates a DATETIME value to the specified part.", + "SELECT DATETIME_TRUNC('HOUR', DATETIME '2023-10-05 14:45:30') AS truncated_datetime;" + )), + ), + datetime_trunc +); + +make_udf_function!( + ByPassScalarUDF::new( + "format_datetime", + ReturnType::Specific(DataType::Utf8), + Signature::one_of(vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable), + Some(build_document( + "Formats a DATETIME according to the specified format string with optional timezone.", + r#"SELECT FORMAT_DATETIME("%c", DATETIME "2050-12-25 15:30:55", "UTC"); -- returns 'Sun Dec 25 15:30:55 2050'"# + )), + ), + format_datetime +); + +make_udf_function!( + ByPassScalarUDF::new( + "parse_datetime", + ReturnType::Specific(DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None)), + Signature::one_of(vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable), + Some(build_document( + "Parses a string into a DATETIME according to the specified format string with optional timezone.", + r#"SELECT PARSE_DATETIME("%c", "Thu Dec 25 07:30:00 2008") AS parsed;"# + )), + ), + parse_datetime +); + +make_udf_function!( + ByPassScalarUDF::new( + "offset", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "To access the array element for zero-based indexing.", + "SELECT [10, 20, 30][OFFSET(1)] AS second_element; -- returns 20" + )), + ), + offset +); + +make_udf_function!( + ByPassScalarUDF::new( + "ordinal", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "To access the array element for one-based indexing.", + "SELECT [10, 20, 30][ORDINAL(2)] AS second_element; -- returns 20" + )), + ), + ordinal +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_offset", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document("To safely access the array element for zero-based indexing, returning NULL if out of bounds.", + "SELECT [10, 20, 30][SAFE_OFFSET(5)] AS out_of_bounds_element; -- returns NULL" + )), + ), + safe_offset +); + +make_udf_function!( + ByPassScalarUDF::new( + "safe_ordinal", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document("To safely access the array element for one-based indexing, returning NULL if out of bounds.", + "SELECT [10, 20, 30][SAFE_ORDINAL(5)] AS out_of_bounds_element; -- returns NULL" + )), + ), + safe_ordinal +); + +// JSON functions would go here +make_udf_function!( + ByPassScalarUDF::new( + "bool", + ReturnType::Specific(DataType::Boolean), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON boolean to a SQL BOOL value.", + "SELECT BOOL(JSON 'true') AS vacancy;" + )), + ), + r#bool +); + +make_udf_function!( + ByPassScalarUDF::new( + "float64", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON number to a SQL FLOAT64 value.", + "SELECT FLOAT64(JSON '12345.6789') AS gdp;" + )), + ), + float64 +); + +make_udf_function!( + ByPassScalarUDF::new( + "int64", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON number to a SQL INT64 value.", + "SELECT INT64(JSON '123456789') AS population;" + )), + ), + int64 +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_array", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Creates a JSON array from the input values.", + "SELECT JSON_ARRAY(1, 'two', TRUE) AS json_array;" + )), + ), + json_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_array_append", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Appends values to a JSON array.", + "SELECT JSON_ARRAY_APPEND(JSON_ARRAY(1, 2), 3) AS json_array;" + )), + ), + json_array_append +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_array_insert", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Inserts values into a JSON array at specified positions.", + r#"SELECT JSON_ARRAY_INSERT(JSON '["a", ["b", "c"], "d"]', '$[1]', 1) AS json_data;"# + )), + ), + json_array_insert +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_extract", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Extracts a value from a JSON string using a JSONPath expression.", + r#"SELECT JSON_EXTRACT(JSON '{"a": {"b": [1, 2, 3]}}', '$.a.b[1]') AS json_value;"# + )), + ), + json_extract +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_extract_array", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Extracts a JSON array from a JSON string using a JSONPath expression.", + r#"SELECT JSON_EXTRACT_ARRAY(JSON '{"a": {"b": [1, 2, 3]}}', '$.a.b') AS json_array;"# + )), + ), + json_extract_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_extract_scalar", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Extracts a scalar value from a JSON string using a JSONPath expression.", + r#"SELECT JSON_EXTRACT_SCALAR(JSON '{"a": {"b": 42}}', '$.a.b') AS json_value;"# + )), + ), + json_extract_scalar +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_extract_string_array", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Extracts a JSON string array from a JSON string using a JSONPath expression.", + r#"SELECT JSON_EXTRACT_STRING_ARRAY(JSON '{"a": {"b": ["x", "y", "z"]}}', '$.a.b') AS json_string_array;"# + )), + ), + json_extract_string_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_keys", + ReturnType::Specific(DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)))), + Signature::one_of(vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable), + Some(build_document( + "Extracts unique JSON keys from a JSON expression with optional max_depth and mode parameters ('strict', 'lax', 'lax recursive'.", + r#"SELECT JSON_KEYS(JSON '{"name": "Alice", "age": 30}') AS keys;"# + )), + ), + json_keys +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_object", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Creates a JSON object from key-value pairs.", + "SELECT JSON_OBJECT('name', 'Alice', 'age', 30) AS json_object;" + )), + ), + json_object +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_query", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Extracts JSON values based on a JSONPath expression.", + r#"SELECT JSON_QUERY(JSON '{"a": {"b": [1, 2, 3]}}', '$.a.b') AS json_value;"# + )), + ), + json_query +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_query_array", + ReturnType::Specific(DataType::Utf8), + Signature::one_of( + vec![TypeSignature::Any(1), TypeSignature::Any(2),], + Volatility::Immutable + ), + Some(build_document( + "Extracts JSON arrays based on a JSONPath expression.", + r#"SELECT JSON_QUERY_ARRAY(JSON '{"a": {"b": [1, 2, 3]}}', '$.a.b') AS json_array;"# + )), + ), + json_query_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_remove", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Removes specified paths from a JSON string.", + r#"SELECT JSON_REMOVE(JSON '{"a": 1, "b": 2, "c": 3}', '$.b') AS json_data;"# + )), + ), + json_remove +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_set", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Sets values at specified paths in a JSON string.", + r#"SELECT JSON_SET(JSON '{"a": 1, "b": 2}', '$.b', 20, '$.c', 30) AS json_data;"# + )), + ), + json_set +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_strip_nulls", + ReturnType::Specific(DataType::Utf8), + Signature::variadic_any(Volatility::Immutable), + Some(build_document( + "Removes all null values from a JSON string with optional path, include_array, and remove_empty_object parameters.", + r#"SELECT JSON_STRIP_NULLS(JSON '{"a": 1, "b": null, "c": 3, "d": null}') AS json_data;"# + )), + ), + json_strip_nulls +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_type", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Returns the type of the outermost JSON value as a string.", + r#"SELECT JSON_TYPE(JSON '{"a": 1, "b": 2}') AS json_type;"# + )), + ), + json_type +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_value", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Extracts a scalar value from a JSON string using a JSONPath expression.", + r#"SELECT JSON_VALUE(JSON '{"a": {"b": 42}}', '$.a.b') AS json_value;"# + )), + ), + json_value +); + +make_udf_function!( + ByPassScalarUDF::new( + "json_value_array", + ReturnType::Specific(DataType::Utf8), + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Extracts a JSON value array from a JSON string using a JSONPath expression.", + r#"SELECT JSON_VALUE_ARRAY(JSON '{"a": {"b": ["x", "y", "z"]}}', '$.a.b') AS json_value_array;"# + )), + ), + json_value_array +); + +make_udf_function!( + ByPassScalarUDF::new( + "lax_bool", + ReturnType::Specific(DataType::Boolean), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON boolean to a SQL BOOL value in lax mode.", + "SELECT LAX_BOOL(JSON 'true') AS vacancy;" + )), + ), + lax_bool +); + +make_udf_function!( + ByPassScalarUDF::new( + "lax_float64", + ReturnType::Specific(DataType::Float64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON number to a SQL FLOAT64 value in lax mode.", + "SELECT LAX_FLOAT64(JSON '12345.6789') AS gdp;" + )), + ), + lax_float64 +); + +make_udf_function!( + ByPassScalarUDF::new( + "lax_int64", + ReturnType::Specific(DataType::Int64), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON number to a SQL INT64 value in lax mode.", + "SELECT LAX_INT64(JSON '123456789') AS population;" + )), + ), + lax_int64 +); + +make_udf_function!( + ByPassScalarUDF::new( + "lax_string", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON string to a SQL STRING value in lax mode.", + "SELECT LAX_STRING(JSON 'Hello, world!') AS greeting;" + )), + ), + lax_string +); + +make_udf_function!( + ByPassScalarUDF::new( + "parse_json", + ReturnType::Specific(DataType::Utf8), + Signature::one_of(vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + ], Volatility::Immutable), + Some(build_document( + "Parses a string into a JSON value with optional mode parameter ('strict' or 'lax').", + "SELECT PARSE_JSON('{\"name\": \"Alice\", \"age\": 30}') AS json_data;" + )), + ), + parse_json +); + +make_udf_function!( + ByPassScalarUDF::new( + "string", + ReturnType::Specific(DataType::Utf8), + Signature::any(1, Volatility::Immutable), + Some(build_document( + "Converts a JSON string to a SQL STRING value.", + "SELECT STRING(JSON 'Hello, world!') AS greeting;" + )), + ), + string +); + +make_udf_function!( + ByPassScalarUDF::new( + "to_json", + ReturnType::Specific(DataType::Utf8), + Signature::one_of( + vec![TypeSignature::Any(1), TypeSignature::Any(2),], + Volatility::Immutable + ), + Some(build_document( + "Converts a SQL value to a JSON string.", + "SELECT TO_JSON(12345) AS json_data;" + )), + ), + to_json +); + +make_udf_function!( + ByPassScalarUDF::new( + "to_json_string", + ReturnType::Specific(DataType::Utf8), + Signature::one_of( + vec![TypeSignature::Any(1), TypeSignature::Any(2),], + Volatility::Immutable + ), + Some(build_document( + "Converts a SQL value to a JSON string.", + "SELECT TO_JSON_STRING(12345) AS json_string;" + )), + ), + to_json_string +); diff --git a/wren-core/core/src/mdl/function/dialect/bigquery/window.rs b/wren-core/core/src/mdl/function/dialect/bigquery/window.rs new file mode 100644 index 000000000..67a1d39a8 --- /dev/null +++ b/wren-core/core/src/mdl/function/dialect/bigquery/window.rs @@ -0,0 +1,32 @@ +use datafusion::logical_expr::{Signature, Volatility}; + +use crate::{ + make_udwf_function, + mdl::function::{dialect::utils::build_document, ByPassWindowFunction, ReturnType}, +}; + +make_udwf_function!( + ByPassWindowFunction::new( + "percentile_cont", + ReturnType::SameAsInput, + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Calculates the value at the given percentile of a set of values.", + "SELECT PERCENTILE_CONT(column_name, 0.5) OVER() FROM table;" + )), + ), + percentile_cont_udwf +); + +make_udwf_function!( + ByPassWindowFunction::new( + "percentile_disc", + ReturnType::SameAsInput, + Signature::any(2, Volatility::Immutable), + Some(build_document( + "Calculates the discrete value at the given percentile of a set of values.", + "SELECT PERCENTILE_DISC(column_name, 0.5) OVER() FROM table;" + )), + ), + percentile_disc_udwf +); diff --git a/wren-core/core/src/mdl/function/dialect/mod.rs b/wren-core/core/src/mdl/function/dialect/mod.rs new file mode 100644 index 000000000..cf9c5a919 --- /dev/null +++ b/wren-core/core/src/mdl/function/dialect/mod.rs @@ -0,0 +1,2 @@ +pub mod bigquery; +mod utils; diff --git a/wren-core/core/src/mdl/function/dialect/utils.rs b/wren-core/core/src/mdl/function/dialect/utils.rs new file mode 100644 index 000000000..a555e28d2 --- /dev/null +++ b/wren-core/core/src/mdl/function/dialect/utils.rs @@ -0,0 +1,5 @@ +use datafusion::logical_expr::{DocSection, Documentation, DocumentationBuilder}; + +pub fn build_document(desc: &str, example: &str) -> Documentation { + DocumentationBuilder::new_with_details(DocSection::default(), desc, example).build() +} diff --git a/wren-core/core/src/mdl/function/macros.rs b/wren-core/core/src/mdl/function/macros.rs index fdc7f5076..9e0cd10ad 100644 --- a/wren-core/core/src/mdl/function/macros.rs +++ b/wren-core/core/src/mdl/function/macros.rs @@ -3,7 +3,7 @@ /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. #[macro_export] -macro_rules! make_udf_function { +macro_rules! make_datafusion_udf_function { ($UDF:ty, $NAME:ident) => { #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME() -> std::sync::Arc { @@ -19,3 +19,57 @@ macro_rules! make_udf_function { } }; } + +#[macro_export] +macro_rules! make_udf_function { + ($UDF:expr, $NAME:ident) => { + #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] + pub fn $NAME() -> std::sync::Arc { + // Singleton instance of the function + static INSTANCE: std::sync::LazyLock< + std::sync::Arc, + > = std::sync::LazyLock::new(|| { + std::sync::Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl( + $UDF + )) + }); + std::sync::Arc::clone(&INSTANCE) + } + }; +} + +#[macro_export] +macro_rules! make_udaf_function { + ($UDF:expr, $NAME:ident) => { + #[doc = concat!("Return a [`AggregateUDF`](datafusion_expr::AggregateUDF) implementation of ", stringify!($NAME))] + pub fn $NAME() -> std::sync::Arc { + // Singleton instance of the function + static INSTANCE: std::sync::LazyLock< + std::sync::Arc, + > = std::sync::LazyLock::new(|| { + std::sync::Arc::new(datafusion::logical_expr::AggregateUDF::new_from_impl( + $UDF + )) + }); + std::sync::Arc::clone(&INSTANCE) + } + }; +} + +#[macro_export] +macro_rules! make_udwf_function { + ($UDF:expr, $NAME:ident) => { + #[doc = concat!("Return a [`WindowUDF`](datafusion_expr::WindowUDF) implementation of ", stringify!($NAME))] + pub fn $NAME() -> std::sync::Arc { + // Singleton instance of the function + static INSTANCE: std::sync::LazyLock< + std::sync::Arc, + > = std::sync::LazyLock::new(|| { + std::sync::Arc::new(datafusion::logical_expr::WindowUDF::new_from_impl( + $UDF + )) + }); + std::sync::Arc::clone(&INSTANCE) + } + }; +} diff --git a/wren-core/core/src/mdl/function/mod.rs b/wren-core/core/src/mdl/function/mod.rs index 1cae051d7..7f15792e1 100644 --- a/wren-core/core/src/mdl/function/mod.rs +++ b/wren-core/core/src/mdl/function/mod.rs @@ -1,4 +1,5 @@ mod aggregate; +pub mod dialect; mod macros; mod remote_function; mod scalar; diff --git a/wren-core/core/src/mdl/function/remote_function.rs b/wren-core/core/src/mdl/function/remote_function.rs index 849f7df6b..f8102b288 100644 --- a/wren-core/core/src/mdl/function/remote_function.rs +++ b/wren-core/core/src/mdl/function/remote_function.rs @@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize}; use std::any::Any; use std::fmt::Display; use std::str::FromStr; +use std::sync::Arc; use crate::logical_plan::utils::{get_coercion_type_signature, map_data_type}; @@ -109,6 +110,8 @@ pub enum ReturnType { /// If the input type is array, the return type is the same as the element type of the first array argument /// e.g. `greatest(array)` will return `int` SameAsInputFirstArrayElement, + /// The return type is the array of the first argument type + ArrayOfInputFirstArgument, } impl Display for ReturnType { @@ -119,6 +122,9 @@ impl Display for ReturnType { ReturnType::SameAsInputFirstArrayElement => { write!(f, "same_as_input_first_array_element") } + ReturnType::ArrayOfInputFirstArgument => { + write!(f, "array_of_input_first_argument") + } } } } @@ -131,6 +137,7 @@ impl FromStr for ReturnType { "same_as_input_first_array_element" => { Ok(ReturnType::SameAsInputFirstArrayElement) } + "array_of_input_first_argument" => Ok(ReturnType::ArrayOfInputFirstArgument), _ => map_data_type(s) .map(ReturnType::Specific) .map_err(|e| e.to_string()), @@ -155,6 +162,12 @@ impl ReturnType { return not_impl_err!("Input type is not array"); } } + ReturnType::ArrayOfInputFirstArgument => { + if arg_types.is_empty() { + return not_impl_err!("No input type"); + } + DataType::List(Arc::new(Field::new("item", arg_types[0].clone(), true))) + } }) } } @@ -171,7 +184,21 @@ pub struct ByPassScalarUDF { } impl ByPassScalarUDF { - pub fn new(name: &str, return_type: DataType) -> Self { + pub fn new( + name: &str, + return_type: ReturnType, + signature: Signature, + doc: Option, + ) -> Self { + Self { + name: name.to_string(), + return_type, + signature, + doc, + } + } + + pub fn new_with_return_type(name: &str, return_type: DataType) -> Self { Self { name: name.to_string(), return_type: ReturnType::Specific(return_type), @@ -251,11 +278,43 @@ pub struct ByPassAggregateUDF { name: String, return_type: ReturnType, signature: Signature, + aliases: Vec, doc: Option, } impl ByPassAggregateUDF { - pub fn new(name: &str, return_type: DataType) -> Self { + pub fn new( + name: &str, + return_type: ReturnType, + signature: Signature, + doc: Option, + ) -> Self { + Self { + name: name.to_string(), + return_type, + signature, + aliases: vec![], + doc, + } + } + + pub fn new_with_alias( + name: &str, + return_type: ReturnType, + signature: Signature, + aliases: Vec, + doc: Option, + ) -> Self { + Self { + name: name.to_string(), + return_type, + signature, + aliases, + doc, + } + } + + pub fn new_with_return_type(name: &str, return_type: DataType) -> Self { Self { name: name.to_string(), return_type: ReturnType::Specific(return_type), @@ -263,6 +322,7 @@ impl ByPassAggregateUDF { vec![TypeSignature::VariadicAny, TypeSignature::Nullary], Volatility::Volatile, ), + aliases: vec![], doc: None, } } @@ -277,6 +337,7 @@ impl From for ByPassAggregateUDF { signature: func.get_signature(), doc: Some(build_document(&func)), name: func.name, + aliases: vec![], } } } @@ -290,6 +351,10 @@ impl AggregateUDFImpl for ByPassAggregateUDF { &self.name } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn signature(&self) -> &Signature { &self.signature } @@ -318,7 +383,21 @@ pub struct ByPassWindowFunction { } impl ByPassWindowFunction { - pub fn new(name: &str, return_type: DataType) -> Self { + pub fn new( + name: &str, + return_type: ReturnType, + signature: Signature, + doc: Option, + ) -> Self { + Self { + name: name.to_string(), + return_type, + signature, + doc, + } + } + + pub fn new_with_return_type(name: &str, return_type: DataType) -> Self { Self { name: name.to_string(), return_type: ReturnType::Specific(return_type), @@ -399,7 +478,7 @@ mod test { #[tokio::test] async fn test_by_pass_scalar_udf() -> Result<()> { - let udf = ByPassScalarUDF::new("date_test", DataType::Int64); + let udf = ByPassScalarUDF::new_with_return_type("date_test", DataType::Int64); let ctx = SessionContext::new(); ctx.register_udf(ScalarUDF::new_from_impl(udf)); @@ -410,10 +489,9 @@ mod test { let expected = "Projection: date_test(Int64(1), Int64(2))\n EmptyRelation"; assert_eq!(format!("{plan}"), expected); - ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( - "today", - DataType::Utf8, - ))); + ctx.register_udf(ScalarUDF::new_from_impl( + ByPassScalarUDF::new_with_return_type("today", DataType::Utf8), + )); let plan_2 = ctx.sql("SELECT today()").await?.into_unoptimized_plan(); assert_eq!(format!("{plan_2}"), "Projection: today()\n EmptyRelation"); @@ -422,7 +500,7 @@ mod test { #[tokio::test] async fn test_by_pass_agg_udf() -> Result<()> { - let udf = ByPassAggregateUDF::new("count_self", DataType::Int64); + let udf = ByPassAggregateUDF::new_with_return_type("count_self", DataType::Int64); let ctx = SessionContext::new(); ctx.register_udaf(AggregateUDF::new_from_impl(udf)); @@ -434,10 +512,9 @@ mod test { \n Values: (Int64(1), Int64(2)), (Int64(2), Int64(3)), (Int64(3), Int64(4))"; assert_eq!(format!("{plan}"), expected); - ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new( - "total_count", - DataType::Int64, - ))); + ctx.register_udaf(AggregateUDF::new_from_impl( + ByPassAggregateUDF::new_with_return_type("total_count", DataType::Int64), + )); let plan_2 = ctx .sql("SELECT total_count() AS total_count FROM (VALUES (1), (2), (3)) AS val(x)") .await? @@ -455,7 +532,8 @@ mod test { #[tokio::test] async fn test_by_pass_window_udf() -> Result<()> { - let udf = ByPassWindowFunction::new("custom_window", DataType::Int64); + let udf = + ByPassWindowFunction::new_with_return_type("custom_window", DataType::Int64); let ctx = SessionContext::new(); ctx.register_udwf(WindowUDF::new_from_impl(udf)); @@ -468,10 +546,9 @@ mod test { \n EmptyRelation"; assert_eq!(format!("{plan}"), expected); - ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new( - "cume_dist", - DataType::Int64, - ))); + ctx.register_udwf(WindowUDF::new_from_impl( + ByPassWindowFunction::new_with_return_type("cume_dist", DataType::Int64), + )); let plan_2 = ctx .sql("SELECT cume_dist() OVER ()") .await? diff --git a/wren-core/core/src/mdl/function/scalar/mod.rs b/wren-core/core/src/mdl/function/scalar/mod.rs index ab3388082..dec085dc3 100644 --- a/wren-core/core/src/mdl/function/scalar/mod.rs +++ b/wren-core/core/src/mdl/function/scalar/mod.rs @@ -9,11 +9,11 @@ use datafusion::{ logical_expr::ScalarUDF, }; -use crate::make_udf_function; +use crate::make_datafusion_udf_function; mod to_char; -make_udf_function!(to_char::ToCharFunc, to_char); +make_datafusion_udf_function!(to_char::ToCharFunc, to_char); pub fn scalar_functions() -> Vec> { vec![ diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index df215b3e1..87e21297d 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -3,6 +3,7 @@ use crate::logical_plan::error::WrenError; use crate::logical_plan::utils::{from_qualified_name_str, try_map_data_type}; use crate::mdl::builder::ManifestBuilder; use crate::mdl::context::{apply_wren_on_ctx, Mode, WrenDataSource}; +use crate::mdl::dialect::inner_dialect::get_inner_dialect; use crate::mdl::function::{ ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType, RemoteFunction, @@ -364,14 +365,27 @@ impl WrenMDL { } /// Create a SessionContext with the default functions registered -pub fn create_wren_ctx(config: Option) -> SessionContext { +pub fn create_wren_ctx( + config: Option, + data_source: Option<&DataSource>, +) -> SessionContext { let builder = SessionStateBuilder::new() .with_expr_planners(SessionStateDefaults::default_expr_planners()) - .with_scalar_functions(crate::mdl::function::scalar_functions()) - .with_aggregate_functions(crate::mdl::function::aggregate_functions()) - .with_window_functions(crate::mdl::function::window_functions()) .with_table_function_list(crate::mdl::function::table_functions()); + let builder = if let Some(data_source) = data_source { + let dialect = get_inner_dialect(data_source); + builder + .with_scalar_functions(dialect.supported_udfs()) + .with_aggregate_functions(dialect.supported_udafs()) + .with_window_functions(dialect.supported_udwfs()) + } else { + builder + .with_scalar_functions(crate::mdl::function::scalar_functions()) + .with_aggregate_functions(crate::mdl::function::aggregate_functions()) + .with_window_functions(crate::mdl::function::window_functions()) + }; + let builder = if let Some(config) = config { builder.with_config(config) } else { @@ -390,7 +404,7 @@ pub fn transform_sql( ) -> Result { let runtime = tokio::runtime::Runtime::new().unwrap(); runtime.block_on(transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()), analyzed_mdl, remote_functions, Arc::new(properties), @@ -472,12 +486,12 @@ async fn permission_analyze( remote_functions: &[RemoteFunction], properties: SessionPropertiesRef, ) -> Result<()> { + let ctx = create_wren_ctx(None, manifest.data_source.as_ref()); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::clone(&properties), Mode::PermissionAnalyze, )?); - let ctx = create_wren_ctx(None); remote_functions.iter().try_for_each(|remote_function| { debug!("Registering remote function: {remote_function:?}"); register_remote_function(&ctx, remote_function)?; @@ -517,24 +531,24 @@ fn register_remote_function( remote_function: &RemoteFunction, ) -> Result<()> { match &remote_function.function_type { - FunctionType::Scalar => { - ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( + FunctionType::Scalar => ctx.register_udf(ScalarUDF::new_from_impl( + ByPassScalarUDF::new_with_return_type( &remote_function.name, try_map_data_type(&remote_function.return_type)?, - ))) - } - FunctionType::Aggregate => { - ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new( + ), + )), + FunctionType::Aggregate => ctx.register_udaf(AggregateUDF::new_from_impl( + ByPassAggregateUDF::new_with_return_type( &remote_function.name, try_map_data_type(&remote_function.return_type)?, - ))) - } - FunctionType::Window => { - ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new( + ), + )), + FunctionType::Window => ctx.register_udwf(WindowUDF::new_from_impl( + ByPassWindowFunction::new_with_return_type( &remote_function.name, try_map_data_type(&remote_function.return_type)?, - ))) - } + ), + )), }; Ok(()) } @@ -644,7 +658,7 @@ mod test { for sql in tests { println!("Original: {sql}"); let actual = mdl::transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -677,7 +691,7 @@ mod test { let sql = "select * from test.test.customer_view"; println!("Original: {sql}"); let _ = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -709,7 +723,7 @@ mod test { )?); let sql = "select totalcost from profile"; let result = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -720,7 +734,7 @@ mod test { let sql = "select totalcost from profile where p_sex = 'M'"; let result = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -734,8 +748,6 @@ mod test { #[tokio::test] async fn test_uppercase_catalog_schema() -> Result<()> { - let ctx = create_wren_ctx(None); - ctx.register_batch("customer", customer())?; let manifest = ManifestBuilder::new() .catalog("CTest") .schema("STest") @@ -754,7 +766,7 @@ mod test { )?); let sql = r#"select * from CTest.STest.Customer"#; let actual = mdl::transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -775,7 +787,6 @@ mod test { [env!("CARGO_MANIFEST_DIR"), "tests", "data", "functions.csv"] .iter() .collect(); - let ctx = create_wren_ctx(None); let functions = csv::Reader::from_path(test_data) .unwrap() .into_deserialize::() @@ -797,6 +808,7 @@ mod test { Arc::new(HashMap::default()), Mode::Unparse, )?); + let ctx = create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()); let actual = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), @@ -834,7 +846,7 @@ mod test { #[tokio::test] async fn test_unicode_remote_column_name() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -873,7 +885,7 @@ mod test { )?); let sql = r#"select * from wren.test.artist"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -887,7 +899,7 @@ mod test { let sql = r#"select group from wren.test.artist"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -900,7 +912,7 @@ mod test { let sql = r#"select subscribe_plus from wren.test.artist"#; let actual = mdl::transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -914,7 +926,7 @@ mod test { #[tokio::test] async fn test_invalid_infer_remote_table() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -943,7 +955,7 @@ mod test { )?); let sql = r#"select name_append from wren.test.artist"#; let _ = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -959,7 +971,7 @@ mod test { let sql = r#"select lower_name from wren.test.artist"#; let _ = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -977,7 +989,7 @@ mod test { #[tokio::test] async fn test_query_hidden_column() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -1002,7 +1014,7 @@ mod test { )?); let sql = r#"select 串接名字 from wren.test.artist"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1013,7 +1025,7 @@ mod test { @"SELECT artist.\"串接名字\" FROM (SELECT artist.\"串接名字\" FROM (SELECT __source.\"名字\" || __source.\"名字\" AS \"串接名字\" FROM artist AS __source) AS artist) AS artist"); let sql = r#"select * from wren.test.artist"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1025,7 +1037,7 @@ mod test { let sql = r#"select "名字" from wren.test.artist"#; let _ = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1044,7 +1056,7 @@ mod test { async fn test_disable_simplify_expression() -> Result<()> { let sql = "select current_date"; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::new(AnalyzedWrenMDL::default()), &[], Arc::new(HashMap::new()), @@ -1075,7 +1087,7 @@ mod test { )?); let sql = r#"select * from wren.test.artist where 名字 in (SELECT 名字 FROM wren.test.artist)"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1090,7 +1102,7 @@ mod test { /// This test will be failed if the `出道時間` is not inferred as a timestamp column correctly. #[tokio::test] async fn test_infer_timestamp_column() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -1110,7 +1122,7 @@ mod test { )?); let sql = r#"select current_date > "出道時間" from wren.test.artist"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1125,7 +1137,7 @@ mod test { #[tokio::test] async fn test_disable_count_wildcard_rule() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select count(*) from (select 1)"; @@ -1144,7 +1156,7 @@ mod test { } async fn assert_sql_valid_executable(sql: &str) -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); // To roundtrip testing, we should register the mock table for the planned sql. ctx.register_batch("orders", orders())?; ctx.register_batch("customer", customer())?; @@ -1166,7 +1178,7 @@ mod test { #[tokio::test] async fn test_mysql_style_interval() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, Some(&DataSource::MySQL)); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select interval 1 day"; let actual = transform_sql_with_ctx( @@ -1208,7 +1220,7 @@ mod test { #[tokio::test] async fn test_unnest_as_table_factor() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new().build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, @@ -1249,7 +1261,7 @@ mod test { #[tokio::test] async fn test_simplify_timestamp() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp '2011-01-01 18:00:00 +08:00'"; let actual = transform_sql_with_ctx( @@ -1280,7 +1292,7 @@ mod test { let mut headers = HashMap::new(); headers.insert("x-wren-timezone".to_string(), Some("+08:00".to_string())); let headers_ref = Arc::new(headers); - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp '2011-01-01 18:00:00'"; let actual = transform_sql_with_ctx( @@ -1306,7 +1318,7 @@ mod test { // TIMESTAMP WITH TIME ZONE will be converted to the session timezone assert_snapshot!(actual, @"SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let mut headers = HashMap::new(); headers.insert( "x-wren-timezone".to_string(), @@ -1340,7 +1352,7 @@ mod test { let headers = HashMap::new(); let headers_ref = Arc::new(headers); - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp with time zone '2011-01-01 18:00:00' - timestamp with time zone '2011-01-01 10:00:00'"; let actual = transform_sql_with_ctx( @@ -1359,7 +1371,7 @@ mod test { #[tokio::test] async fn test_disable_pushdown_filter() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -1388,7 +1400,7 @@ mod test { )?); let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamptz as timestamp) > timestamp '2011-01-01 21:00:00'"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1406,7 +1418,7 @@ mod test { #[tokio::test] async fn test_register_timestamptz() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_batch("timestamp_table", timestamp_table())?; let provider = ctx .catalog("datafusion") @@ -1457,7 +1469,7 @@ mod test { #[tokio::test] async fn test_coercion_timestamptz() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_batch("timestamp_table", timestamp_table())?; for timezone_type in [ "timestamptz", @@ -1484,7 +1496,7 @@ mod test { )?); let sql = r#"select timestamp_col = timestamptz_col from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1499,7 +1511,7 @@ mod test { let sql = r#"select timestamptz_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1513,7 +1525,7 @@ mod test { let sql = r#"select timestamptz_col > '2011-01-01 18:00:00' from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1528,7 +1540,7 @@ mod test { let sql = r#"select timestamp_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &create_wren_ctx(None), + &create_wren_ctx(None, None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1544,7 +1556,7 @@ mod test { #[tokio::test] async fn test_list() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1576,7 +1588,7 @@ mod test { #[tokio::test] async fn test_struct() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1675,7 +1687,7 @@ mod test { #[tokio::test] async fn test_disable_common_expression_eliminate() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let sql = "SELECT CAST(TIMESTAMP '2021-01-01 00:00:00' as TIMESTAMP WITH TIME ZONE) = \ CAST(TIMESTAMP '2021-01-01 00:00:00' as TIMESTAMP WITH TIME ZONE)"; @@ -1694,7 +1706,7 @@ mod test { #[tokio::test] async fn test_disable_eliminate_nested_union() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let sql = r#"SELECT * FROM (SELECT 1 x, 'a' y UNION ALL SELECT 1 x, 'b' y UNION ALL SELECT 2 x, 'a' y UNION ALL @@ -1725,7 +1737,7 @@ mod test { Arc::new(HashMap::default()), Mode::Unparse, )?); - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let sql = "SELECT trim(' abc')"; let actual = transform_sql_with_ctx( &ctx, @@ -1741,7 +1753,7 @@ mod test { #[tokio::test] async fn test_disable_single_distinct_to_group_by() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1779,7 +1791,7 @@ mod test { #[tokio::test] async fn test_disable_distinct_to_group_by() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1814,7 +1826,7 @@ mod test { #[tokio::test] async fn test_disable_scalar_subquery() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1849,7 +1861,7 @@ mod test { #[tokio::test] async fn test_wildcard_where() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1911,7 +1923,7 @@ mod test { } "#; let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, @@ -1962,7 +1974,7 @@ mod test { } "#; let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, @@ -1986,7 +1998,7 @@ mod test { #[tokio::test] async fn test_rlac_with_requried_properties() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); // test required property let manifest = ManifestBuilder::new() @@ -2198,7 +2210,7 @@ mod test { #[tokio::test] async fn test_rlac_with_optional_properties() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); // test required property let manifest = ManifestBuilder::new() @@ -2389,7 +2401,7 @@ mod test { #[tokio::test] async fn test_rlac_on_calculated_field() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -2560,7 +2572,7 @@ mod test { #[tokio::test] async fn test_rlac_alias_model() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2623,7 +2635,7 @@ mod test { #[tokio::test] async fn test_rlac_unicode_model_column_name() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2663,7 +2675,7 @@ mod test { #[tokio::test] async fn test_ralc_condition_contain_hidden() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -2711,7 +2723,7 @@ mod test { #[tokio::test] async fn test_clac_with_required_properties() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -2808,7 +2820,7 @@ mod test { #[tokio::test] async fn test_clac_permission_denied() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2881,7 +2893,7 @@ mod test { #[tokio::test] async fn test_calc_primary_key() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2923,7 +2935,7 @@ mod test { #[tokio::test] async fn test_clac_with_optional_properties() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -3031,7 +3043,7 @@ mod test { #[tokio::test] async fn test_clac_on_calculated_field() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -3183,7 +3195,7 @@ mod test { #[tokio::test] async fn test_rlac_case_insensitive() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); // test required property let manifest = ManifestBuilder::new() @@ -3221,7 +3233,7 @@ mod test { #[tokio::test] async fn test_disable_eliminate_limit() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); // test required property let manifest = ManifestBuilder::new() @@ -3251,7 +3263,7 @@ mod test { #[tokio::test] async fn test_default_nulls_last() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); // test required property let manifest = ManifestBuilder::new() @@ -3312,7 +3324,7 @@ mod test { #[tokio::test] async fn test_extract_roundtrip_bigquery() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, Some(&DataSource::BigQuery)); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3371,7 +3383,7 @@ mod test { #[tokio::test] async fn test_date_diff_bigquery() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, Some(&DataSource::BigQuery)); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3455,7 +3467,7 @@ mod test { #[tokio::test] async fn test_window_function_frame() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3492,7 +3504,7 @@ mod test { #[tokio::test] async fn test_window_functions_without_frame_bigquery() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3522,7 +3534,7 @@ mod test { #[tokio::test] async fn test_cte_used_in_scalar_subquery() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3565,7 +3577,7 @@ mod test { #[tokio::test] async fn test_ambiguous_table_name() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3628,7 +3640,7 @@ mod test { #[tokio::test] async fn test_unicode_literal() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::default().build(); let properties = SessionPropertiesRef::default(); @@ -3662,7 +3674,7 @@ mod test { #[tokio::test] async fn test_compatible_type() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::default().build(); let properties = SessionPropertiesRef::default(); @@ -3681,7 +3693,7 @@ mod test { #[tokio::test] async fn test_trim_function_bigquery() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, Some(&DataSource::BigQuery)); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3706,6 +3718,7 @@ mod test { @"SELECT trim(customer.c_name) FROM (SELECT customer.c_name FROM (SELECT __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer" ); + let ctx = create_wren_ctx(None, None); // normal data source will be transformed to btrim let manifest = ManifestBuilder::new() .catalog("wren") @@ -3734,7 +3747,7 @@ mod test { #[tokio::test] async fn test_to_char() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3776,7 +3789,7 @@ mod test { #[tokio::test] async fn test_disable_eliminate_cross_join() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); // test required property let manifest = ManifestBuilder::new() @@ -3813,7 +3826,7 @@ mod test { #[tokio::test] async fn test_snowflake_unnest() -> Result<()> { - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, Some(&DataSource::Snowflake)); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") diff --git a/wren-core/sqllogictest/src/test_context.rs b/wren-core/sqllogictest/src/test_context.rs index 95b7e19d7..42b406bd5 100644 --- a/wren-core/sqllogictest/src/test_context.rs +++ b/wren-core/sqllogictest/src/test_context.rs @@ -63,7 +63,7 @@ impl TestContext { .with_target_partitions(4) .with_information_schema(true); - let ctx = create_wren_ctx(Some(config)); + let ctx = create_wren_ctx(Some(config), None); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { diff --git a/wren-core/wren-example/examples/plan-sql.rs b/wren-core/wren-example/examples/plan-sql.rs index 130ed9693..7fa876d92 100644 --- a/wren-core/wren-example/examples/plan-sql.rs +++ b/wren-core/wren-example/examples/plan-sql.rs @@ -1,12 +1,11 @@ -use datafusion::prelude::SessionContext; use std::collections::HashMap; use std::sync::Arc; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, }; use wren_core::mdl::context::Mode; -use wren_core::mdl::manifest::{JoinType, Manifest}; -use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; +use wren_core::mdl::manifest::{DataSource, JoinType, Manifest}; +use wren_core::mdl::{create_wren_ctx, transform_sql_with_ctx, AnalyzedWrenMDL}; #[tokio::main] async fn main() -> datafusion::common::Result<()> { @@ -20,7 +19,7 @@ async fn main() -> datafusion::common::Result<()> { let sql = "select customer_state from wrenai.public.orders_model"; println!("Original SQL: \n{sql}"); let sql = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None, Some(&DataSource::BigQuery)), analyzed_mdl, &[], HashMap::new().into(), diff --git a/wren-core/wren-example/examples/to-many-calculation.rs b/wren-core/wren-example/examples/to-many-calculation.rs index 9bd70a570..4f09b8a00 100644 --- a/wren-core/wren-example/examples/to-many-calculation.rs +++ b/wren-core/wren-example/examples/to-many-calculation.rs @@ -17,7 +17,7 @@ async fn main() -> Result<()> { let manifest = init_manifest(); // register the table - let ctx = create_wren_ctx(None); + let ctx = create_wren_ctx(None, None); ctx.register_csv( "orders", "sqllogictest/tests/resources/ecommerce/orders.csv",