Skip to content

Commit

Permalink
update AggregateFunction
Browse files Browse the repository at this point in the history
Upstream Changes:
- The field name was switched from `func_name` to func.
- AggregateFunctionDefinition was removed

Ref: apache/datafusion#11803
  • Loading branch information
Michael-J-Ward committed Aug 20, 2024
1 parent 7207433 commit 4c75286
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/expr/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl PyAggregate {
// TODO: This Alias logic seems to be returning some strange results that we should investigate
Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction {
func_def: _, args, ..
func: _, args, ..
}) => Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()),
_ => Err(py_type_err(
"Encountered a non Aggregate type in aggregation_arguments",
Expand All @@ -138,8 +138,8 @@ impl PyAggregate {
fn _agg_func_name(expr: &Expr) -> PyResult<String> {
match expr {
Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction { func_def, .. }) => {
Ok(func_def.name().to_owned())
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
Ok(func.name().to_owned())
}
_ => Err(py_type_err(
"Encountered a non Aggregate type in agg_func_name",
Expand Down
4 changes: 2 additions & 2 deletions src/expr/aggregate_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ impl From<AggregateFunction> for PyAggregateFunction {
impl Display for PyAggregateFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
let args: Vec<String> = self.aggr.args.iter().map(|expr| expr.to_string()).collect();
write!(f, "{}({})", self.aggr.func_def.name(), args.join(", "))
write!(f, "{}({})", self.aggr.func.name(), args.join(", "))
}
}

#[pymethods]
impl PyAggregateFunction {
/// Get the aggregate type, such as "MIN", or "MAX"
fn aggregate_type(&self) -> String {
self.aggr.func_def.name().to_string()
self.aggr.func.name().to_string()
}

/// is this a distinct aggregate such as `COUNT(DISTINCT expr)`
Expand Down
20 changes: 8 additions & 12 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion_expr::AggregateExt;
use datafusion_expr::ExprFunctionExt as AggregateExt;
use pyo3::{prelude::*, wrap_pyfunction};

use crate::common::data_type::NullTreatment;
Expand All @@ -31,9 +31,7 @@ use datafusion::functions_aggregate;
use datafusion_common::{Column, ScalarValue, TableReference};
use datafusion_expr::expr::Alias;
use datafusion_expr::{
expr::{
find_df_window_func, AggregateFunction, AggregateFunctionDefinition, Sort, WindowFunction,
},
expr::{find_df_window_func, AggregateFunction, Sort, WindowFunction},
lit, Expr, WindowFunctionDefinition,
};

Expand Down Expand Up @@ -638,18 +636,16 @@ fn window(
}

macro_rules! aggregate_function {
($NAME: ident, $FUNC: ident) => {
($NAME: ident, $FUNC: path) => {
aggregate_function!($NAME, $FUNC, stringify!($NAME));
};
($NAME: ident, $FUNC: ident, $DOC: expr) => {
($NAME: ident, $FUNC: path, $DOC: expr) => {
#[doc = $DOC]
#[pyfunction]
#[pyo3(signature = (*args, distinct=false))]
fn $NAME(args: Vec<PyExpr>, distinct: bool) -> PyExpr {
let expr = datafusion_expr::Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::$FUNC,
),
func: $FUNC(),
args: args.into_iter().map(|e| e.into()).collect(),
distinct,
filter: None,
Expand Down Expand Up @@ -884,9 +880,9 @@ array_fn!(array_resize, array size value);
array_fn!(flatten, array);
array_fn!(range, start stop step);

aggregate_function!(array_agg, ArrayAgg);
aggregate_function!(max, Max);
aggregate_function!(min, Min);
aggregate_function!(array_agg, functions_aggregate::array_agg::array_agg_udaf);
aggregate_function!(max, functions_aggregate::min_max::max_udaf);
aggregate_function!(min, functions_aggregate::min_max::min_udaf);

pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(abs))?;
Expand Down

0 comments on commit 4c75286

Please sign in to comment.