diff --git a/Cargo.lock b/Cargo.lock index e5981121..bca4bf06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -680,6 +680,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crunchy" version = "0.2.2" @@ -725,11 +731,12 @@ checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" [[package]] name = "dashmap" -version = "5.5.3" +version = "6.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28" dependencies = [ "cfg-if", + "crossbeam-utils", "hashbrown", "lock_api", "once_cell", @@ -738,9 +745,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab9d55a9cd2634818953809f75ebe5248b00dd43c3227efb2a51a2d5feaad54e" +checksum = "e4fd4a99fc70d40ef7e52b243b4a399c3f8d353a40d5ecb200deee05e49c61bb" dependencies = [ "ahash", "apache-avro", @@ -754,16 +761,18 @@ dependencies = [ "bzip2", "chrono", "dashmap", + "datafusion-catalog", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "datafusion-functions-array", + "datafusion-functions-nested", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", + "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", "flate2", @@ -792,11 +801,25 @@ dependencies = [ "zstd 0.13.2", ] +[[package]] +name = "datafusion-catalog" +version = "41.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b3cfbd84c6003594ae1972314e3df303a27ce8ce755fcea3240c90f4c0529" +dependencies = [ + "arrow-schema", + "async-trait", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", +] + [[package]] name = "datafusion-common" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "def66b642959e7f96f5d2da22e1f43d3bd35598f821e5ce351a0553e0f1b7367" +checksum = "44fdbc877e3e40dcf88cc8f283d9f5c8851f0a3aa07fee657b1b75ac1ad49b9c" dependencies = [ "ahash", "apache-avro", @@ -818,18 +841,18 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f104bb9cb44c06c9badf8a0d7e0855e5f7fa5e395b887d7f835e8a9457dc1352" +checksum = "8a7496d1f664179f6ce3a5cbef6566056ccaf3ea4aa72cc455f80e62c1dd86b1" dependencies = [ "tokio", ] [[package]] name = "datafusion-execution" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ac0fd8b5d80bbca3fc3b6f40da4e9f6907354824ec3b18bbd83fee8cf5c3c3e" +checksum = "799e70968c815b611116951e3dd876aef04bf217da31b72eec01ee6a959336a1" dependencies = [ "arrow", "chrono", @@ -848,9 +871,9 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2103d2cc16fb11ef1fa993a6cac57ed5cb028601db4b97566c90e5fa77aa1e68" +checksum = "1c1841c409d9518c17971d15c9bae62e629eb937e6fb6c68cd32e9186f8b30d2" dependencies = [ "ahash", "arrow", @@ -867,11 +890,12 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a369332afd0ef5bd565f6db2139fb9f1dfdd0afa75a7f70f000b74208d76994f" +checksum = "a8e481cf34d2a444bd8fa09b65945f0ce83dc92df8665b761505b3d9f351bebb" dependencies = [ "arrow", + "arrow-buffer", "base64 0.22.1", "blake2", "blake3", @@ -893,9 +917,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92718db1aff70c47e5abf9fc975768530097059e5db7c7b78cd64b5e9a11fc77" +checksum = "2b4ece19f73c02727e5e8654d79cd5652de371352c1df3c4ac3e419ecd6943fb" dependencies = [ "ahash", "arrow", @@ -910,10 +934,10 @@ dependencies = [ ] [[package]] -name = "datafusion-functions-array" -version = "40.0.0" +name = "datafusion-functions-nested" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bb80f46ff3dcf4bb4510209c2ba9b8ce1b716ac8b7bf70c6bf7dca6260c831" +checksum = "a1474552cc824e8c9c88177d454db5781d4b66757d4aca75719306b8343a5e8d" dependencies = [ "arrow", "arrow-array", @@ -928,13 +952,14 @@ dependencies = [ "itertools 0.12.1", "log", "paste", + "rand", ] [[package]] name = "datafusion-optimizer" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82f34692011bec4fdd6fc18c264bf8037b8625d801e6dd8f5111af15cb6d71d3" +checksum = "791ff56f55608bc542d1ea7a68a64bdc86a9413f5a381d06a39fd49c2a3ab906" dependencies = [ "arrow", "async-trait", @@ -952,9 +977,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45538630defedb553771434a437f7ca8f04b9b3e834344aafacecb27dc65d5e5" +checksum = "9a223962b3041304a3e20ed07a21d5de3d88d7e4e71ca192135db6d24e3365a4" dependencies = [ "ahash", "arrow", @@ -982,9 +1007,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d8a72b0ca908e074aaeca52c14ddf5c28d22361e9cb6bc79bb733cd6661b536" +checksum = "db5e7d8532a1601cd916881db87a70b0a599900d23f3db2897d389032da53bc6" dependencies = [ "ahash", "arrow", @@ -994,11 +1019,23 @@ dependencies = [ "rand", ] +[[package]] +name = "datafusion-physical-optimizer" +version = "41.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb9c78f308e050f5004671039786a925c3fee83b90004e9fcfd328d7febdcc0" +dependencies = [ + "datafusion-common", + "datafusion-execution", + "datafusion-physical-expr", + "datafusion-physical-plan", +] + [[package]] name = "datafusion-physical-plan" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b504eae6107a342775e22e323e9103f7f42db593ec6103b28605b7b7b1405c4a" +checksum = "8d1116949432eb2d30f6362707e2846d942e491052a206f2ddcb42d08aea1ffe" dependencies = [ "ahash", "arrow", @@ -1037,7 +1074,7 @@ dependencies = [ "datafusion", "datafusion-common", "datafusion-expr", - "datafusion-functions-array", + "datafusion-functions-nested", "datafusion-optimizer", "datafusion-sql", "datafusion-substrait", @@ -1051,7 +1088,6 @@ dependencies = [ "pyo3-build-config", "rand", "regex-syntax", - "sqlparser", "syn 2.0.72", "tokio", "url", @@ -1060,9 +1096,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5db33f323f41b95ae201318ba654a9bf11113e58a51a1dff977b1a836d3d889" +checksum = "b45d0180711165fe94015d7c4123eb3e1cf5fb60b1506453200b8d1ce666bef0" dependencies = [ "arrow", "arrow-array", @@ -1077,9 +1113,9 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "40.0.0" +version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "434e52fbff22e6e04e6c787f603a6aba4961a7e249a29c743c5d4f609ec2dcef" +checksum = "bf0a0055aa98246c79f98f0d03df11f16cb7adc87818d02d4413e3f3cdadbbee" dependencies = [ "arrow-buffer", "async-recursion", @@ -2898,9 +2934,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.47.0" +version = "0.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "295e9930cd7a97e58ca2a070541a3ca502b17f5d1fa7157376d0fabd85324f25" +checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" dependencies = [ "log", "sqlparser_derive", diff --git a/Cargo.toml b/Cargo.toml index 820118fa..8881884b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,13 +38,13 @@ tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync rand = "0.8" pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] } arrow = { version = "52", feature = ["pyarrow"] } -datafusion = { version = "40.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } -datafusion-common = { version = "40.0.0", features = ["pyarrow"] } -datafusion-expr = "40.0.0" -datafusion-functions-array = "40.0.0" -datafusion-optimizer = "40.0.0" -datafusion-sql = "40.0.0" -datafusion-substrait = { version = "40.0.0", optional = true } +datafusion = { version = "41.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } +datafusion-common = { version = "41.0.0", features = ["pyarrow"] } +datafusion-expr = { version = "41.0.0" } +datafusion-functions-nested = { version = "41.0.0" } +datafusion-optimizer = { version = "41.0.0" } +datafusion-sql = { version = "41.0.0" } +datafusion-substrait = { version = "41.0.0", optional = true } prost = "0.12" # keep in line with `datafusion-substrait` prost-types = "0.12" # keep in line with `datafusion-substrait` uuid = { version = "1.9", features = ["v4"] } @@ -56,7 +56,6 @@ parking_lot = "0.12" regex-syntax = "0.8" syn = "2.0.68" url = "2" -sqlparser = "0.47.0" [build-dependencies] pyo3-build-config = "0.21" diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 29391232..b8ad9c0d 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -808,7 +808,7 @@ def test_regr_funcs_sql(df): assert result[0].column(0) == pa.array([None], type=pa.float64()) assert result[0].column(1) == pa.array([None], type=pa.float64()) - assert result[0].column(2) == pa.array([1], type=pa.float64()) + assert result[0].column(2) == pa.array([1], type=pa.uint64()) assert result[0].column(3) == pa.array([None], type=pa.float64()) assert result[0].column(4) == pa.array([1], type=pa.float64()) assert result[0].column(5) == pa.array([1], type=pa.float64()) @@ -840,7 +840,7 @@ def test_regr_funcs_sql_2(): # Assertions for SQL results assert result_sql[0].column(0) == pa.array([2], type=pa.float64()) assert result_sql[0].column(1) == pa.array([0], type=pa.float64()) - assert result_sql[0].column(2) == pa.array([3], type=pa.float64()) # todo: i would not expect this to be float + assert result_sql[0].column(2) == pa.array([3], type=pa.uint64()) assert result_sql[0].column(3) == pa.array([1], type=pa.float64()) assert result_sql[0].column(4) == pa.array([2], type=pa.float64()) assert result_sql[0].column(5) == pa.array([4], type=pa.float64()) @@ -852,7 +852,7 @@ def test_regr_funcs_sql_2(): @pytest.mark.parametrize("func, expected", [ pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"), pytest.param(f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"), - pytest.param(f.regr_count, pa.array([3], type=pa.float64()), id="regr_count"), # TODO: I would expect this to return an int array + pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"), pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"), pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"), pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"), diff --git a/src/catalog.rs b/src/catalog.rs index 49fe1404..1ce66a4d 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -25,7 +25,7 @@ use crate::errors::DataFusionError; use crate::utils::wait_for_future; use datafusion::{ arrow::pyarrow::ToPyArrow, - catalog::{schema::SchemaProvider, CatalogProvider}, + catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 469bb789..21b085c0 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -18,6 +18,7 @@ use datafusion::arrow::array::Array; use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::sqlparser::ast::NullTreatment as DFNullTreatment; use pyo3::{exceptions::PyValueError, prelude::*}; use crate::errors::py_datafusion_err; @@ -775,20 +776,20 @@ pub enum NullTreatment { RESPECT_NULLS, } -impl From for sqlparser::ast::NullTreatment { - fn from(null_treatment: NullTreatment) -> sqlparser::ast::NullTreatment { +impl From for DFNullTreatment { + fn from(null_treatment: NullTreatment) -> DFNullTreatment { match null_treatment { - NullTreatment::IGNORE_NULLS => sqlparser::ast::NullTreatment::IgnoreNulls, - NullTreatment::RESPECT_NULLS => sqlparser::ast::NullTreatment::RespectNulls, + NullTreatment::IGNORE_NULLS => DFNullTreatment::IgnoreNulls, + NullTreatment::RESPECT_NULLS => DFNullTreatment::RespectNulls, } } } -impl From for NullTreatment { - fn from(null_treatment: sqlparser::ast::NullTreatment) -> NullTreatment { +impl From for NullTreatment { + fn from(null_treatment: DFNullTreatment) -> NullTreatment { match null_treatment { - sqlparser::ast::NullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS, - sqlparser::ast::NullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS, + DFNullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS, + DFNullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS, } } } diff --git a/src/context.rs b/src/context.rs index d7890e3f..a43599cf 100644 --- a/src/context.rs +++ b/src/context.rs @@ -20,6 +20,7 @@ use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; +use datafusion::execution::session_state::SessionStateBuilder; use object_store::ObjectStore; use url::Url; use uuid::Uuid; @@ -49,9 +50,7 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::MemTable; use datafusion::datasource::TableProvider; -use datafusion::execution::context::{ - SQLOptions, SessionConfig, SessionContext, SessionState, TaskContext, -}; +use datafusion::execution::context::{SQLOptions, SessionConfig, SessionContext, TaskContext}; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; @@ -281,7 +280,11 @@ impl PySessionContext { RuntimeConfig::default() }; let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_state = SessionState::new_with_config_rt(config, runtime); + let session_state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); Ok(PySessionContext { ctx: SessionContext::new_with_state(session_state), }) diff --git a/src/dataset.rs b/src/dataset.rs index 724b4af7..b5704164 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion::catalog::Session; use pyo3::exceptions::PyValueError; /// Implements a Datafusion TableProvider that delegates to a PyArrow Dataset /// This allows us to use PyArrow Datasets as Datafusion tables while pushing down projections and filters @@ -30,7 +31,6 @@ use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::{DataFusionError, Result as DFResult}; -use datafusion::execution::context::SessionState; use datafusion::logical_expr::TableProviderFilterPushDown; use datafusion::physical_plan::ExecutionPlan; use datafusion_expr::Expr; @@ -98,7 +98,7 @@ impl TableProvider for Dataset { /// parallelized or distributed. async fn scan( &self, - _ctx: &SessionState, + _ctx: &dyn Session, projection: Option<&Vec>, filters: &[Expr], // limit can be used to reduce the amount scanned diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index 626d92c7..e3d1bb13 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -126,9 +126,9 @@ impl PyAggregate { match expr { // TODO: This Alias logic seems to be returning some strange results that we should investigate Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()), - Expr::AggregateFunction(AggregateFunction { - func_def: _, args, .. - }) => Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()), + Expr::AggregateFunction(AggregateFunction { func: _, args, .. }) => { + Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()) + } _ => Err(py_type_err( "Encountered a non Aggregate type in aggregation_arguments", )), @@ -138,9 +138,7 @@ impl PyAggregate { fn _agg_func_name(expr: &Expr) -> PyResult { match expr { Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()), - Expr::AggregateFunction(AggregateFunction { func_def, .. }) => { - Ok(func_def.name().to_owned()) - } + Expr::AggregateFunction(AggregateFunction { func, .. }) => Ok(func.name().to_owned()), _ => Err(py_type_err( "Encountered a non Aggregate type in agg_func_name", )), diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs index 04ec29a1..15097e00 100644 --- a/src/expr/aggregate_expr.rs +++ b/src/expr/aggregate_expr.rs @@ -41,7 +41,7 @@ impl From for PyAggregateFunction { impl Display for PyAggregateFunction { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let args: Vec = self.aggr.args.iter().map(|expr| expr.to_string()).collect(); - write!(f, "{}({})", self.aggr.func_def.name(), args.join(", ")) + write!(f, "{}({})", self.aggr.func.name(), args.join(", ")) } } @@ -49,7 +49,7 @@ impl Display for PyAggregateFunction { impl PyAggregateFunction { /// Get the aggregate type, such as "MIN", or "MAX" fn aggregate_type(&self) -> String { - self.aggr.func_def.name().to_string() + self.aggr.func.name().to_string() } /// is this a distinct aggregate such as `COUNT(DISTINCT expr)` diff --git a/src/functions.rs b/src/functions.rs index f8f47816..c53d4ad9 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -16,7 +16,7 @@ // under the License. use datafusion::functions_aggregate::all_default_aggregate_functions; -use datafusion_expr::AggregateExt; +use datafusion_expr::ExprFunctionExt; use pyo3::{prelude::*, wrap_pyfunction}; use crate::common::data_type::NullTreatment; @@ -30,16 +30,15 @@ use datafusion::functions; use datafusion::functions_aggregate; use datafusion_common::{Column, ScalarValue, TableReference}; use datafusion_expr::expr::Alias; +use datafusion_expr::sqlparser::ast::NullTreatment as DFNullTreatment; use datafusion_expr::{ - expr::{ - find_df_window_func, AggregateFunction, AggregateFunctionDefinition, Sort, WindowFunction, - }, + expr::{find_df_window_func, AggregateFunction, Sort, WindowFunction}, lit, Expr, WindowFunctionDefinition, }; #[pyfunction] pub fn approx_distinct(expression: PyExpr) -> PyExpr { - functions_aggregate::expr_fn::approx_distinct::approx_distinct(expression.expr).into() + functions_aggregate::expr_fn::approx_distinct(expression.expr).into() } #[pyfunction] @@ -342,9 +341,8 @@ pub fn first_value( builder = builder.filter(filter.expr); } - if let Some(null_treatment) = null_treatment { - builder = builder.null_treatment(null_treatment.into()) - } + // would be nice if all the options builder methods accepted Option ... + builder = builder.null_treatment(null_treatment.map(DFNullTreatment::from)); Ok(builder.build()?.into()) } @@ -373,9 +371,7 @@ pub fn last_value( builder = builder.filter(filter.expr); } - if let Some(null_treatment) = null_treatment { - builder = builder.null_treatment(null_treatment.into()) - } + builder = builder.null_treatment(null_treatment.map(DFNullTreatment::from)); Ok(builder.build()?.into()) } @@ -392,14 +388,14 @@ fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { #[pyfunction] fn make_array(exprs: Vec) -> PyExpr { - datafusion_functions_array::expr_fn::make_array(exprs.into_iter().map(|x| x.into()).collect()) + datafusion_functions_nested::expr_fn::make_array(exprs.into_iter().map(|x| x.into()).collect()) .into() } #[pyfunction] fn array_concat(exprs: Vec) -> PyExpr { let exprs = exprs.into_iter().map(|x| x.into()).collect(); - datafusion_functions_array::expr_fn::array_concat(exprs).into() + datafusion_functions_nested::expr_fn::array_concat(exprs).into() } #[pyfunction] @@ -411,12 +407,12 @@ fn array_cat(exprs: Vec) -> PyExpr { fn array_position(array: PyExpr, element: PyExpr, index: Option) -> PyExpr { let index = ScalarValue::Int64(index); let index = Expr::Literal(index); - datafusion_functions_array::expr_fn::array_position(array.into(), element.into(), index).into() + datafusion_functions_nested::expr_fn::array_position(array.into(), element.into(), index).into() } #[pyfunction] fn array_slice(array: PyExpr, begin: PyExpr, end: PyExpr, stride: Option) -> PyExpr { - datafusion_functions_array::expr_fn::array_slice( + datafusion_functions_nested::expr_fn::array_slice( array.into(), begin.into(), end.into(), @@ -638,18 +634,16 @@ fn window( } macro_rules! aggregate_function { - ($NAME: ident, $FUNC: ident) => { + ($NAME: ident, $FUNC: path) => { aggregate_function!($NAME, $FUNC, stringify!($NAME)); }; - ($NAME: ident, $FUNC: ident, $DOC: expr) => { + ($NAME: ident, $FUNC: path, $DOC: expr) => { #[doc = $DOC] #[pyfunction] #[pyo3(signature = (*args, distinct=false))] fn $NAME(args: Vec, distinct: bool) -> PyExpr { let expr = datafusion_expr::Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::$FUNC, - ), + func: $FUNC(), args: args.into_iter().map(|e| e.into()).collect(), distinct, filter: None, @@ -701,7 +695,7 @@ macro_rules! expr_fn_vec { }; } -/// Generates a [pyo3] wrapper for [datafusion_functions_array::expr_fn] +/// Generates a [pyo3] wrapper for [datafusion_functions_nested::expr_fn] /// /// These functions have explicit named arguments. macro_rules! array_fn { @@ -718,7 +712,7 @@ macro_rules! array_fn { #[doc = $DOC] #[pyfunction] fn $FUNC($($arg: PyExpr),*) -> PyExpr { - datafusion_functions_array::expr_fn::$FUNC($($arg.into()),*).into() + datafusion_functions_nested::expr_fn::$FUNC($($arg.into()),*).into() } }; } @@ -884,9 +878,9 @@ array_fn!(array_resize, array size value); array_fn!(flatten, array); array_fn!(range, start stop step); -aggregate_function!(array_agg, ArrayAgg); -aggregate_function!(max, Max); -aggregate_function!(min, Min); +aggregate_function!(array_agg, functions_aggregate::array_agg::array_agg_udaf); +aggregate_function!(max, functions_aggregate::min_max::max_udaf); +aggregate_function!(min, functions_aggregate::min_max::min_udaf); pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?;