Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce expr builder for aggregate function #10560

Merged
merged 12 commits into from
Jun 9, 2024
44 changes: 38 additions & 6 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::functions_aggregate::first_last::first_value_udaf;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries};
use datafusion::prelude::*;
Expand All @@ -32,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};

/// This example demonstrates the DataFusion [`Expr`] API.
///
Expand All @@ -44,11 +45,12 @@ use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
/// also comes with APIs for evaluation, simplification, and analysis.
///
/// The code in this example shows how to:
/// 1. Create [`Exprs`] using different APIs: [`main`]`
/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`]
/// 3. Simplify expressions: [`simplify_demo`]
/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`]
/// 5. Get the types of the expressions: [`expression_type_demo`]
/// 1. Create [`Expr`]s using different APIs: [`main`]`
/// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`]
/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`]
/// 4. Simplify expressions: [`simplify_demo`]
/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`]
/// 6. Get the types of the expressions: [`expression_type_demo`]
#[tokio::main]
async fn main() -> Result<()> {
// The easiest way to do create expressions is to use the
Expand All @@ -63,6 +65,9 @@ async fn main() -> Result<()> {
));
assert_eq!(expr, expr2);

// See how to build aggregate functions with the expr_fn API
expr_fn_demo()?;

// See how to evaluate expressions
evaluate_demo()?;

Expand All @@ -78,6 +83,33 @@ async fn main() -> Result<()> {
Ok(())
}

/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the
/// full range of expression types such as aggregates and window functions.
fn expr_fn_demo() -> Result<()> {
// Let's say you want to call the "first_value" aggregate function
let first_value = first_value_udaf();

// For example, to create the expression `FIRST_VALUE(price)`
// These expressions can be passed to `DataFrame::aggregate` and other
// APIs that take aggregate expressions.
let agg = first_value.call(vec![col("price")]);
assert_eq!(agg.to_string(), "first_value(price)");

// You can use the AggregateExt trait to create more complex aggregates
// such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
let agg = first_value
.call(vec![col("price")])
.order_by(vec![col("ts")])
.filter(col("quantity").gt(lit(100)))
.build()?; // build the aggregate
assert_eq!(
agg.to_string(),
"first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts]"
);

Ok(())
}

/// DataFusion can also evaluate arbitrary expressions on Arrow arrays.
fn evaluate_demo() -> Result<()> {
// For example, let's say you have some integers in an array
Expand Down
190 changes: 187 additions & 3 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use arrow::util::pretty::pretty_format_columns;
use arrow::util::pretty::{pretty_format_batches, pretty_format_columns};
use arrow_array::builder::{ListBuilder, StringBuilder};
use arrow_array::{ArrayRef, RecordBatch, StringArray, StructArray};
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray};
use arrow_schema::{DataType, Field};
use datafusion::prelude::*;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_common::{assert_contains, DFSchema, ScalarValue};
use datafusion_expr::AggregateExt;
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_aggregate::first_last::first_value_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor};
use sqlparser::ast::NullTreatment;
/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan
use std::sync::{Arc, OnceLock};

Expand Down Expand Up @@ -162,6 +166,183 @@ fn test_list_range() {
);
}

#[tokio::test]
async fn test_aggregate_error() {
let err = first_value_udaf()
.call(vec![col("props")])
// not a sort column
.order_by(vec![col("id")])
.build()
.unwrap_err()
.to_string();
assert_contains!(
err,
"Error during planning: ORDER BY expressions must be Expr::Sort"
);
}

#[tokio::test]
async fn test_aggregate_ext_order_by() {
let agg = first_value_udaf().call(vec![col("props")]);

// ORDER BY id ASC
let agg_asc = agg
.clone()
.order_by(vec![col("id").sort(true, true)])
.build()
.unwrap()
.alias("asc");

// ORDER BY id DESC
let agg_desc = agg
.order_by(vec![col("id").sort(false, true)])
.build()
.unwrap()
.alias("desc");

evaluate_agg_test(
agg_asc,
vec![
"+-----------------+",
"| asc |",
"+-----------------+",
"| {a: 2021-02-01} |",
"+-----------------+",
],
)
.await;

evaluate_agg_test(
agg_desc,
vec![
"+-----------------+",
"| desc |",
"+-----------------+",
"| {a: 2021-02-03} |",
"+-----------------+",
],
)
.await;
}

