diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index c0561ed8a..c0bff3749 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -32,7 +32,7 @@ use tokio::runtime::Runtime; use wren_core::array::AsArray; use wren_core::ast::{visit_statements_mut, Expr, Statement, Value, ValueWithSpan}; use wren_core::dialect::GenericDialect; -use wren_core::mdl::context::create_ctx_with_mdl; +use wren_core::mdl::context::apply_wren_on_ctx; use wren_core::mdl::function::{ ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType, RemoteFunction, @@ -91,11 +91,11 @@ impl PySessionContext { .collect::>(); let config = SessionConfig::default().with_information_schema(true); - let ctx = wren_core::SessionContext::new_with_config(config); + let ctx = wren_core::mdl::create_wren_ctx(Some(config)); let runtime = Runtime::new().map_err(CoreError::from)?; let registered_functions = runtime - .block_on(Self::get_regietered_functions(&ctx)) + .block_on(Self::get_registered_functions(&ctx)) .map(|functions| { functions .into_iter() @@ -169,7 +169,7 @@ impl PySessionContext { Ok(analyzed_mdl) => { let analyzed_mdl = Arc::new(analyzed_mdl); let unparser_ctx = runtime - .block_on(create_ctx_with_mdl( + .block_on(apply_wren_on_ctx( &ctx, Arc::clone(&analyzed_mdl), Arc::clone(&properties_ref), @@ -178,7 +178,7 @@ impl PySessionContext { .map_err(CoreError::from)?; let exec_ctx = runtime - .block_on(create_ctx_with_mdl( + .block_on(apply_wren_on_ctx( &ctx, Arc::clone(&analyzed_mdl), Arc::clone(&properties_ref), @@ -226,7 +226,7 @@ impl PySessionContext { pub fn get_available_functions(&self) -> PyResult> { let registered_functions: Vec = self .runtime - .block_on(Self::get_regietered_functions(&self.exec_ctx)) + .block_on(Self::get_registered_functions(&self.exec_ctx)) .map_err(CoreError::from)? .into_iter() .map(|f| f.into()) @@ -321,7 +321,7 @@ impl PySessionContext { /// The `name` is the name of the function. /// The `function_type` is the type of the function. (e.g. scalar, aggregate, window) /// The `description` is the description of the function. - async fn get_regietered_functions( + async fn get_registered_functions( ctx: &wren_core::SessionContext, ) -> PyResult> { let sql = r#" diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 54e1059af..b6c9b92ce 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -43,7 +43,7 @@ use parking_lot::RwLock; pub type SessionPropertiesRef = Arc>>; /// Apply Wren Rules to the context for sql generation. -pub async fn create_ctx_with_mdl( +pub async fn apply_wren_on_ctx( ctx: &SessionContext, analyzed_mdl: Arc, properties: SessionPropertiesRef, diff --git a/wren-core/core/src/mdl/function/aggregate/mod.rs b/wren-core/core/src/mdl/function/aggregate/mod.rs new file mode 100644 index 000000000..895920fe3 --- /dev/null +++ b/wren-core/core/src/mdl/function/aggregate/mod.rs @@ -0,0 +1,51 @@ +use std::sync::Arc; + +use datafusion::{ + functions_aggregate::{ + approx_percentile_cont::approx_percentile_cont_udaf, + approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf, *, + }, + logical_expr::AggregateUDF, +}; + +pub fn aggregate_functions() -> Vec> { + vec![ + array_agg::array_agg_udaf(), + first_last::first_value_udaf(), + first_last::last_value_udaf(), + covariance::covar_samp_udaf(), + covariance::covar_pop_udaf(), + correlation::corr_udaf(), + sum::sum_udaf(), + min_max::max_udaf(), + min_max::min_udaf(), + median::median_udaf(), + count::count_udaf(), + regr::regr_slope_udaf(), + regr::regr_intercept_udaf(), + regr::regr_count_udaf(), + regr::regr_r2_udaf(), + regr::regr_avgx_udaf(), + regr::regr_avgy_udaf(), + regr::regr_sxx_udaf(), + regr::regr_syy_udaf(), + regr::regr_sxy_udaf(), + variance::var_samp_udaf(), + variance::var_pop_udaf(), + stddev::stddev_udaf(), + stddev::stddev_pop_udaf(), + approx_median::approx_median_udaf(), + approx_distinct::approx_distinct_udaf(), + approx_percentile_cont_udaf(), + approx_percentile_cont_with_weight_udaf(), + string_agg::string_agg_udaf(), + bit_and_or_xor::bit_and_udaf(), + bit_and_or_xor::bit_or_udaf(), + bit_and_or_xor::bit_xor_udaf(), + bool_and_or::bool_and_udaf(), + bool_and_or::bool_or_udaf(), + average::avg_udaf(), + grouping::grouping_udaf(), + nth_value::nth_value_udaf(), + ] +} diff --git a/wren-core/core/src/mdl/function/mod.rs b/wren-core/core/src/mdl/function/mod.rs new file mode 100644 index 000000000..54526b1ac --- /dev/null +++ b/wren-core/core/src/mdl/function/mod.rs @@ -0,0 +1,10 @@ +mod aggregate; +mod remote_function; +mod scalar; +mod table; +mod window; +pub use aggregate::aggregate_functions; +pub use remote_function::*; +pub use scalar::scalar_functions; +pub use table::table_functions; +pub use window::window_functions; diff --git a/wren-core/core/src/mdl/function.rs b/wren-core/core/src/mdl/function/remote_function.rs similarity index 100% rename from wren-core/core/src/mdl/function.rs rename to wren-core/core/src/mdl/function/remote_function.rs diff --git a/wren-core/core/src/mdl/function/scalar/mod.rs b/wren-core/core/src/mdl/function/scalar/mod.rs new file mode 100644 index 000000000..7219efacf --- /dev/null +++ b/wren-core/core/src/mdl/function/scalar/mod.rs @@ -0,0 +1,186 @@ +use std::sync::Arc; + +use datafusion::{ + functions::{ + core::*, crypto::*, datetime::*, encoding::*, math::*, regex::*, string::*, + unicode::*, + }, + functions_nested::*, + logical_expr::ScalarUDF, +}; + +pub fn scalar_functions() -> Vec> { + vec![ + // datefusion core + nullif(), + arrow_cast(), + nvl(), + nvl2(), + overlay(), + arrow_typeof(), + named_struct(), + get_field(), + coalesce(), + greatest(), + least(), + union_extract(), + union_tag(), + version(), + r#struct(), + // datafusion crypto + digest(), + md5(), + sha224(), + sha256(), + sha384(), + sha512(), + // datafusion datetime + current_date(), + current_time(), + date_bin(), + date_part(), + date_trunc(), + date_diff(), + from_unixtime(), + make_date(), + now(), + to_char(), + to_date(), + to_local_time(), + to_unixtime(), + to_timestamp(), + to_timestamp_seconds(), + to_timestamp_millis(), + to_timestamp_micros(), + to_timestamp_nanos(), + // datafusion encoding + encode(), + decode(), + // datafusion math + abs(), + acos(), + acosh(), + asin(), + asinh(), + atan(), + atan2(), + atanh(), + cbrt(), + ceil(), + cos(), + cosh(), + cot(), + degrees(), + exp(), + factorial(), + floor(), + gcd(), + isnan(), + iszero(), + lcm(), + ln(), + log(), + log2(), + log10(), + nanvl(), + pi(), + power(), + radians(), + random(), + signum(), + sin(), + sinh(), + sqrt(), + tan(), + tanh(), + round(), + trunc(), + // datafusion regex + regexp_count(), + regexp_match(), + regexp_instr(), + regexp_like(), + regexp_replace(), + // datafusion string + ascii(), + bit_length(), + btrim(), + chr(), + concat(), + concat_ws(), + ends_with(), + levenshtein(), + lower(), + ltrim(), + octet_length(), + repeat(), + replace(), + rtrim(), + split_part(), + starts_with(), + to_hex(), + upper(), + uuid(), + contains(), + // datafusion unicode + character_length(), + find_in_set(), + initcap(), + left(), + lpad(), + reverse(), + right(), + rpad(), + strpos(), + substr(), + substr_index(), + translate(), + // datafusion nested + string::array_to_string_udf(), + string::string_to_array_udf(), + range::range_udf(), + range::gen_series_udf(), + dimension::array_dims_udf(), + cardinality::cardinality_udf(), + dimension::array_ndims_udf(), + datafusion::functions_nested::concat::array_append_udf(), + datafusion::functions_nested::concat::array_prepend_udf(), + datafusion::functions_nested::concat::array_concat_udf(), + except::array_except_udf(), + extract::array_element_udf(), + extract::array_pop_back_udf(), + extract::array_pop_front_udf(), + extract::array_slice_udf(), + extract::array_any_value_udf(), + make_array::make_array_udf(), + array_has::array_has_udf(), + array_has::array_has_all_udf(), + array_has::array_has_any_udf(), + empty::array_empty_udf(), + length::array_length_udf(), + distance::array_distance_udf(), + flatten::flatten_udf(), + min_max::array_max_udf(), + min_max::array_min_udf(), + sort::array_sort_udf(), + datafusion::functions_nested::repeat::array_repeat_udf(), + resize::array_resize_udf(), + datafusion::functions_nested::reverse::array_reverse_udf(), + set_ops::array_distinct_udf(), + set_ops::array_intersect_udf(), + set_ops::array_union_udf(), + position::array_position_udf(), + position::array_positions_udf(), + remove::array_remove_udf(), + remove::array_remove_all_udf(), + remove::array_remove_n_udf(), + datafusion::functions_nested::replace::array_replace_n_udf(), + datafusion::functions_nested::replace::array_replace_all_udf(), + datafusion::functions_nested::replace::array_replace_udf(), + map::map_udf(), + map_entries::map_entries_udf(), + map_extract::map_extract_udf(), + map_keys::map_keys_udf(), + map_values::map_values_udf(), + ] +} diff --git a/wren-core/core/src/mdl/function/table/mod.rs b/wren-core/core/src/mdl/function/table/mod.rs new file mode 100644 index 000000000..1a1e1a328 --- /dev/null +++ b/wren-core/core/src/mdl/function/table/mod.rs @@ -0,0 +1,11 @@ +use std::sync::Arc; + +use datafusion::{ + catalog::TableFunction, + functions_table::{generate_series, range}, +}; + +/// Returns all default table functions +pub fn table_functions() -> Vec> { + vec![generate_series(), range()] +} diff --git a/wren-core/core/src/mdl/function/window/mod.rs b/wren-core/core/src/mdl/function/window/mod.rs new file mode 100644 index 000000000..1d5916422 --- /dev/null +++ b/wren-core/core/src/mdl/function/window/mod.rs @@ -0,0 +1,19 @@ +use std::sync::Arc; + +use datafusion::{functions_window::*, logical_expr::WindowUDF}; + +pub fn window_functions() -> Vec> { + vec![ + cume_dist::cume_dist_udwf(), + row_number::row_number_udwf(), + lead_lag::lead_udwf(), + lead_lag::lag_udwf(), + rank::rank_udwf(), + rank::dense_rank_udwf(), + rank::percent_rank_udwf(), + ntile::ntile_udwf(), + nth_value::first_value_udwf(), + nth_value::last_value_udwf(), + nth_value::nth_value_udwf(), + ] +} diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index d28ccccef..3d42e62e3 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -2,7 +2,7 @@ use crate::logical_plan::analyze::access_control::validate_clac_rule; use crate::logical_plan::error::WrenError; use crate::logical_plan::utils::{from_qualified_name_str, try_map_data_type}; use crate::mdl::builder::ManifestBuilder; -use crate::mdl::context::{create_ctx_with_mdl, Mode, WrenDataSource}; +use crate::mdl::context::{apply_wren_on_ctx, Mode, WrenDataSource}; use crate::mdl::function::{ ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType, RemoteFunction, @@ -16,8 +16,9 @@ use datafusion::common::internal_datafusion_err; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::SessionState; +use datafusion::execution::{SessionStateBuilder, SessionStateDefaults}; use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; -use datafusion::prelude::SessionContext; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion::sql::parser::DFParser; use datafusion::sql::sqlparser::ast::{Expr, ExprWithAlias, Ident}; use datafusion::sql::sqlparser::dialect::dialect_from_str; @@ -362,6 +363,24 @@ impl WrenMDL { } } +/// Create a SessionContext with the default functions registered +pub fn create_wren_ctx(config: Option) -> SessionContext { + let builder = SessionStateBuilder::new() + .with_expr_planners(SessionStateDefaults::default_expr_planners()) + .with_scalar_functions(crate::mdl::function::scalar_functions()) + .with_aggregate_functions(crate::mdl::function::aggregate_functions()) + .with_window_functions(crate::mdl::function::window_functions()) + .with_table_function_list(crate::mdl::function::table_functions()); + + let builder = if let Some(config) = config { + builder.with_config(config) + } else { + builder + }; + + SessionContext::new_with_state(builder.build()) +} + /// Transform the SQL based on the MDL pub fn transform_sql( analyzed_mdl: Arc, @@ -371,7 +390,7 @@ pub fn transform_sql( ) -> Result { let runtime = tokio::runtime::Runtime::new().unwrap(); runtime.block_on(transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), analyzed_mdl, remote_functions, Arc::new(properties), @@ -393,7 +412,7 @@ pub async fn transform_sql_with_ctx( register_remote_function(ctx, remote_function)?; Ok::<_, DataFusionError>(()) })?; - let ctx = create_ctx_with_mdl( + let ctx = apply_wren_on_ctx( ctx, Arc::clone(&analyzed_mdl), Arc::clone(&properties), @@ -458,15 +477,14 @@ async fn permission_analyze( Arc::clone(&properties), Mode::PermissionAnalyze, )?); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); remote_functions.iter().try_for_each(|remote_function| { debug!("Registering remote function: {remote_function:?}"); register_remote_function(&ctx, remote_function)?; Ok::<_, DataFusionError>(()) })?; - let ctx = - create_ctx_with_mdl(&ctx, analyzed_mdl, properties, Mode::PermissionAnalyze) - .await?; + let ctx = apply_wren_on_ctx(&ctx, analyzed_mdl, properties, Mode::PermissionAnalyze) + .await?; let plan = match ctx.state().create_logical_plan(sql).await { Ok(plan) => plan, @@ -550,11 +568,11 @@ mod test { use std::sync::Arc; use crate::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; - use crate::mdl::context::{create_ctx_with_mdl, Mode, SessionPropertiesRef}; + use crate::mdl::context::{apply_wren_on_ctx, Mode, SessionPropertiesRef}; use crate::mdl::function::RemoteFunction; use crate::mdl::manifest::DataSource::MySQL; use crate::mdl::manifest::Manifest; - use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL}; + use crate::mdl::{self, create_wren_ctx, transform_sql_with_ctx, AnalyzedWrenMDL}; use datafusion::arrow::array::{ ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray, }; @@ -562,7 +580,6 @@ mod test { use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; use datafusion::common::not_impl_err; use datafusion::common::Result; - use datafusion::prelude::SessionContext; use datafusion::sql::unparser::plan_to_sql; use insta::assert_snapshot; use wren_core_base::mdl::{ @@ -627,7 +644,7 @@ mod test { for sql in tests { println!("Original: {sql}"); let actual = mdl::transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -660,7 +677,7 @@ mod test { let sql = "select * from test.test.customer_view"; println!("Original: {sql}"); let _ = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -692,7 +709,7 @@ mod test { )?); let sql = "select totalcost from profile"; let result = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -703,7 +720,7 @@ mod test { let sql = "select totalcost from profile where p_sex = 'M'"; let result = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -717,7 +734,7 @@ mod test { #[tokio::test] async fn test_uppercase_catalog_schema() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("customer", customer())?; let manifest = ManifestBuilder::new() .catalog("CTest") @@ -737,7 +754,7 @@ mod test { )?); let sql = r#"select * from CTest.STest.Customer"#; let actual = mdl::transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -758,7 +775,7 @@ mod test { [env!("CARGO_MANIFEST_DIR"), "tests", "data", "functions.csv"] .iter() .collect(); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let functions = csv::Reader::from_path(test_data) .unwrap() .into_deserialize::() @@ -817,7 +834,7 @@ mod test { #[tokio::test] async fn test_unicode_remote_column_name() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -856,7 +873,7 @@ mod test { )?); let sql = r#"select * from wren.test.artist"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -870,7 +887,7 @@ mod test { let sql = r#"select group from wren.test.artist"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -883,7 +900,7 @@ mod test { let sql = r#"select subscribe_plus from wren.test.artist"#; let actual = mdl::transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -897,7 +914,7 @@ mod test { #[tokio::test] async fn test_invalid_infer_remote_table() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -926,7 +943,7 @@ mod test { )?); let sql = r#"select name_append from wren.test.artist"#; let _ = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -942,7 +959,7 @@ mod test { let sql = r#"select lower_name from wren.test.artist"#; let _ = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -960,7 +977,7 @@ mod test { #[tokio::test] async fn test_query_hidden_column() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -985,7 +1002,7 @@ mod test { )?); let sql = r#"select 串接名字 from wren.test.artist"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -996,7 +1013,7 @@ mod test { @"SELECT artist.\"串接名字\" FROM (SELECT artist.\"串接名字\" FROM (SELECT __source.\"名字\" || __source.\"名字\" AS \"串接名字\" FROM artist AS __source) AS artist) AS artist"); let sql = r#"select * from wren.test.artist"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1008,7 +1025,7 @@ mod test { let sql = r#"select "名字" from wren.test.artist"#; let _ = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1027,7 +1044,7 @@ mod test { async fn test_disable_simplify_expression() -> Result<()> { let sql = "select current_date"; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::new(AnalyzedWrenMDL::default()), &[], Arc::new(HashMap::new()), @@ -1058,7 +1075,7 @@ mod test { )?); let sql = r#"select * from wren.test.artist where 名字 in (SELECT 名字 FROM wren.test.artist)"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1073,7 +1090,7 @@ mod test { /// This test will be failed if the `出道時間` is not inferred as a timestamp column correctly. #[tokio::test] async fn test_infer_timestamp_column() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -1093,7 +1110,7 @@ mod test { )?); let sql = r#"select current_date > "出道時間" from wren.test.artist"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1108,7 +1125,7 @@ mod test { #[tokio::test] async fn test_disable_count_wildcard_rule() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select count(*) from (select 1)"; @@ -1127,7 +1144,7 @@ mod test { } async fn assert_sql_valid_executable(sql: &str) -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); // To roundtrip testing, we should register the mock table for the planned sql. ctx.register_batch("orders", orders())?; ctx.register_batch("customer", customer())?; @@ -1149,7 +1166,7 @@ mod test { #[tokio::test] async fn test_mysql_style_interval() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select interval 1 day"; let actual = transform_sql_with_ctx( @@ -1191,7 +1208,7 @@ mod test { #[tokio::test] async fn test_unnest_as_table_factor() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new().build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, @@ -1232,7 +1249,7 @@ mod test { #[tokio::test] async fn test_simplify_timestamp() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp '2011-01-01 18:00:00 +08:00'"; let actual = transform_sql_with_ctx( @@ -1263,7 +1280,7 @@ mod test { let mut headers = HashMap::new(); headers.insert("x-wren-timezone".to_string(), Some("+08:00".to_string())); let headers_ref = Arc::new(headers); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp '2011-01-01 18:00:00'"; let actual = transform_sql_with_ctx( @@ -1289,7 +1306,7 @@ mod test { // TIMESTAMP WITH TIME ZONE will be converted to the session timezone assert_snapshot!(actual, @"SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let mut headers = HashMap::new(); headers.insert( "x-wren-timezone".to_string(), @@ -1323,7 +1340,7 @@ mod test { let headers = HashMap::new(); let headers_ref = Arc::new(headers); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp with time zone '2011-01-01 18:00:00' - timestamp with time zone '2011-01-01 10:00:00'"; let actual = transform_sql_with_ctx( @@ -1342,7 +1359,7 @@ mod test { #[tokio::test] async fn test_disable_pushdown_filter() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("artist", artist())?; let manifest = ManifestBuilder::new() .catalog("wren") @@ -1371,7 +1388,7 @@ mod test { )?); let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamptz as timestamp) > timestamp '2011-01-01 21:00:00'"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1389,7 +1406,7 @@ mod test { #[tokio::test] async fn test_register_timestamptz() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("timestamp_table", timestamp_table())?; let provider = ctx .catalog("datafusion") @@ -1419,7 +1436,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, registers)?); let properties_ref = Arc::new(HashMap::new()); - let ctx = create_ctx_with_mdl( + let ctx = apply_wren_on_ctx( &ctx, Arc::clone(&analyzed_mdl), properties_ref, @@ -1440,7 +1457,7 @@ mod test { #[tokio::test] async fn test_coercion_timestamptz() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_batch("timestamp_table", timestamp_table())?; for timezone_type in [ "timestamptz", @@ -1467,7 +1484,7 @@ mod test { )?); let sql = r#"select timestamp_col = timestamptz_col from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1482,7 +1499,7 @@ mod test { let sql = r#"select timestamptz_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1496,7 +1513,7 @@ mod test { let sql = r#"select timestamptz_col > '2011-01-01 18:00:00' from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1511,7 +1528,7 @@ mod test { let sql = r#"select timestamp_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( - &SessionContext::new(), + &create_wren_ctx(None), Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), @@ -1527,7 +1544,7 @@ mod test { #[tokio::test] async fn test_list() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1559,7 +1576,7 @@ mod test { #[tokio::test] async fn test_struct() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1658,7 +1675,7 @@ mod test { #[tokio::test] async fn test_disable_common_expression_eliminate() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let sql = "SELECT CAST(TIMESTAMP '2021-01-01 00:00:00' as TIMESTAMP WITH TIME ZONE) = \ CAST(TIMESTAMP '2021-01-01 00:00:00' as TIMESTAMP WITH TIME ZONE)"; @@ -1677,7 +1694,7 @@ mod test { #[tokio::test] async fn test_disable_eliminate_nested_union() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let sql = r#"SELECT * FROM (SELECT 1 x, 'a' y UNION ALL SELECT 1 x, 'b' y UNION ALL SELECT 2 x, 'a' y UNION ALL @@ -1708,7 +1725,7 @@ mod test { Arc::new(HashMap::default()), Mode::Unparse, )?); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let sql = "SELECT trim(' abc')"; let actual = transform_sql_with_ctx( &ctx, @@ -1724,7 +1741,7 @@ mod test { #[tokio::test] async fn test_disable_single_distinct_to_group_by() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1762,7 +1779,7 @@ mod test { #[tokio::test] async fn test_disable_distinct_to_group_by() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1797,7 +1814,7 @@ mod test { #[tokio::test] async fn test_disable_scalar_subquery() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1832,7 +1849,7 @@ mod test { #[tokio::test] async fn test_wildcard_where() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -1894,7 +1911,7 @@ mod test { } "#; let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, @@ -1945,7 +1962,7 @@ mod test { } "#; let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, @@ -1969,7 +1986,7 @@ mod test { #[tokio::test] async fn test_rlac_with_requried_properties() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); // test required property let manifest = ManifestBuilder::new() @@ -2181,7 +2198,7 @@ mod test { #[tokio::test] async fn test_rlac_with_optional_properties() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); // test required property let manifest = ManifestBuilder::new() @@ -2372,7 +2389,7 @@ mod test { #[tokio::test] async fn test_rlac_on_calculated_field() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -2543,7 +2560,7 @@ mod test { #[tokio::test] async fn test_rlac_alias_model() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2606,7 +2623,7 @@ mod test { #[tokio::test] async fn test_rlac_unicode_model_column_name() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2646,7 +2663,7 @@ mod test { #[tokio::test] async fn test_ralc_condition_contain_hidden() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -2694,7 +2711,7 @@ mod test { #[tokio::test] async fn test_clac_with_required_properties() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -2791,7 +2808,7 @@ mod test { #[tokio::test] async fn test_clac_permission_denied() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2864,7 +2881,7 @@ mod test { #[tokio::test] async fn test_calc_primary_key() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2906,7 +2923,7 @@ mod test { #[tokio::test] async fn test_clac_with_optional_properties() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -3014,7 +3031,7 @@ mod test { #[tokio::test] async fn test_clac_on_calculated_field() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") @@ -3166,7 +3183,7 @@ mod test { #[tokio::test] async fn test_rlac_case_insensitive() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); // test required property let manifest = ManifestBuilder::new() @@ -3204,7 +3221,7 @@ mod test { #[tokio::test] async fn test_disable_eliminate_limit() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); // test required property let manifest = ManifestBuilder::new() @@ -3234,7 +3251,7 @@ mod test { #[tokio::test] async fn test_default_nulls_last() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); // test required property let manifest = ManifestBuilder::new() @@ -3295,7 +3312,7 @@ mod test { #[tokio::test] async fn test_extract_roundtrip_bigquery() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3354,7 +3371,7 @@ mod test { #[tokio::test] async fn test_date_diff_bigquery() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3438,7 +3455,7 @@ mod test { #[tokio::test] async fn test_window_function_frame() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3475,7 +3492,7 @@ mod test { #[tokio::test] async fn test_window_functions_without_frame_bigquery() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3505,7 +3522,7 @@ mod test { #[tokio::test] async fn test_cte_used_in_scalar_subquery() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3548,7 +3565,7 @@ mod test { #[tokio::test] async fn test_ambiguous_table_name() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -3611,7 +3628,7 @@ mod test { #[tokio::test] async fn test_unicode_literal() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::default().build(); let properties = SessionPropertiesRef::default(); @@ -3645,7 +3662,7 @@ mod test { #[tokio::test] async fn test_compatible_type() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::default().build(); let properties = SessionPropertiesRef::default(); @@ -3664,7 +3681,7 @@ mod test { #[tokio::test] async fn test_trim_function_bigquery() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") diff --git a/wren-core/sqllogictest/src/test_context.rs b/wren-core/sqllogictest/src/test_context.rs index 707e9f976..95b7e19d7 100644 --- a/wren-core/sqllogictest/src/test_context.rs +++ b/wren-core/sqllogictest/src/test_context.rs @@ -28,9 +28,9 @@ use tempfile::TempDir; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, ViewBuilder, }; -use wren_core::mdl::context::{create_ctx_with_mdl, Mode, SessionPropertiesRef}; +use wren_core::mdl::context::{apply_wren_on_ctx, Mode, SessionPropertiesRef}; use wren_core::mdl::manifest::JoinType; -use wren_core::mdl::AnalyzedWrenMDL; +use wren_core::mdl::{create_wren_ctx, AnalyzedWrenMDL}; const TEST_RESOURCES: &str = "tests/resources"; @@ -63,7 +63,7 @@ impl TestContext { .with_target_partitions(4) .with_information_schema(true); - let ctx = SessionContext::new_with_config(config); + let ctx = create_wren_ctx(Some(config)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { @@ -78,7 +78,7 @@ impl TestContext { _ => { info!("Using default SessionContext"); let mdl = Arc::new(AnalyzedWrenMDL::default()); - let ctx = create_ctx_with_mdl( + let ctx = apply_wren_on_ctx( &ctx, mdl.clone(), SessionPropertiesRef::default(), @@ -311,7 +311,7 @@ async fn register_ecommerce_mdl( manifest, register_tables, )?); - let ctx = create_ctx_with_mdl( + let ctx = apply_wren_on_ctx( ctx, Arc::clone(&analyzed_mdl), Arc::new(HashMap::new()), @@ -547,7 +547,7 @@ async fn register_tpch_mdl( manifest, register_tables, )?); - let ctx = create_ctx_with_mdl( + let ctx = apply_wren_on_ctx( ctx, Arc::clone(&analyzed_mdl), Arc::new(HashMap::new()), diff --git a/wren-core/wren-example/examples/to-many-calculation.rs b/wren-core/wren-example/examples/to-many-calculation.rs index 60f2a3fd1..9bd70a570 100644 --- a/wren-core/wren-example/examples/to-many-calculation.rs +++ b/wren-core/wren-example/examples/to-many-calculation.rs @@ -2,14 +2,14 @@ use std::collections::HashMap; use std::sync::Arc; use datafusion::error::Result; -use datafusion::prelude::{CsvReadOptions, SessionContext}; +use datafusion::prelude::CsvReadOptions; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, }; -use wren_core::mdl::context::{create_ctx_with_mdl, Mode}; +use wren_core::mdl::context::{apply_wren_on_ctx, Mode}; use wren_core::mdl::manifest::{JoinType, Manifest}; -use wren_core::mdl::AnalyzedWrenMDL; +use wren_core::mdl::{create_wren_ctx, AnalyzedWrenMDL}; #[tokio::main] async fn main() -> Result<()> { @@ -17,7 +17,7 @@ async fn main() -> Result<()> { let manifest = init_manifest(); // register the table - let ctx = SessionContext::new(); + let ctx = create_wren_ctx(None); ctx.register_csv( "orders", "sqllogictest/tests/resources/ecommerce/orders.csv", @@ -76,7 +76,7 @@ async fn main() -> Result<()> { ]); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); - let ctx = create_ctx_with_mdl( + let ctx = apply_wren_on_ctx( &ctx, analyzed_mdl, Arc::new(HashMap::new()),