Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent API to set parameters of aggregate and window functions (AggregateExt --> ExprFunctionExt) #11550

Merged
merged 22 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
76b6655
Moving over AggregateExt to ExprFunctionExt and adding in function se…
timsaucer-may Jul 18, 2024
4b124c6
Switch WindowFrame to only need the window function definition and ar…
timsaucer-may Jul 19, 2024
31a82dd
Changing null_treatment to take an option, but this is mostly for cod…
timsaucer-may Jul 19, 2024
d290d2e
Moving functions in ExprFuncBuilder over to be explicitly implementin…
timsaucer-may Jul 19, 2024
d267d0e
Apply cargo fmt
timsaucer-may Jul 19, 2024
2e758ad
Add deprecated trait AggregateExt so that users get a warning but sti…
timsaucer Jul 23, 2024
fbde31f
Window helper functions should return Expr
timsaucer Jul 23, 2024
99f1c79
Update documentation to show window function example
timsaucer Jul 23, 2024
fd9ebdf
Add license info
timsaucer Jul 23, 2024
4344a9f
Update comments that are no longer applicable
timsaucer Jul 24, 2024
a154ddc
Remove first_value and last_value since these are already implemented…
timsaucer Jul 24, 2024
64cbc36
Update to use WindowFunction::new to set additional parameters for o…
timsaucer Jul 24, 2024
532f262
Apply cargo fmt
timsaucer Jul 24, 2024
6436499
Merge remote-tracking branch 'apache/main' into feature/expr-function…
alamb Jul 24, 2024
acfcece
Fix up clippy
alamb Jul 24, 2024
039f427
fix doc example
alamb Jul 24, 2024
75e364a
fmt
alamb Jul 24, 2024
1a801f6
doc tweaks
alamb Jul 24, 2024
d689872
more doc tweaks
alamb Jul 24, 2024
4884c8a
fix up links
alamb Jul 24, 2024
9bfd1dd
fix integration test
alamb Jul 24, 2024
77726eb
fix anothr doc example
alamb Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ async fn main() -> Result<()> {
df.show().await?;

// Now, run the function using the DataFrame API:
let window_expr = smooth_it.call(
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(None),
);
let window_expr = smooth_it
.call(vec![col("speed")]) // smooth_it(speed)
.partition_by(vec![col("car")]) // PARTITION BY car
.order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC
.window_frame(WindowFrame::new(None))
.build()?;
let df = ctx.table("cars").await?.window(vec![window_expr])?;

// print the results
Expand Down
4 changes: 2 additions & 2 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};
use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator};

/// This example demonstrates the DataFusion [`Expr`] API.
///
Expand Down Expand Up @@ -95,7 +95,7 @@ fn expr_fn_demo() -> Result<()> {
let agg = first_value.call(vec![col("price")]);
assert_eq!(agg.to_string(), "first_value(price)");

// You can use the AggregateExt trait to create more complex aggregates
// You can use the ExprFunctionExt trait to create more complex aggregates
// such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
let agg = first_value
.call(vec![col("price")])
Expand Down
12 changes: 6 additions & 6 deletions datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ async fn main() -> Result<()> {
df.show().await?;

// Now, run the function using the DataFrame API:
let window_expr = smooth_it.call(
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(None),
);
let window_expr = smooth_it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so much nicer

.call(vec![col("speed")]) // smooth_it(speed)
.partition_by(vec![col("car")]) // PARTITION BY car
.order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC
.window_frame(WindowFrame::new(None))
.build()?;
let df = ctx.table("cars").await?.window(vec![window_expr])?;

// print the results
Expand Down
13 changes: 6 additions & 7 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,8 +1696,8 @@ mod tests {
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation,
Volatility, WindowFrame, WindowFunctionDefinition,
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
Expand Down Expand Up @@ -1867,11 +1867,10 @@ mod tests {
BuiltInWindowFunction::FirstValue,
),
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(None),
None,
));
))
.partition_by(vec![col("aggregate_test_100.c2")])
.build()
.unwrap();
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();

Expand Down
22 changes: 11 additions & 11 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum};