#[tokio::test]
async fn test_aggregate_ext_filter() {
let agg = first_value_udaf()
.call(vec![col("i")])
.order_by(vec![col("i").sort(true, true)])
.filter(col("i").is_not_null())
.build()
.unwrap()
.alias("val");

#[rustfmt::skip]
evaluate_agg_test(
agg,
vec![
"+-----+",
"| val |",
"+-----+",
"| 5 |",
"+-----+",
],
)
.await;
}

#[tokio::test]
async fn test_aggregate_ext_distinct() {
let agg = sum_udaf()
.call(vec![lit(5)])
// distinct sum should be 5, not 15
.distinct()
.build()
.unwrap()
.alias("distinct");

evaluate_agg_test(
agg,
vec![
"+----------+",
"| distinct |",
"+----------+",
"| 5 |",
"+----------+",
],
)
.await;
}

#[tokio::test]
async fn test_aggregate_ext_null_treatment() {
let agg = first_value_udaf()
.call(vec![col("i")])
.order_by(vec![col("i").sort(true, true)]);

let agg_respect = agg
.clone()
.null_treatment(NullTreatment::RespectNulls)
.build()
.unwrap()
.alias("respect");

let agg_ignore = agg
.null_treatment(NullTreatment::IgnoreNulls)
.build()
.unwrap()
.alias("ignore");

evaluate_agg_test(
agg_respect,
vec![
"+---------+",
"| respect |",
"+---------+",
"| |",
"+---------+",
],
)
.await;

evaluate_agg_test(
agg_ignore,
vec![
"+--------+",
"| ignore |",
"+--------+",
"| 5 |",
"+--------+",
],
)
.await;
}

/// Evaluates the specified expr as an aggregate and compares the result to the
/// expected result.
async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) {
let batch = test_batch();

let ctx = SessionContext::new();
let group_expr = vec![];
let agg_expr = vec![expr];
let result = ctx
.read_batch(batch)
.unwrap()
.aggregate(group_expr, agg_expr)
.unwrap()
.collect()
.await
.unwrap();

let result = pretty_format_batches(&result).unwrap().to_string();
let actual_lines = result.lines().collect::<Vec<_>>();

assert_eq!(
expected_lines, actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
}

/// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided
/// `RecordBatch` and compares the result to the expected result.
fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) {
Expand Down Expand Up @@ -189,6 +370,8 @@ fn test_batch() -> RecordBatch {
TEST_BATCH
.get_or_init(|| {
let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"]));
let int_array: ArrayRef =
Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)]));

// { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" }
let struct_array: ArrayRef = Arc::from(StructArray::from(vec![(
Expand All @@ -209,6 +392,7 @@ fn test_batch() -> RecordBatch {

RecordBatch::try_from_iter(vec![
("id", string_array),
("i", int_array),
("props", struct_array),
("list", list_array),
])
Expand Down
15 changes: 14 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,23 @@ pub enum Expr {
/// can be used. The first form consists of a series of boolean "when" expressions with
/// corresponding "then" expressions, and an optional "else" expression.
///
/// ```text
/// CASE WHEN condition THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
/// ```
///
/// The second form uses a base expression and then a series of "when" clauses that match on a
/// literal value.
///
/// ```text
/// CASE expression
/// WHEN value THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
/// ```
Case(Case),
/// Casts the expression to a given type and will return a runtime error if the expression cannot be cast.
/// This expression is guaranteed to have a fixed type.
Expand All @@ -279,7 +283,12 @@ pub enum Expr {
Sort(Sort),
/// Represents the call of a scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
/// Represents the call of an aggregate built-in function with arguments.
/// Calls an aggregate function with arguments, and optional
/// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`.
///
/// See also [`AggregateExt`] to set these fields.
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
Expand Down Expand Up @@ -623,6 +632,10 @@ impl AggregateFunctionDefinition {
}

/// Aggregate function
///
/// See also [`AggregateExt`] to set these fields on `Expr`
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub use signature::{
ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
Loading
Loading