diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index f84794e274f7..8b6d66b5de2a 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -15,10 +15,21 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use arrow::{ + array::{AsArray, Float64Array}, + datatypes::Float64Type, +}; +use arrow_schema::DataType; use datafusion::datasource::file_format::options::CsvReadOptions; use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_common::DataFusionError; +use datafusion_expr::{ + partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF, +}; // create local execution context with `cars.csv` registered as a table named `cars` async fn create_context() -> Result { @@ -39,6 +50,9 @@ async fn create_context() -> Result { async fn main() -> Result<()> { let ctx = create_context().await?; + // register the window function with DataFusion so wecan call it + ctx.register_udwf(my_average()); + // Use SQL to run the new window function let df = ctx.sql("SELECT * from cars").await?; // print the results @@ -52,6 +66,7 @@ async fn main() -> Result<()> { "SELECT car, \ speed, \ lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\ + my_average(speed) OVER (PARTITION BY car ORDER BY time),\ time \ from cars", ) @@ -59,16 +74,137 @@ async fn main() -> Result<()> { // print the results df.show().await?; - // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using - let df = ctx.sql("SELECT car, \ - speed, \ - lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\ - time \ - from cars").await?; - // print the results - df.show().await?; + // // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using + // let df = ctx.sql("SELECT car, \ + // speed, \ + // lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\ + // time \ + // from cars").await?; + // // print the results + // df.show().await?; // todo show how to run dataframe API as well Ok(()) } + +// TODO make a helper funciton like `crate_udf` that helps to make these signatures + +fn my_average() -> WindowUDF { + WindowUDF { + name: String::from("my_average"), + // it will take 2 arguments -- the column and the window size + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + return_type: Arc::new(return_type), + partition_evaluator: Arc::new(make_partition_evaluator), + } +} + +/// Compute the return type of the function given the argument types +fn return_type(arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(DataFusionError::Plan(format!( + "my_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ))); + } + Ok(Arc::new(arg_types[0].clone())) +} + +/// Create a partition evaluator for this argument +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// These different evaluation methods are called depending on the various settings of WindowUDF +impl PartitionEvaluator for MyPartitionEvaluator { + fn get_range(&self, _idx: usize, _n_rows: usize) -> Result> { + Err(DataFusionError::NotImplemented( + "get_range is not implemented for this window function".to_string(), + )) + } + + /// This function is given the values of each partition + fn evaluate( + &self, + values: &[arrow::array::ArrayRef], + _num_rows: usize, + ) -> Result { + // datafusion has handled ensuring we get the correct input argument + assert_eq!(values.len(), 1); + + // For this example, we convert convert the input argument to an + // array of floating point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + // implement a simple moving average by averaging the current + // value with the previous value + // + // value | avg + // ------+------ + // 10 | 10 + // 20 | 15 + // 30 | 25 + // 30 | 30 + // + let mut previous_value = None; + let new_values: Float64Array = arr + .values() + .iter() + .map(|&value| { + let new_value = previous_value + .map(|previous_value| (value + previous_value) / 2.0) + .unwrap_or(value); + previous_value = Some(value); + new_value + }) + .collect(); + + Ok(Arc::new(new_values)) + } + + fn evaluate_stateful( + &mut self, + _values: &[arrow::array::ArrayRef], + ) -> Result { + Err(DataFusionError::NotImplemented( + "evaluate_stateful is not implemented by default".into(), + )) + } + + fn evaluate_with_rank( + &self, + _num_rows: usize, + _ranks_in_partition: &[std::ops::Range], + ) -> Result { + Err(DataFusionError::NotImplemented( + "evaluate_partition_with_rank is not implemented by default".into(), + )) + } + + fn evaluate_inside_range( + &self, + _values: &[arrow::array::ArrayRef], + _range: &std::ops::Range, + ) -> Result { + Err(DataFusionError::NotImplemented( + "evaluate_inside_range is not implemented by default".into(), + )) + } +} + +// TODO show how to use other evaluate methods diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 6b81a39691d6..c65e16f8868e 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -34,7 +34,7 @@ use crate::{ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - DescribeTable, StringifiedPlan, UserDefinedLogicalNode, + DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -797,6 +797,20 @@ impl SessionContext { .insert(f.name.clone(), Arc::new(f)); } + /// Registers an window UDF within this context. + /// + /// Note in SQL queries, window function names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"` + /// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"` + pub fn register_udwf(&self, f: WindowUDF) { + self.state + .write() + .window_functions + .insert(f.name.clone(), Arc::new(f)); + } + /// Creates a [`DataFrame`] for reading a data source. /// /// For more control such as reading multiple files, you can use @@ -1290,6 +1304,10 @@ impl FunctionRegistry for SessionContext { fn udaf(&self, name: &str) -> Result> { self.state.read().udaf(name) } + + fn udwf(&self, name: &str) -> Result> { + self.state.read().udwf(name) + } } /// A planner used to add extensions to DataFusion logical and physical plans. @@ -1340,6 +1358,8 @@ pub struct SessionState { scalar_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, + /// Window functions registered in the context + window_functions: HashMap>, /// Deserializer registry for extensions. serializer_registry: Arc, /// Session configuration @@ -1483,6 +1503,7 @@ impl SessionState { catalog_list, scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), serializer_registry: Arc::new(EmptySerializerRegistry), config, execution_props: ExecutionProps::new(), @@ -1959,6 +1980,11 @@ impl SessionState { &self.aggregate_functions } + /// Return reference to window functions + pub fn window_functions(&self) -> &HashMap> { + &self.window_functions + } + /// Return [SerializerRegistry] for extensions pub fn serializer_registry(&self) -> Arc { self.serializer_registry.clone() @@ -1992,6 +2018,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { self.state.aggregate_functions().get(name).cloned() } + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + fn get_variable_type(&self, variable_names: &[String]) -> Option { if variable_names.is_empty() { return None; @@ -2039,6 +2069,16 @@ impl FunctionRegistry for SessionState { )) }) } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Plan(format!( + "There is no UDWF named \"{name}\" in the registry" + )) + }) + } } impl OptimizerConfig for SessionState { @@ -2068,6 +2108,7 @@ impl From<&SessionState> for TaskContext { state.config.clone(), state.scalar_functions.clone(), state.aggregate_functions.clone(), + state.window_functions.clone(), state.runtime_env.clone(), ) } diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 8fafbef2e55c..47ea0479e1fb 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -198,7 +198,11 @@ fn create_udwf_window_expr( name: String, ) -> Result> { // need to get the types into an owned vec for some reason - let input_types: Vec<_> = input_schema.fields().iter().map(|f| f.data_type().clone()).collect(); + let input_types: Vec<_> = args + .iter() + .map(|arg| arg.data_type(input_schema).map(|dt| dt.clone())) + .collect::>()?; + // figure out the output type let data_type = (fun.return_type)(&input_types)?; Ok(Arc::new(WindowUDFExpr { @@ -227,7 +231,11 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { fn field(&self) -> Result { let nullable = false; - Ok(Field::new(&self.name, self.data_type.as_ref().clone(), nullable)) + Ok(Field::new( + &self.name, + self.data_type.as_ref().clone(), + nullable, + )) } fn expressions(&self) -> Vec> { @@ -235,7 +243,11 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn create_evaluator(&self) -> Result> { - todo!() + (self.fun.partition_evaluator)() + } + + fn name(&self) -> &str { + &self.name } } diff --git a/datafusion/core/tests/data/cars.csv b/datafusion/core/tests/data/cars.csv index 24f363ccf432..bc40f3b01e7a 100644 --- a/datafusion/core/tests/data/cars.csv +++ b/datafusion/core/tests/data/cars.csv @@ -24,4 +24,3 @@ green,15.1,1996-04-12T12:05:11.000000000 green,15.2,1996-04-12T12:05:12.000000000 green,8.0,1996-04-12T12:05:13.000000000 green,2.0,1996-04-12T12:05:14.000000000 -green,0.0,1996-04-12T12:05:15.000000000 diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index ef06c74cc292..9ba487e715b3 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -18,7 +18,7 @@ //! FunctionRegistry trait use datafusion_common::Result; -use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode}; +use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use std::{collections::HashSet, sync::Arc}; /// A registry knows how to build logical expressions out of user-defined function' names @@ -31,6 +31,9 @@ pub trait FunctionRegistry { /// Returns a reference to the udaf named `name`. fn udaf(&self, name: &str) -> Result>; + + /// Returns a reference to the udwf named `name`. + fn udwf(&self, name: &str) -> Result>; } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index ca1bc9369e35..15581c85bf5f 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -24,7 +24,7 @@ use datafusion_common::{ config::{ConfigOptions, Extensions}, DataFusionError, Result, }; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::{ config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, @@ -48,6 +48,8 @@ pub struct TaskContext { scalar_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, + /// Window functions associated with this task context + window_functions: HashMap>, /// Runtime environment associated with this task context runtime: Arc, } @@ -60,6 +62,7 @@ impl TaskContext { session_config: SessionConfig, scalar_functions: HashMap>, aggregate_functions: HashMap>, + window_functions: HashMap>, runtime: Arc, ) -> Self { Self { @@ -68,6 +71,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, } } @@ -92,6 +96,7 @@ impl TaskContext { config.set(&k, &v)?; } let session_config = SessionConfig::from(config); + let window_functions = HashMap::new(); Ok(Self::new( Some(task_id), @@ -99,6 +104,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, )) } @@ -153,6 +159,16 @@ impl FunctionRegistry for TaskContext { )) }) } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDWF named \"{name}\" in the TaskContext" + )) + }) + } } #[cfg(test)] diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 0e12bf9da21a..c0088036905f 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -58,8 +58,7 @@ pub type StateTypeFunction = /// Factory that creates a PartitionEvaluator for the given aggregate, given /// its return datatype. pub type PartitionEvaluatorFunctionFactory = - Arc Result> + Send + Sync>; - + Arc Result> + Send + Sync>; macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 8a556669709a..fdb1b6e60a8a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -42,6 +42,7 @@ mod literal; pub mod logical_plan; mod nullif; mod operator; +pub mod partition_evaluator; mod signature; pub mod struct_expressions; mod table_source; @@ -52,9 +53,8 @@ mod udf; mod udwf; pub mod utils; pub mod window_frame; -pub mod window_function; -pub mod partition_evaluator; pub mod window_frame_state; +pub mod window_function; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index 316b4a5d58be..87261db4addb 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -25,7 +25,6 @@ use std::any::Any; use std::fmt::Debug; use std::ops::Range; - /// Trait for the state managed by this partition evaluator /// /// This follows the existing pattern, but maybe we can improve it :thinking: @@ -138,7 +137,7 @@ pub trait PartitionEvaluator: Debug + Send { /// Sets the internal state for window function /// /// Only used for stateful evaluation - fn set_state(&mut self, state: Box) -> Result<()> { + fn set_state(&mut self, _state: Box) -> Result<()> { Err(DataFusionError::NotImplemented( "set_state is not implemented for this window function".to_string(), )) diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a1b767fa2804..7fa8e52d4648 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -17,9 +17,9 @@ //! Support for user-defined window (UDWF) window functions -use std::fmt::{self, Debug, Formatter}; +use std::fmt::{self, Debug, Display, Formatter}; -use crate::{ReturnTypeFunction, Signature}; +use crate::{function::PartitionEvaluatorFunctionFactory, ReturnTypeFunction, Signature}; /// Logical representation of a user-defined window function (UDWF) /// A UDAF is different from a UDF in that it is stateful across batches. @@ -31,8 +31,8 @@ pub struct WindowUDF { pub signature: Signature, /// Return type pub return_type: ReturnTypeFunction, - // /// actual implementation - // pub accumulator: AccumulatorFunctionImplementation, + /// Return the partition functon + pub partition_evaluator: PartitionEvaluatorFunctionFactory, } impl Debug for WindowUDF { @@ -41,10 +41,16 @@ impl Debug for WindowUDF { } } +/// Defines how the WindowUDF is shown to users +impl Display for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.name) + } +} + impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - todo!(); - //self.name == other.name && self.signature == other.signature + self.name == other.name && self.signature == other.signature } } @@ -52,8 +58,8 @@ impl Eq for WindowUDF {} impl std::hash::Hash for WindowUDF { fn hash(&self, state: &mut H) { - // self.name.hash(state); - // self.signature.hash(state); + self.name.hash(state); + self.signature.hash(state); } } diff --git a/datafusion/expr/src/window_frame_state.rs b/datafusion/expr/src/window_frame_state.rs index 3100f3a58f1e..1dc11d0fdefe 100644 --- a/datafusion/expr/src/window_frame_state.rs +++ b/datafusion/expr/src/window_frame_state.rs @@ -18,20 +18,20 @@ //! This module provides utilities for window frame index calculations //! depending on the window frame mode: RANGE, ROWS, GROUPS. +use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use arrow::array::ArrayRef; -use arrow::compute::{concat}; +use arrow::compute::concat; use arrow::compute::kernels::sort::SortOptions; +use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice}; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; use std::collections::VecDeque; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; - /// State for each unique partition determined according to PARTITION BY column(s) #[derive(Debug)] pub struct PartitionBatchState { @@ -43,7 +43,6 @@ pub struct PartitionBatchState { pub n_out_row: usize, } - #[derive(Debug)] pub struct WindowAggState { /// The range that we calculate the window function @@ -110,6 +109,21 @@ impl WindowAggState { } } +impl WindowAggState { + pub fn new(out_type: &DataType) -> Result { + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + Ok(Self { + window_frame_range: Range { start: 0, end: 0 }, + window_frame_ctx: None, + last_calculated_index: 0, + offset_pruned_rows: 0, + out_col: empty_out_col, + n_row_result_missing: 0, + is_end: false, + }) + } +} + /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] pub enum WindowFrameContext { @@ -629,9 +643,9 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result fun.fmt(f), WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunction::WindowUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunction::WindowUDF(fun) => fun.fmt(f), } } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 030c20c5743c..4e5665f0a62e 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -21,7 +21,6 @@ use std::any::Any; use std::ops::Range; use std::sync::Arc; -use super::window_frame_state::WindowFrameContext; use super::BuiltInWindowFunctionExpr; use super::WindowExpr; use crate::window::window_expr::{ @@ -37,6 +36,7 @@ use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::window_frame_state::WindowFrameContext; use datafusion_expr::WindowFrame; /// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`]. @@ -211,12 +211,21 @@ impl WindowExpr for BuiltInWindowExpr { state.update(&out_col, partition_batch_state)?; if self.window_frame.start_bound.is_unbounded() { - let mut evaluator_state = evaluator.state()?; - if let BuiltinWindowState::NthValue(nth_value_state) = - &mut evaluator_state - { - memoize_nth_value(state, nth_value_state)?; - evaluator.set_state(&evaluator_state)?; + let Some(evaluator_state) = evaluator.state()? else { + return Ok(()) + }; + + let evaluator_state = evaluator_state + .as_any() + .downcast_ref::() + .unwrap(); + + if let BuiltinWindowState::NthValue(nth_value_state) = &evaluator_state { + let mut nth_value_state = nth_value_state.clone(); + memoize_nth_value(state, &mut nth_value_state)?; + let evaluator_state = + Box::new(BuiltinWindowState::NthValue(nth_value_state)); + evaluator.set_state(evaluator_state)?; } } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 59438a72f275..9f6694dc77d5 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use super::partition_evaluator::PartitionEvaluator; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::Result; +use datafusion_expr::partition_evaluator::PartitionEvaluator; use std::any::Any; use std::sync::Arc; @@ -46,9 +46,7 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default /// implementation returns placeholder text. - fn name(&self) -> &str { - "BuiltInWindowFunctionExpr: default name" - } + fn name(&self) -> &str; /// Evaluate window function's arguments against the input window /// batch and return an [`ArrayRef`]. diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 46997578001d..945ab183e6b7 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -18,13 +18,13 @@ //! Defines physical expression for `cume_dist` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::Float64Array; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; +use datafusion_expr::partition_evaluator::PartitionEvaluator; use std::any::Any; use std::iter; use std::ops::Range; diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 8d97d5ebc0b3..5414d6dff98c 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -18,7 +18,6 @@ //! Defines physical expression for `lead` and `lag` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, LeadLagState}; use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; @@ -27,6 +26,7 @@ use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::cmp::min; use std::ops::{Neg, Range}; @@ -182,9 +182,13 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { - fn state(&self) -> Result { + fn state( + &self, + ) -> Result>, DataFusionError> { // If we do not use state we just return Default - Ok(BuiltinWindowState::LeadLag(self.state.clone())) + Ok(Some(Box::new(BuiltinWindowState::LeadLag( + self.state.clone(), + )))) } fn update_state( diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 3d7ee7a67f47..88a5599ac036 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -30,12 +30,12 @@ mod window_expr; pub use aggregate::PlainAggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; -pub use partition_evaluator::PartitionEvaluator; +pub use datafusion_expr::partition_evaluator::PartitionEvaluator; +pub use datafusion_expr::window_frame_state::PartitionBatchState; +pub use datafusion_expr::window_frame_state::WindowAggState; pub use sliding_aggregate::SlidingAggregateWindowExpr; -pub use window_expr::PartitionBatchState; pub use window_expr::PartitionBatches; pub use window_expr::PartitionKey; pub use window_expr::PartitionWindowAggStates; -pub use window_expr::WindowAggState; pub use window_expr::WindowExpr; pub use window_expr::WindowState; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 4bfe514c38da..5dbf34bee854 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -18,7 +18,6 @@ //! Defines physical expressions for `first_value`, `last_value`, and `nth_value` //! that can evaluated at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, NthValueKind, NthValueState}; use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; @@ -26,6 +25,7 @@ use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -152,9 +152,13 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - fn state(&self) -> Result { + fn state( + &self, + ) -> Result>, DataFusionError> { // If we do not use state we just return Default - Ok(BuiltinWindowState::NthValue(self.state.clone())) + Ok(Some(Box::new(BuiltinWindowState::NthValue( + self.state.clone(), + )))) } fn update_state( @@ -169,7 +173,8 @@ impl PartitionEvaluator for NthValueEvaluator { Ok(()) } - fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { + fn set_state(&mut self, state: Box<(dyn PartitionState + 'static)>) -> Result<()> { + let state = state.as_any().downcast_ref::().unwrap(); if let BuiltinWindowState::NthValue(nth_value_state) = state { self.state = nth_value_state.clone() } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 479fa263337a..6618de423f2b 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -18,13 +18,13 @@ //! Defines physical expression for `ntile` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::Field; use arrow_schema::DataType; use datafusion_common::Result; +use datafusion_expr::partition_evaluator::PartitionEvaluator; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 89ca40dd564f..81d8c7107108 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -18,7 +18,6 @@ //! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, RankState}; use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; @@ -27,6 +26,7 @@ use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::utils::get_row_at_idx; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::iter; use std::ops::Range; @@ -125,8 +125,10 @@ impl PartitionEvaluator for RankEvaluator { Ok(Range { start, end }) } - fn state(&self) -> Result { - Ok(BuiltinWindowState::Rank(self.state.clone())) + fn state( + &self, + ) -> Result>, DataFusionError> { + Ok(Some(Box::new(BuiltinWindowState::Rank(self.state.clone())))) } fn update_state( diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index 9883d67f7cd8..229504bbd7f2 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -17,13 +17,13 @@ //! Defines physical expression for `row_number` that can evaluated at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, NumRowsState}; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -76,9 +76,11 @@ pub(crate) struct NumRowsEvaluator { } impl PartitionEvaluator for NumRowsEvaluator { - fn state(&self) -> Result { + fn state(&self) -> Result>> { // If we do not use state we just return Default - Ok(BuiltinWindowState::NumRows(self.state.clone())) + Ok(Some(Box::new(BuiltinWindowState::NumRows( + self.state.clone(), + )))) } fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index dc459abdcaee..4dbed8705505 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,16 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::window_frame_state::WindowFrameContext; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; +use datafusion_expr::window_frame_state::{ + PartitionBatchState, WindowAggState, WindowFrameContext, +}; use datafusion_expr::{Accumulator, WindowFrame}; use indexmap::IndexMap; use std::any::Any; @@ -337,6 +338,12 @@ pub enum BuiltinWindowState { Default, } +impl PartitionState for BuiltinWindowState { + fn as_any(&self) -> &dyn Any { + self + } +} + /// Key for IndexMap for each unique partition /// /// For instance, if window frame is `OVER(PARTITION BY a,b)`, @@ -352,18 +359,3 @@ pub type PartitionWindowAggStates = IndexMap; /// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. pub type PartitionBatches = IndexMap; - -impl WindowAggState { - pub fn new(out_type: &DataType) -> Result { - let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); - Ok(Self { - window_frame_range: Range { start: 0, end: 0 }, - window_frame_ctx: None, - last_calculated_index: 0, - offset_pruned_rows: 0, - out_col: empty_out_col, - n_row_result_missing: 0, - is_end: false, - }) - } -} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 104a65832dcd..ab667d97c676 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -183,11 +183,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn find_window_func(&self, name: &str) -> Result { window_function::find_df_window_func(name) + // next check user defined aggregates .or_else(|| { self.schema_provider .get_aggregate_meta(name) .map(WindowFunction::AggregateUDF) }) + // next check user defined window functions + .or_else(|| { + self.schema_provider + .get_window_meta(name) + .map(WindowFunction::WindowUDF) + }) .ok_or_else(|| { DataFusionError::Plan(format!("There is no window function named {name}")) }) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ceec01037425..26ff5466f408 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -22,6 +22,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::field_not_found; +use datafusion_expr::WindowUDF; use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -46,6 +47,8 @@ pub trait ContextProvider { fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; /// Getter for system/user-defined variable type fn get_variable_type(&self, variable_names: &[String]) -> Option;