From 18f06aec948cbd286f98a3299ec76f25c6243256 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 22 Jun 2023 12:52:52 -0400 Subject: [PATCH] Make it easier to call window functions via expression API --- datafusion-examples/examples/simple_udwf.rs | 14 +-- datafusion/core/src/dataframe.rs | 17 ++++ datafusion/expr/src/expr.rs | 67 ++++++++++++--- datafusion/expr/src/expr_fn.rs | 86 ++++++++++++++++++- datafusion/expr/src/tree_node/expr.rs | 10 +-- datafusion/expr/src/udwf.rs | 29 ++----- datafusion/expr/src/window_function.rs | 24 ++++++ .../optimizer/src/analyzer/type_coercion.rs | 4 +- datafusion/sql/src/expr/function.rs | 12 +-- datafusion/sql/src/utils.rs | 15 ++-- 10 files changed, 216 insertions(+), 62 deletions(-) diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index c5c166a5796b..4160ea70f2de 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -112,12 +112,14 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it().call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(false), - ); + let window_expr = smooth_it() + // smooth_it(speed) + .call(vec![col("speed")]) + .with_partition_by(vec![col("car")]) + // ORDER BY time ASC + .with_order_by(vec![col("time").sort(true, true)]) + .build(); + let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 5da72f96bd65..2aa2abfa74bd 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -219,6 +219,23 @@ impl DataFrame { } /// Apply one or more window functions ([`Expr::WindowFunction`]) to extend the schema + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// + /// // The following is the equivalent of "SELECT FIRST_VALUE(b) OVER(PARTITION BY a)" + /// let first_value = first_value(col("b")) + /// .with_partition_by(vec![col("a")]) + /// .build(); + /// let _ = df.window(vec![first_value]); + /// # Ok(()) + /// # } + /// ``` pub fn window(self, window_exprs: Vec) -> Result { let plan = LogicalPlanBuilder::from(self.plan) .window(window_exprs)? diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9c3b53906ac9..377b63dfd449 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -443,7 +443,30 @@ impl AggregateFunction { } } -/// Window function +/// Window Function Expression (part of `Expr::WindowFunction`). +/// +/// Holds the actual actual function to call +/// [`window_function::WindowFunction`] as well as its arguments +/// (`args`) and the contents of the `OVER` clause: +/// +/// 1. `PARTITION BY` +/// 2. `ORDER BY` +/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`) +/// +/// See [`Self::build`] to create an [`Expr`] +/// +/// # Example +/// ``` +/// # use datafusion_expr::expr::WindowFunction; +/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c) +/// let expr: Expr = WindowFunction::new( +/// BuiltInWindowFunction::FirstValue, +/// vec![col("a")] +/// ) +/// .with_partition_by(vec![col("b")]) +/// .with_order_by(vec![col("b")]) +/// .build(); +/// ``` #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function @@ -459,22 +482,40 @@ pub struct WindowFunction { } impl WindowFunction { - /// Create a new Window expression - pub fn new( - fun: window_function::WindowFunction, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: window_frame::WindowFrame, - ) -> Self { + /// Create a new Window expression with the specified argument an + /// empty `OVER` clause + pub fn new(fun: impl Into, args: Vec) -> Self { Self { - fun, + fun: fun.into(), args, - partition_by, - order_by, - window_frame, + partition_by: vec![], + order_by: vec![], + window_frame: window_frame::WindowFrame::new(false), } } + + /// set the partition by expressions + pub fn with_partition_by(mut self, partition_by: Vec) -> Self { + self.partition_by = partition_by; + self + } + + /// set the order by expressions + pub fn with_order_by(mut self, order_by: Vec) -> Self { + self.order_by = order_by; + self + } + + /// set the window frame + pub fn with_window_frame(mut self, window_frame: window_frame::WindowFrame) -> Self { + self.window_frame = window_frame; + self + } + + /// convert this WindowFunction into an [`Expr`] + pub fn build(self) -> Expr { + Expr::WindowFunction(self) + } } // Exists expression. diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ef782b319cd7..276dc7c7b72d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -18,17 +18,17 @@ //! Functions for creating logical expressions use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, + self, AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, ScalarFunction, TryCast, }; use crate::function::PartitionEvaluatorFactory; -use crate::WindowUDF; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; +use crate::{BuiltInWindowFunction, WindowUDF}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::sync::Arc; @@ -158,6 +158,83 @@ pub fn count(expr: Expr) -> Expr { )) } +/// Create an expression to represent the `row_number` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn row_number() -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::RowNumber, vec![]) +} + +/// Create an expression to represent the `rank` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn rank() -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::Rank, vec![]) +} + +/// Create an expression to represent the `dense_rank` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn dense_rank() -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::DenseRank, vec![]) +} + +/// Create an expression to represent the `percent_rank` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn percent_rank() -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::PercentRank, vec![]) +} + +/// Create an expression to represent the `cume_dist` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn cume_dist(arg: Expr) -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![arg]) +} + +/// Create an expression to represent the `ntile` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn ntile(arg: Expr) -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg]) +} + +/// Create an expression to represent the `lag` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn lag(arg: Expr) -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::Lag, vec![arg]) +} + +/// Create an expression to represent the `lead` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn lead(arg: Expr) -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::Lead, vec![arg]) +} + +/// Create an expression to represent the `first_value` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn first_value(arg: Expr) -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![arg]) +} + +/// Create an expression to represent the `last_value` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn last_value(arg: Expr) -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::LastValue, vec![arg]) +} + +/// Create an expression to represent the `nth_value` window function +/// +/// Note: call [`expr::WindowFunction::build]` to create an [`Expr`] +pub fn nth_value(arg: Expr) -> expr::WindowFunction { + expr::WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg]) +} + /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( @@ -751,6 +828,11 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } +// /// Create a window expr from +// pub fn window_expr(window_function: impl Into) -> Expr { +// e Expr::WindowFunction(expr.into()) +// } + /// Creates a new UDF with a specific signature and specific return type. /// This is a helper function to create a new UDF. /// The function `create_udf` returns a subset of all possible `ScalarFunction`: diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 3ecf54c9ce26..f72025500fd7 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -283,13 +283,13 @@ impl TreeNode for Expr { partition_by, order_by, window_frame, - }) => Expr::WindowFunction(WindowFunction::new( + }) => Expr::WindowFunction(WindowFunction { fun, - transform_vec(args, &mut transform)?, - transform_vec(partition_by, &mut transform)?, - transform_vec(order_by, &mut transform)?, + args: transform_vec(args, &mut transform)?, + partition_by: transform_vec(partition_by, &mut transform)?, + order_by: transform_vec(order_by, &mut transform)?, window_frame, - )), + }), Expr::AggregateFunction(AggregateFunction { args, fun, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c0a2a8205a08..59070253df48 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -22,9 +22,7 @@ use std::{ sync::Arc, }; -use crate::{ - Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; +use crate::{expr, Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature}; /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. @@ -93,26 +91,15 @@ impl WindowUDF { } } - /// creates a [`Expr`] that calls the window function given - /// the `partition_by`, `order_by`, and `window_frame` definition + /// creates a [`expr::WindowFunction`] builder for calling the + /// window function given. + /// + /// Use the methods on the builder to set the `partition_by`, + /// `order_by`, and `window_frame` definitions /// /// This utility allows using the UDWF without requiring access to /// the registry, such as with the DataFrame API. - pub fn call( - &self, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: WindowFrame, - ) -> Expr { - let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); - - Expr::WindowFunction(crate::expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - }) + pub fn call(&self, args: Vec) -> expr::WindowFunction { + expr::WindowFunction::new(Arc::new(self.clone()), args) } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 6f30bff69b6a..901fbd566019 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -61,6 +61,30 @@ pub fn find_df_window_func(name: &str) -> Option { } } +impl From for WindowFunction { + fn from(value: AggregateFunction) -> Self { + Self::AggregateFunction(value) + } +} + +impl From for WindowFunction { + fn from(value: BuiltInWindowFunction) -> Self { + Self::BuiltInWindowFunction(value) + } +} + +impl From> for WindowFunction { + fn from(value: Arc) -> Self { + Self::AggregateUDF(value) + } +} + +impl From> for WindowFunction { + fn from(value: Arc) -> Self { + Self::WindowUDF(value) + } +} + impl fmt::Display for BuiltInWindowFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.name()) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 412abbfae644..d577e9b0bd33 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -433,13 +433,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { }) => { let window_frame = coerce_window_frame(window_frame, &self.schema, &order_by)?; - let expr = Expr::WindowFunction(WindowFunction::new( + let expr = Expr::WindowFunction(WindowFunction { fun, args, partition_by, order_by, window_frame, - )); + }); Ok(expr) } expr => Ok(expr), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index f08f357ec42c..5bb01862d877 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -101,17 +101,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), + Expr::WindowFunction(expr::WindowFunction { + fun: WindowFunction::AggregateFunction(aggregate_fun), args, partition_by, order_by, window_frame, - )) + }) } - _ => Expr::WindowFunction(expr::WindowFunction::new( + _ => Expr::WindowFunction(expr::WindowFunction { fun, - self.function_args_to_expr( + args: self.function_args_to_expr( function.args, schema, planner_context, @@ -119,7 +119,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { partition_by, order_by, window_frame, - )), + }), }; return Ok(expr); } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index e37830d0ba53..e8ed567fac97 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -182,21 +182,22 @@ where partition_by, order_by, window_frame, - }) => Ok(Expr::WindowFunction(WindowFunction::new( - fun.clone(), - args.iter() + }) => Ok(Expr::WindowFunction(WindowFunction { + fun: fun.clone(), + args: args + .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, - partition_by + partition_by: partition_by .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, - order_by + order_by: order_by .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, - window_frame.clone(), - ))), + window_frame: window_frame.clone(), + })), Expr::AggregateUDF(AggregateUDF { fun, args,