Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions wren-core-py/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use tokio::runtime::Runtime;
use wren_core::array::AsArray;
use wren_core::ast::{visit_statements_mut, Expr, Statement, Value, ValueWithSpan};
use wren_core::dialect::GenericDialect;
use wren_core::mdl::context::create_ctx_with_mdl;
use wren_core::mdl::context::apply_wren_on_ctx;
use wren_core::mdl::function::{
ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType,
RemoteFunction,
Expand Down Expand Up @@ -91,11 +91,11 @@ impl PySessionContext {
.collect::<Vec<_>>();

let config = SessionConfig::default().with_information_schema(true);
let ctx = wren_core::SessionContext::new_with_config(config);
let ctx = wren_core::mdl::create_wren_ctx(Some(config));
let runtime = Runtime::new().map_err(CoreError::from)?;

let registered_functions = runtime
.block_on(Self::get_regietered_functions(&ctx))
.block_on(Self::get_registered_functions(&ctx))
.map(|functions| {
functions
.into_iter()
Expand Down Expand Up @@ -169,7 +169,7 @@ impl PySessionContext {
Ok(analyzed_mdl) => {
let analyzed_mdl = Arc::new(analyzed_mdl);
let unparser_ctx = runtime
.block_on(create_ctx_with_mdl(
.block_on(apply_wren_on_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
Arc::clone(&properties_ref),
Expand All @@ -178,7 +178,7 @@ impl PySessionContext {
.map_err(CoreError::from)?;

let exec_ctx = runtime
.block_on(create_ctx_with_mdl(
.block_on(apply_wren_on_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
Arc::clone(&properties_ref),
Expand Down Expand Up @@ -226,7 +226,7 @@ impl PySessionContext {
pub fn get_available_functions(&self) -> PyResult<Vec<PyRemoteFunction>> {
let registered_functions: Vec<PyRemoteFunction> = self
.runtime
.block_on(Self::get_regietered_functions(&self.exec_ctx))
.block_on(Self::get_registered_functions(&self.exec_ctx))
.map_err(CoreError::from)?
.into_iter()
.map(|f| f.into())
Expand Down Expand Up @@ -321,7 +321,7 @@ impl PySessionContext {
/// The `name` is the name of the function.
/// The `function_type` is the type of the function. (e.g. scalar, aggregate, window)
/// The `description` is the description of the function.
async fn get_regietered_functions(
async fn get_registered_functions(
ctx: &wren_core::SessionContext,
) -> PyResult<Vec<RemoteFunctionDto>> {
let sql = r#"
Expand Down
2 changes: 1 addition & 1 deletion wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use parking_lot::RwLock;
pub type SessionPropertiesRef = Arc<HashMap<String, Option<String>>>;

/// Apply Wren Rules to the context for sql generation.
pub async fn create_ctx_with_mdl(
pub async fn apply_wren_on_ctx(
ctx: &SessionContext,
analyzed_mdl: Arc<AnalyzedWrenMDL>,
properties: SessionPropertiesRef,
Expand Down
51 changes: 51 additions & 0 deletions wren-core/core/src/mdl/function/aggregate/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::sync::Arc;

use datafusion::{
functions_aggregate::{
approx_percentile_cont::approx_percentile_cont_udaf,
approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf, *,
},
logical_expr::AggregateUDF,
};

pub fn aggregate_functions() -> Vec<Arc<AggregateUDF>> {
vec![
array_agg::array_agg_udaf(),
first_last::first_value_udaf(),
first_last::last_value_udaf(),
covariance::covar_samp_udaf(),
covariance::covar_pop_udaf(),
correlation::corr_udaf(),
sum::sum_udaf(),
min_max::max_udaf(),
min_max::min_udaf(),
median::median_udaf(),
count::count_udaf(),
regr::regr_slope_udaf(),
regr::regr_intercept_udaf(),
regr::regr_count_udaf(),
regr::regr_r2_udaf(),
regr::regr_avgx_udaf(),
regr::regr_avgy_udaf(),
regr::regr_sxx_udaf(),
regr::regr_syy_udaf(),
regr::regr_sxy_udaf(),
variance::var_samp_udaf(),
variance::var_pop_udaf(),
stddev::stddev_udaf(),
stddev::stddev_pop_udaf(),
approx_median::approx_median_udaf(),
approx_distinct::approx_distinct_udaf(),
approx_percentile_cont_udaf(),
approx_percentile_cont_with_weight_udaf(),
string_agg::string_agg_udaf(),
bit_and_or_xor::bit_and_udaf(),
bit_and_or_xor::bit_or_udaf(),
bit_and_or_xor::bit_xor_udaf(),
bool_and_or::bool_and_udaf(),
bool_and_or::bool_or_udaf(),
average::avg_udaf(),
grouping::grouping_udaf(),
nth_value::nth_value_udaf(),
]
}
10 changes: 10 additions & 0 deletions wren-core/core/src/mdl/function/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
mod aggregate;
mod remote_function;
mod scalar;
mod table;
mod window;
pub use aggregate::aggregate_functions;
pub use remote_function::*;
pub use scalar::scalar_functions;
pub use table::table_functions;
pub use window::window_functions;
186 changes: 186 additions & 0 deletions wren-core/core/src/mdl/function/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
use std::sync::Arc;

use datafusion::{
functions::{
core::*, crypto::*, datetime::*, encoding::*, math::*, regex::*, string::*,
unicode::*,
},
functions_nested::*,
logical_expr::ScalarUDF,
};

pub fn scalar_functions() -> Vec<Arc<ScalarUDF>> {
vec![
// datefusion core
nullif(),
arrow_cast(),
nvl(),
nvl2(),
overlay(),
arrow_typeof(),
named_struct(),
get_field(),
coalesce(),
greatest(),
least(),
union_extract(),
union_tag(),
version(),
r#struct(),
// datafusion crypto
digest(),
md5(),
sha224(),
sha256(),
sha384(),
sha512(),
// datafusion datetime
current_date(),
current_time(),
date_bin(),
date_part(),
date_trunc(),
date_diff(),
from_unixtime(),
make_date(),
now(),
to_char(),
to_date(),
to_local_time(),
to_unixtime(),
to_timestamp(),
to_timestamp_seconds(),
to_timestamp_millis(),
to_timestamp_micros(),
to_timestamp_nanos(),
// datafusion encoding
encode(),
decode(),
// datafusion math
abs(),
acos(),
acosh(),
asin(),
asinh(),
atan(),
atan2(),
atanh(),
cbrt(),
ceil(),
cos(),
cosh(),
cot(),
degrees(),
exp(),
factorial(),
floor(),
gcd(),
isnan(),
iszero(),
lcm(),
ln(),
log(),
log2(),
log10(),
nanvl(),
pi(),
power(),
radians(),
random(),
signum(),
sin(),
sinh(),
sqrt(),
tan(),
tanh(),
round(),
trunc(),
// datafusion regex
regexp_count(),
regexp_match(),
regexp_instr(),
regexp_like(),
regexp_replace(),
// datafusion string
ascii(),
bit_length(),
btrim(),
chr(),
concat(),
concat_ws(),
ends_with(),
levenshtein(),
lower(),
ltrim(),
octet_length(),
repeat(),
replace(),
rtrim(),
split_part(),
starts_with(),
to_hex(),
upper(),
uuid(),
contains(),
// datafusion unicode
character_length(),
find_in_set(),
initcap(),
left(),
lpad(),
reverse(),
right(),
rpad(),
strpos(),
substr(),
substr_index(),
translate(),
// datafusion nested
string::array_to_string_udf(),
string::string_to_array_udf(),
range::range_udf(),
range::gen_series_udf(),
dimension::array_dims_udf(),
cardinality::cardinality_udf(),
dimension::array_ndims_udf(),
datafusion::functions_nested::concat::array_append_udf(),
datafusion::functions_nested::concat::array_prepend_udf(),
datafusion::functions_nested::concat::array_concat_udf(),
except::array_except_udf(),
extract::array_element_udf(),
extract::array_pop_back_udf(),
extract::array_pop_front_udf(),
extract::array_slice_udf(),
extract::array_any_value_udf(),
make_array::make_array_udf(),
array_has::array_has_udf(),
array_has::array_has_all_udf(),
array_has::array_has_any_udf(),
empty::array_empty_udf(),
length::array_length_udf(),
distance::array_distance_udf(),
flatten::flatten_udf(),
min_max::array_max_udf(),
min_max::array_min_udf(),
sort::array_sort_udf(),
datafusion::functions_nested::repeat::array_repeat_udf(),
resize::array_resize_udf(),
datafusion::functions_nested::reverse::array_reverse_udf(),
set_ops::array_distinct_udf(),
set_ops::array_intersect_udf(),
set_ops::array_union_udf(),
position::array_position_udf(),
position::array_positions_udf(),
remove::array_remove_udf(),
remove::array_remove_all_udf(),
remove::array_remove_n_udf(),
datafusion::functions_nested::replace::array_replace_n_udf(),
datafusion::functions_nested::replace::array_replace_all_udf(),
datafusion::functions_nested::replace::array_replace_udf(),
map::map_udf(),
map_entries::map_entries_udf(),
map_extract::map_extract_udf(),
map_keys::map_keys_udf(),
map_values::map_values_udf(),
]
}
11 changes: 11 additions & 0 deletions wren-core/core/src/mdl/function/table/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use std::sync::Arc;

use datafusion::{
catalog::TableFunction,
functions_table::{generate_series, range},
};

/// Returns all default table functions
pub fn table_functions() -> Vec<Arc<TableFunction>> {
vec![generate_series(), range()]
}
19 changes: 19 additions & 0 deletions wren-core/core/src/mdl/function/window/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use std::sync::Arc;

use datafusion::{functions_window::*, logical_expr::WindowUDF};

pub fn window_functions() -> Vec<Arc<WindowUDF>> {
vec![
cume_dist::cume_dist_udwf(),
row_number::row_number_udwf(),
lead_lag::lead_udwf(),
lead_lag::lag_udwf(),
rank::rank_udwf(),
rank::dense_rank_udwf(),
rank::percent_rank_udwf(),
ntile::ntile_udwf(),
nth_value::first_value_udwf(),
nth_value::last_value_udwf(),
nth_value::nth_value_udwf(),
]
}
Loading