Skip to content

Commit

Permalink
refactor api
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jun 6, 2024
1 parent f25c1df commit c689980
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 33 deletions.
7 changes: 3 additions & 4 deletions datafusion-examples/examples/udaf_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion::{
};

use datafusion_common::Result;
use datafusion_expr::col;
use datafusion_expr::{col, AggregateUDFExprBuilder};

#[tokio::main]
async fn main() -> Result<()> {
Expand All @@ -33,11 +33,10 @@ async fn main() -> Result<()> {
let mut state = SessionState::new_with_config_rt(config, ctx.runtime_env());
let _ = register_all(&mut state);

let first_value_udaf = state.aggregate_functions().get("FIRST_VALUE").unwrap();
let first_value_udaf = state.aggregate_functions().get("first_value").unwrap();
let first_value_builder = first_value_udaf
.call(vec![col("a")])
.order_by(vec![col("b")])
.build();
.order_by(vec![col("b")]);

let first_value_fn = first_value(col("a"), Some(vec![col("b")]));
assert_eq!(first_value_builder, first_value_fn);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub use signature::{
ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, AggregateUDFExprBuilder};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
58 changes: 56 additions & 2 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, Result};
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
Expand Down Expand Up @@ -139,8 +140,15 @@ impl AggregateUDF {
///
/// This utility allows using the UDAF without requiring access to
/// the registry, such as with the DataFrame API.
pub fn call(&self, args: Vec<Expr>) -> AggregateFunction {
AggregateFunction::new_udf(Arc::new(self.clone()), args, false, None, None, None)
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
None,
None,
None,
))
}

/// Returns this function's name
Expand Down Expand Up @@ -599,3 +607,49 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
(self.accumulator)(acc_args)
}
}

pub trait AggregateUDFExprBuilder {
fn order_by(self, order_by: Vec<Expr>) -> Expr;
fn filter(self, filter: Box<Expr>) -> Expr;
fn null_treatment(self, null_treatment: NullTreatment) -> Expr;
fn distinct(self) -> Expr;
}

impl AggregateUDFExprBuilder for Expr {
fn order_by(self, order_by: Vec<Expr>) -> Expr {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.order_by = Some(order_by);
Expr::AggregateFunction(udaf)
}
_ => self,
}
}
fn filter(self, filter: Box<Expr>) -> Expr {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.filter = Some(filter);
Expr::AggregateFunction(udaf)
}
_ => self,
}
}
fn null_treatment(self, null_treatment: NullTreatment) -> Expr {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.null_treatment = Some(null_treatment);
Expr::AggregateFunction(udaf)
}
_ => self,
}
}
fn distinct(self) -> Expr {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.distinct = true;
Expr::AggregateFunction(udaf)
}
_ => self,
}
}
}
17 changes: 6 additions & 11 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature, TypeSignature,
Volatility,
Accumulator, AggregateUDFExprBuilder, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, TypeSignature, Volatility
};
use datafusion_physical_expr_common::aggregate::utils::get_sort_options;
use datafusion_physical_expr_common::sort_expr::{
Expand All @@ -44,14 +42,11 @@ create_func!(FirstValue, first_value_udaf);

/// Returns the first value in a group of values.
pub fn first_value(expression: Expr, order_by: Option<Vec<Expr>>) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
first_value_udaf(),
vec![expression],
false,
None,
order_by,
None,
))
if let Some(order_by) = order_by {
first_value_udaf().call(vec![expression]).order_by(order_by)
} else {
first_value_udaf().call(vec![expression])
}
}

pub struct FirstValue {
Expand Down
18 changes: 7 additions & 11 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Column, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr_rewriter::normalize_cols;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{col, LogicalPlanBuilder};
use datafusion_expr::{col, AggregateUDFExprBuilder, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};

/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
Expand Down Expand Up @@ -95,17 +94,14 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
let expr_cnt = on_expr.len();

// Construct the aggregation expression to be used to fetch the selected expressions.
let first_value_udaf =
let first_value_udaf: std::sync::Arc<datafusion_expr::AggregateUDF> =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
Expr::AggregateFunction(AggregateFunction::new_udf(
first_value_udaf.clone(),
vec![e],
false,
None,
sort_expr.clone(),
None,
))
if let Some(order_by) = &sort_expr {
first_value_udaf.call(vec![e]).order_by(order_by.clone())
} else {
first_value_udaf.call(vec![e])
}
});

let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ async fn roundtrip_expr_api() -> Result<()> {
lit(1),
),
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)),
first_value(lit(1), None),
first_value(lit(1), Some(vec![lit(2)])),
covar_samp(lit(1.5), lit(2.2)),
covar_pop(lit(1.5), lit(2.2)),
Expand Down
8 changes: 4 additions & 4 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,13 @@ select log(-1), log(0), sqrt(-1);

## Aggregate Function Builder

Another builder expression that ends with `build()`, it is useful if the functions has multiple optional arguments
Import trait `AggregateUDFExprBuilder` and update the arguments directly in `Expr`

See datafusion-examples/examples/udaf_expr.rs for example usage.

| Syntax | Equivalent to |
| -------------------------------------------------------------- | ----------------------------------- |
| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build() | first_value(expr, Some(vec![expr])) |
| Syntax | Equivalent to |
| ------------------------------------------------------ | ----------------------------------- |
| first_value_udaf.call(vec![expr]).order_by(vec![expr]) | first_value(expr, Some(vec![expr])) |

## Subquery Expressions

Expand Down

0 comments on commit c689980

Please sign in to comment.