Expand Down Expand Up @@ -183,15 +183,15 @@ async fn test_count_wildcard_on_window() -> Result<()> {
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
None,
))])?
))
.order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))])
.window_frame(WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
))
.build()
.unwrap()])?
.explain(false, false)?
.collect()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray};
use arrow_schema::{DataType, Field};
use datafusion::prelude::*;
use datafusion_common::{assert_contains, DFSchema, ScalarValue};
use datafusion_expr::AggregateExt;
use datafusion_expr::ExprFunctionExt;
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_aggregate::first_last::first_value_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
Expand Down
85 changes: 65 additions & 20 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::utils::expr_to_columns;
use crate::{
aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator,
Signature,
aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction,
ExprSchemable, Operator, Signature, WindowFrame, WindowUDF,
};
use crate::{window_frame, Volatility};

Expand Down Expand Up @@ -60,6 +60,10 @@ use sqlparser::ast::NullTreatment;
/// use the fluent APIs in [`crate::expr_fn`] such as [`col`] and [`lit`], or
/// methods such as [`Expr::alias`], [`Expr::cast_to`], and [`Expr::Like`]).
///
/// See also [`ExprFunctionExt`] for creating aggregate and window functions.
///
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
///
/// # Schema Access
///
/// See [`ExprSchemable::get_type`] to access the [`DataType`] and nullability
Expand Down Expand Up @@ -283,15 +287,17 @@ pub enum Expr {
/// This expression is guaranteed to have a fixed type.
TryCast(TryCast),
/// A sort expression, that can be used to sort values.
///
/// See [Expr::sort] for more details
Sort(Sort),
/// Represents the call of a scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
/// Calls an aggregate function with arguments, and optional
/// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`.
///
/// See also [`AggregateExt`] to set these fields.
/// See also [`ExprFunctionExt`] to set these fields.
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
Expand Down Expand Up @@ -641,9 +647,9 @@ impl AggregateFunctionDefinition {

/// Aggregate function
///
/// See also [`AggregateExt`] to set these fields on `Expr`
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
Expand Down Expand Up @@ -769,7 +775,52 @@ impl fmt::Display for WindowFunctionDefinition {
}
}

impl From<aggregate_function::AggregateFunction> for WindowFunctionDefinition {
fn from(value: aggregate_function::AggregateFunction) -> Self {
Self::AggregateFunction(value)
}
}

impl From<BuiltInWindowFunction> for WindowFunctionDefinition {
fn from(value: BuiltInWindowFunction) -> Self {
Self::BuiltInWindowFunction(value)
}
}

impl From<Arc<crate::AggregateUDF>> for WindowFunctionDefinition {
fn from(value: Arc<crate::AggregateUDF>) -> Self {
Self::AggregateUDF(value)
}
}

impl From<Arc<WindowUDF>> for WindowFunctionDefinition {
fn from(value: Arc<WindowUDF>) -> Self {
Self::WindowUDF(value)
}
}

/// Window function
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

/// Holds the actual actual function to call [`WindowFunction`] as well as its
/// arguments (`args`) and the contents of the `OVER` clause:
///
/// 1. `PARTITION BY`
/// 2. `ORDER BY`
/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`)
///
/// # Example
/// ```
/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt};
/// # use datafusion_expr::expr::WindowFunction;
/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c)
/// let expr = Expr::WindowFunction(
/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")])
/// )
/// .partition_by(vec![col("b")])
/// .order_by(vec![col("b").sort(true, true)])
/// .build()
/// .unwrap();
/// ```
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct WindowFunction {
/// Name of the function
Expand All @@ -787,22 +838,16 @@ pub struct WindowFunction {
}

impl WindowFunction {
/// Create a new Window expression
pub fn new(
fun: WindowFunctionDefinition,
args: Vec<Expr>,
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
null_treatment: Option<NullTreatment>,
) -> Self {
/// Create a new Window expression with the specified argument an
/// empty `OVER` clause
pub fn new(fun: impl Into<WindowFunctionDefinition>, args: Vec<Expr>) -> Self {
Self {
fun,
fun: fun.into(),
args,
partition_by,
order_by,
window_frame,
null_treatment,
partition_by: Vec::default(),
order_by: Vec::default(),
window_frame: WindowFrame::new(None),
null_treatment: None,
}
}
}
Expand Down
Loading