diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 31595c980a30..f1249a2fabdd 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false } prost-derive = { version = "0.11", default-features = false } serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.82" +tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } tonic = "0.9" url = "2.2" diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index df6ad5a467b6..02dd9c417325 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -57,6 +57,7 @@ cargo run --example csv_sql - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) +- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs new file mode 100644 index 000000000000..8de7f575ac37 --- /dev/null +++ b/datafusion-examples/examples/simple_udwf.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{AsArray, Float64Array, ArrayRef}, + datatypes::Float64Type, +}; +use arrow_schema::DataType; +use datafusion::datasource::file_format::options::CsvReadOptions; + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ + partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF, +}; + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = format!("datafusion/core/tests/data/cars.csv"); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + + // register the window function with DataFusion so wecan call it + ctx.register_udwf(smooth_it()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a seprate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be orderd by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit window + // + // `ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // todo show how to run dataframe API as well + + Ok(()) +} +fn smooth_it() -> WindowUDF { + WindowUDF { + name: String::from("smooth_it"), + // it will take 1 arguments -- the column to smooth + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + return_type: Arc::new(return_type), + partition_evaluator: Arc::new(make_partition_evaluator), + // specify that the user defined window function gets a window + // frame (so that the user can use the window frame definition + // (ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) + uses_window_frame: true, + supports_bounded_execution: false, + } +} + +/// Compute the return type of the smooth_it window function given +/// arguments of `arg_types`. +fn return_type(arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(DataFusionError::Plan(format!( + "my_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ))); + } + Ok(Arc::new(arg_types[0].clone())) +} + +/// Create a `PartitionEvalutor` to evaluate this function on a new +/// partition. +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// These different evaluation methods are called depending on the various settings of WindowUDF +impl PartitionEvaluator for MyPartitionEvaluator { + fn get_range(&self, _idx: usize, _n_rows: usize) -> Result> { + Err(DataFusionError::NotImplemented( + "get_range is not implemented for this window function".to_string(), + )) + } + + /// This function is called once per input row. + /// + /// `range` + /// specifies which indexes of `values` should be considered for + /// the calculation. + /// + /// Note this is not the fastest way to evaluate a window + /// function. It is much faster to implement evaluate_stateful or + /// range less / rank based calculations if possible. + fn evaluate_inside_range( + &self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + //println!("evaluate_inside_range(). range: {range:#?}, values: {values:#?}"); + + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr + .values() + .iter() + .skip(range.start) + .take(range_len) + .sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 6b81a39691d6..c65e16f8868e 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -34,7 +34,7 @@ use crate::{ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - DescribeTable, StringifiedPlan, UserDefinedLogicalNode, + DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -797,6 +797,20 @@ impl SessionContext { .insert(f.name.clone(), Arc::new(f)); } + /// Registers an window UDF within this context. + /// + /// Note in SQL queries, window function names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"` + /// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"` + pub fn register_udwf(&self, f: WindowUDF) { + self.state + .write() + .window_functions + .insert(f.name.clone(), Arc::new(f)); + } + /// Creates a [`DataFrame`] for reading a data source. /// /// For more control such as reading multiple files, you can use @@ -1290,6 +1304,10 @@ impl FunctionRegistry for SessionContext { fn udaf(&self, name: &str) -> Result> { self.state.read().udaf(name) } + + fn udwf(&self, name: &str) -> Result> { + self.state.read().udwf(name) + } } /// A planner used to add extensions to DataFusion logical and physical plans. @@ -1340,6 +1358,8 @@ pub struct SessionState { scalar_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, + /// Window functions registered in the context + window_functions: HashMap>, /// Deserializer registry for extensions. serializer_registry: Arc, /// Session configuration @@ -1483,6 +1503,7 @@ impl SessionState { catalog_list, scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), serializer_registry: Arc::new(EmptySerializerRegistry), config, execution_props: ExecutionProps::new(), @@ -1959,6 +1980,11 @@ impl SessionState { &self.aggregate_functions } + /// Return reference to window functions + pub fn window_functions(&self) -> &HashMap> { + &self.window_functions + } + /// Return [SerializerRegistry] for extensions pub fn serializer_registry(&self) -> Arc { self.serializer_registry.clone() @@ -1992,6 +2018,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { self.state.aggregate_functions().get(name).cloned() } + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + fn get_variable_type(&self, variable_names: &[String]) -> Option { if variable_names.is_empty() { return None; @@ -2039,6 +2069,16 @@ impl FunctionRegistry for SessionState { )) }) } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Plan(format!( + "There is no UDWF named \"{name}\" in the registry" + )) + }) + } } impl OptimizerConfig for SessionState { @@ -2068,6 +2108,7 @@ impl From<&SessionState> for TaskContext { state.config.clone(), state.scalar_functions.clone(), state.aggregate_functions.clone(), + state.window_functions.clone(), state.runtime_env.clone(), ) } diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 73a3eb10c28f..24cc254ea07d 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -26,15 +26,15 @@ use crate::physical_plan::{ udaf, ExecutionPlan, PhysicalExpr, }; use arrow::datatypes::Schema; -use arrow_schema::{SchemaRef, SortOptions}; +use arrow_schema::{DataType, Field, SchemaRef, SortOptions}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ window_function::{BuiltInWindowFunction, WindowFunction}, - WindowFrame, + WindowFrame, WindowUDF, }; use datafusion_physical_expr::window::{ - BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr, + BuiltInWindowFunctionExpr, PartitionEvaluator, SlidingAggregateWindowExpr, }; use std::borrow::Borrow; use std::convert::TryInto; @@ -97,6 +97,12 @@ pub fn create_window_expr( order_by, window_frame, )), + WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )), }) } @@ -184,6 +190,79 @@ fn create_built_in_window_expr( }) } +/// Creates a `BuiltInWindowFunctionExpr` suitable for a user defined window function +fn create_udwf_window_expr( + fun: &Arc, + args: &[Arc], + input_schema: &Schema, + name: String, +) -> Result> { + // need to get the types into an owned vec for some reason + let input_types: Vec<_> = args + .iter() + .map(|arg| arg.data_type(input_schema).map(|dt| dt.clone())) + .collect::>()?; + + // figure out the output type + let data_type = (fun.return_type)(&input_types)?; + Ok(Arc::new(WindowUDFExpr { + fun: Arc::clone(fun), + args: args.to_vec(), + name, + data_type, + })) +} + +// Implement BuiltInWindowFunctionExpr for WindowUDF +#[derive(Clone, Debug)] +struct WindowUDFExpr { + fun: Arc, + args: Vec>, + /// Display name + name: String, + /// result type + data_type: Arc, +} + +impl BuiltInWindowFunctionExpr for WindowUDFExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn field(&self) -> Result { + let nullable = false; + Ok(Field::new( + &self.name, + self.data_type.as_ref().clone(), + nullable, + )) + } + + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn create_evaluator(&self) -> Result> { + (self.fun.partition_evaluator)() + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + None + } + + fn supports_bounded_execution(&self) -> bool { + self.fun.supports_bounded_execution + } + + fn uses_window_frame(&self) -> bool { + self.fun.uses_window_frame + } +} + pub(crate) fn calc_requirements< T: Borrow>, S: Borrow, diff --git a/datafusion/core/tests/data/cars.csv b/datafusion/core/tests/data/cars.csv new file mode 100644 index 000000000000..bc40f3b01e7a --- /dev/null +++ b/datafusion/core/tests/data/cars.csv @@ -0,0 +1,26 @@ +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs index 1047f73df4cd..3a16fef19caf 100644 --- a/datafusion/core/tests/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined_aggregates.rs @@ -18,6 +18,9 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions +// TODO: rename this file user_defined_functions.rs (as it has examples of user defined window functions too now) + + use arrow::datatypes::Fields; use std::sync::Arc; @@ -239,3 +242,14 @@ impl Accumulator for FirstSelector { std::mem::size_of_val(self) } } + + +// Test 1: Evaluate over the entire partition + +// Test 2: Evlaute over window ranges + +// Test 3: Evaluate using rank() + +// Test 4: Evaluate using stateful evalution + +// Test 5: Show using a scalar as argument diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index ef06c74cc292..9ba487e715b3 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -18,7 +18,7 @@ //! FunctionRegistry trait use datafusion_common::Result; -use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode}; +use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use std::{collections::HashSet, sync::Arc}; /// A registry knows how to build logical expressions out of user-defined function' names @@ -31,6 +31,9 @@ pub trait FunctionRegistry { /// Returns a reference to the udaf named `name`. fn udaf(&self, name: &str) -> Result>; + + /// Returns a reference to the udwf named `name`. + fn udwf(&self, name: &str) -> Result>; } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index ca1bc9369e35..15581c85bf5f 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -24,7 +24,7 @@ use datafusion_common::{ config::{ConfigOptions, Extensions}, DataFusionError, Result, }; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::{ config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, @@ -48,6 +48,8 @@ pub struct TaskContext { scalar_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, + /// Window functions associated with this task context + window_functions: HashMap>, /// Runtime environment associated with this task context runtime: Arc, } @@ -60,6 +62,7 @@ impl TaskContext { session_config: SessionConfig, scalar_functions: HashMap>, aggregate_functions: HashMap>, + window_functions: HashMap>, runtime: Arc, ) -> Self { Self { @@ -68,6 +71,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, } } @@ -92,6 +96,7 @@ impl TaskContext { config.set(&k, &v)?; } let session_config = SessionConfig::from(config); + let window_functions = HashMap::new(); Ok(Self::new( Some(task_id), @@ -99,6 +104,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, )) } @@ -153,6 +159,16 @@ impl FunctionRegistry for TaskContext { )) }) } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDWF named \"{name}\" in the TaskContext" + )) + }) + } } #[cfg(test)] diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bec672ab6f6c..c0088036905f 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -19,6 +19,7 @@ use crate::function_err::generate_signature_error_msg; use crate::nullif::SUPPORTED_NULLIF_TYPES; +use crate::partition_evaluator::PartitionEvaluator; use crate::type_coercion::functions::data_types; use crate::ColumnarValue; use crate::{ @@ -54,6 +55,11 @@ pub type AccumulatorFunctionImplementation = pub type StateTypeFunction = Arc Result>> + Send + Sync>; +/// Factory that creates a PartitionEvaluator for the given aggregate, given +/// its return datatype. +pub type PartitionEvaluatorFunctionFactory = + Arc Result> + Send + Sync>; + macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { fn $FUNC(arg_type: &DataType, name: &str) -> Result { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 5945480aba1d..fdb1b6e60a8a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -42,6 +42,7 @@ mod literal; pub mod logical_plan; mod nullif; mod operator; +pub mod partition_evaluator; mod signature; pub mod struct_expressions; mod table_source; @@ -49,8 +50,10 @@ pub mod tree_node; pub mod type_coercion; mod udaf; mod udf; +mod udwf; pub mod utils; pub mod window_frame; +pub mod window_frame_state; pub mod window_function; pub use accumulator::Accumulator; @@ -74,6 +77,7 @@ pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::ScalarUDF; +pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs similarity index 84% rename from datafusion/physical-expr/src/window/partition_evaluator.rs rename to datafusion/expr/src/partition_evaluator.rs index db60fdd5f1fa..381274bbf3c4 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -17,14 +17,24 @@ //! Partition evaluation module -use crate::window::window_expr::BuiltinWindowState; -use crate::window::WindowAggState; +use crate::window_frame_state::WindowAggState; use arrow::array::ArrayRef; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; +use std::any::Any; use std::fmt::Debug; use std::ops::Range; +/// Trait for the state managed by this partition evaluator +/// +/// This follows the existing pattern, but maybe we can improve it :thinking: + +pub trait PartitionState { + /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + /// Partition evaluator for Window Functions /// /// # Background @@ -100,12 +110,9 @@ pub trait PartitionEvaluator: Debug + Send { false } - /// Returns the internal state of the window function - /// - /// Only used for stateful evaluation - fn state(&self) -> Result { - // If we do not use state we just return Default - Ok(BuiltinWindowState::Default) + /// Returns the internal state of the window function, if any + fn state(&self) -> Result>> { + Ok(None) } /// Updates the internal state for window function @@ -130,7 +137,7 @@ pub trait PartitionEvaluator: Debug + Send { /// Sets the internal state for window function /// /// Only used for stateful evaluation - fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> { + fn set_state(&mut self, _state: Box) -> Result<()> { Err(DataFusionError::NotImplemented( "set_state is not implemented for this window function".to_string(), )) @@ -146,9 +153,18 @@ pub trait PartitionEvaluator: Debug + Send { )) } + /// Evaluate a window function on the entire input partition, + /// `values`, producing one output row for each input. + /// /// Called for window functions that *do not use* values from the /// the window frame, such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, /// `PERCENT_RANK`, `CUME_DIST`, `LEAD`, `LAG`). + /// + /// The function is passed the window as the `value` and must + /// produce an output column with exactly `num_rows` values. + /// + /// `num_rows` is requied to correctly compute the output in case + /// `values.len() == 0` fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( "evaluate is not implemented by default".into(), @@ -207,7 +223,13 @@ pub trait PartitionEvaluator: Debug + Send { /// such as `FIRST_VALUE`, `LAST_VALUE`, `NTH_VALUE` and produce a /// single value for every row in the partition. /// - /// Returns a [`ScalarValue`] that is the value of the window function for the entire partition + /// This is the simplest and most general function to implement + /// but also the least performant as it creates the output one row + /// at a time. It is typically much faster to implement stateful + /// evaluation or one of the specialized ran or evaluate + /// + /// Returns a [`ScalarValue`] that is the value of the window + /// function within the rangefor the entire partition fn evaluate_inside_range( &self, _values: &[ArrayRef], diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs new file mode 100644 index 000000000000..f1faf76a74f6 --- /dev/null +++ b/datafusion/expr/src/udwf.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support for user-defined window (UDWF) window functions + +use std::fmt::{self, Debug, Display, Formatter}; + +use crate::{function::PartitionEvaluatorFunctionFactory, ReturnTypeFunction, Signature}; + +/// Logical representation of a user-defined window function (UDWF) +/// A UDAF is different from a UDF in that it is stateful across batches. +/// +/// Window Frames: +/// +/// TODO add a diagram here showing the input and the ouput w/ frames +/// (or document that elsewhere and link here) +#[derive(Clone)] +pub struct WindowUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// Return the partition functon + pub partition_evaluator: PartitionEvaluatorFunctionFactory, + /// If true, the window function requires the window frame (e.g. the sliding the declared window frame + /// (TODO: see documentation on + /// BuiltInWindowFunctionExpr::uses_window_frame) + pub uses_window_frame: bool, + /// Returns true if this function supports bounded execution (TODO see documentation on XXXX) + pub supports_bounded_execution: bool, + // TODO: Reverse expressions, uses rank() +} + +impl Debug for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator", &"") + .field("uses_window_frame", &self.uses_window_frame) + .field("supports_bounded_execution", &self.supports_bounded_execution) + .finish_non_exhaustive() + } +} + +/// Defines how the WindowUDF is shown to users +impl Display for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +impl PartialEq for WindowUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl Eq for WindowUDF {} + +impl std::hash::Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl WindowUDF { + // /// Create a new WindowUDF + // pub fn new( + // name: &str, + // signature: &Signature, + // return_type: &ReturnTypeFunction, + // accumulator: &AccumulatorFunctionImplementation, + // state_type: &StateTypeFunction, + // ) -> Self { + // Self { + // name: name.to_owned(), + // signature: signature.clone(), + // return_type: return_type.clone(), + // accumulator: accumulator.clone(), + // state_type: state_type.clone(), + // } + // } +} diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/expr/src/window_frame_state.rs similarity index 87% rename from datafusion/physical-expr/src/window/window_frame_state.rs rename to datafusion/expr/src/window_frame_state.rs index e23a58a09b66..1dc11d0fdefe 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/expr/src/window_frame_state.rs @@ -18,17 +18,112 @@ //! This module provides utilities for window frame index calculations //! depending on the window frame mode: RANGE, ROWS, GROUPS. +use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use arrow::array::ArrayRef; +use arrow::compute::concat; use arrow::compute::kernels::sort::SortOptions; +use arrow::datatypes::DataType; +use arrow::record_batch::RecordBatch; use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice}; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; use std::collections::VecDeque; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; +/// State for each unique partition determined according to PARTITION BY column(s) +#[derive(Debug)] +pub struct PartitionBatchState { + /// The record_batch belonging to current partition + pub record_batch: RecordBatch, + /// Flag indicating whether we have received all data for this partition + pub is_end: bool, + /// Number of rows emitted for each partition + pub n_out_row: usize, +} + +#[derive(Debug)] +pub struct WindowAggState { + /// The range that we calculate the window function + pub window_frame_range: Range, + pub window_frame_ctx: Option, + /// The index of the last row that its result is calculated inside the partition record batch buffer. + pub last_calculated_index: usize, + /// The offset of the deleted row number + pub offset_pruned_rows: usize, + /// Stores the results calculated by window frame + pub out_col: ArrayRef, + /// Keeps track of how many rows should be generated to be in sync with input record_batch. + // (For each row in the input record batch we need to generate a window result). + pub n_row_result_missing: usize, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +impl WindowAggState { + pub fn prune_state(&mut self, n_prune: usize) { + self.window_frame_range = Range { + start: self.window_frame_range.start - n_prune, + end: self.window_frame_range.end - n_prune, + }; + self.last_calculated_index -= n_prune; + self.offset_pruned_rows += n_prune; + + match self.window_frame_ctx.as_mut() { + // Rows have no state do nothing + Some(WindowFrameContext::Rows(_)) => {} + Some(WindowFrameContext::Range { .. }) => {} + Some(WindowFrameContext::Groups { state, .. }) => { + let mut n_group_to_del = 0; + for (_, end_idx) in &state.group_end_indices { + if n_prune < *end_idx { + break; + } + n_group_to_del += 1; + } + state.group_end_indices.drain(0..n_group_to_del); + state + .group_end_indices + .iter_mut() + .for_each(|(_, start_idx)| *start_idx -= n_prune); + state.current_group_idx -= n_group_to_del; + } + None => {} + }; + } +} + +impl WindowAggState { + pub fn update( + &mut self, + out_col: &ArrayRef, + partition_batch_state: &PartitionBatchState, + ) -> Result<()> { + self.last_calculated_index += out_col.len(); + self.out_col = concat(&[&self.out_col, &out_col])?; + self.n_row_result_missing = + partition_batch_state.record_batch.num_rows() - self.last_calculated_index; + self.is_end = partition_batch_state.is_end; + Ok(()) + } +} + +impl WindowAggState { + pub fn new(out_type: &DataType) -> Result { + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + Ok(Self { + window_frame_range: Range { start: 0, end: 0 }, + window_frame_ctx: None, + last_calculated_index: 0, + offset_pruned_rows: 0, + out_col: empty_out_col, + n_row_result_missing: 0, + is_end: false, + }) + } +} + /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] pub enum WindowFrameContext { @@ -547,11 +642,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result), + /// A user defined aggregate function + WindowUDF(Arc), } /// Find DataFusion's built-in window function by name. @@ -69,6 +74,7 @@ impl fmt::Display for WindowFunction { WindowFunction::AggregateFunction(fun) => fun.fmt(f), WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunction::WindowUDF(fun) => fun.fmt(f), } } } @@ -166,6 +172,9 @@ pub fn return_type( WindowFunction::AggregateUDF(fun) => { Ok((*(fun.return_type)(input_expr_types)?).clone()) } + WindowFunction::WindowUDF(fun) => { + Ok((*(fun.return_type)(input_expr_types)?).clone()) + } } } @@ -202,6 +211,7 @@ pub fn signature(fun: &WindowFunction) -> Signature { WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun), WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun), WindowFunction::AggregateUDF(fun) => fun.signature.clone(), + WindowFunction::WindowUDF(fun) => fun.signature.clone(), } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 030c20c5743c..4e5665f0a62e 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -21,7 +21,6 @@ use std::any::Any; use std::ops::Range; use std::sync::Arc; -use super::window_frame_state::WindowFrameContext; use super::BuiltInWindowFunctionExpr; use super::WindowExpr; use crate::window::window_expr::{ @@ -37,6 +36,7 @@ use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::window_frame_state::WindowFrameContext; use datafusion_expr::WindowFrame; /// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`]. @@ -211,12 +211,21 @@ impl WindowExpr for BuiltInWindowExpr { state.update(&out_col, partition_batch_state)?; if self.window_frame.start_bound.is_unbounded() { - let mut evaluator_state = evaluator.state()?; - if let BuiltinWindowState::NthValue(nth_value_state) = - &mut evaluator_state - { - memoize_nth_value(state, nth_value_state)?; - evaluator.set_state(&evaluator_state)?; + let Some(evaluator_state) = evaluator.state()? else { + return Ok(()) + }; + + let evaluator_state = evaluator_state + .as_any() + .downcast_ref::() + .unwrap(); + + if let BuiltinWindowState::NthValue(nth_value_state) = &evaluator_state { + let mut nth_value_state = nth_value_state.clone(); + memoize_nth_value(state, &mut nth_value_state)?; + let evaluator_state = + Box::new(BuiltinWindowState::NthValue(nth_value_state)); + evaluator.set_state(evaluator_state)?; } } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 59438a72f275..912cc152203d 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use super::partition_evaluator::PartitionEvaluator; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::Result; +use datafusion_expr::partition_evaluator::PartitionEvaluator; use std::any::Any; use std::sync::Arc; @@ -46,9 +46,7 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default /// implementation returns placeholder text. - fn name(&self) -> &str { - "BuiltInWindowFunctionExpr: default name" - } + fn name(&self) -> &str; /// Evaluate window function's arguments against the input window /// batch and return an [`ArrayRef`]. @@ -88,10 +86,29 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { false } - /// Does the window function use the values from its window frame? + /// Does the window function use the values from the window frame, + /// if one is specified? /// /// If this function returns true, [`Self::create_evaluator`] must /// implement [`PartitionEvaluator::evaluate_inside_range`] + /// + /// This is an optimization: certain window functions are not + /// affected by the window frame, and thus DataFusion skips the + /// (costly) calculation of the window frame, if possible. + /// + /// For example, the `LAG` built in window function does not use the + /// values of its window frame (it can be computed in one shot on + /// the entire partition with `Self::evalute`) + /// + /// ```sql + /// lag(x, 1) OVER (ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) + /// ``` + /// + /// However, the LEAD built in does not + /// + /// ```sql + /// avg(x) OVER (ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) + /// ``` fn uses_window_frame(&self) -> bool { false } diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 46997578001d..945ab183e6b7 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -18,13 +18,13 @@ //! Defines physical expression for `cume_dist` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::Float64Array; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; +use datafusion_expr::partition_evaluator::PartitionEvaluator; use std::any::Any; use std::iter; use std::ops::Range; diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 8d97d5ebc0b3..5414d6dff98c 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -18,7 +18,6 @@ //! Defines physical expression for `lead` and `lag` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, LeadLagState}; use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; @@ -27,6 +26,7 @@ use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::cmp::min; use std::ops::{Neg, Range}; @@ -182,9 +182,13 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { - fn state(&self) -> Result { + fn state( + &self, + ) -> Result>, DataFusionError> { // If we do not use state we just return Default - Ok(BuiltinWindowState::LeadLag(self.state.clone())) + Ok(Some(Box::new(BuiltinWindowState::LeadLag( + self.state.clone(), + )))) } fn update_state( diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 4c8b8b5a4e4b..88a5599ac036 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -22,21 +22,20 @@ pub(crate) mod cume_dist; pub(crate) mod lead_lag; pub(crate) mod nth_value; pub(crate) mod ntile; -pub(crate) mod partition_evaluator; pub(crate) mod rank; pub(crate) mod row_number; mod sliding_aggregate; mod window_expr; -mod window_frame_state; pub use aggregate::PlainAggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; +pub use datafusion_expr::partition_evaluator::PartitionEvaluator; +pub use datafusion_expr::window_frame_state::PartitionBatchState; +pub use datafusion_expr::window_frame_state::WindowAggState; pub use sliding_aggregate::SlidingAggregateWindowExpr; -pub use window_expr::PartitionBatchState; pub use window_expr::PartitionBatches; pub use window_expr::PartitionKey; pub use window_expr::PartitionWindowAggStates; -pub use window_expr::WindowAggState; pub use window_expr::WindowExpr; pub use window_expr::WindowState; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 4bfe514c38da..5dbf34bee854 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -18,7 +18,6 @@ //! Defines physical expressions for `first_value`, `last_value`, and `nth_value` //! that can evaluated at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, NthValueKind, NthValueState}; use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; @@ -26,6 +25,7 @@ use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -152,9 +152,13 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - fn state(&self) -> Result { + fn state( + &self, + ) -> Result>, DataFusionError> { // If we do not use state we just return Default - Ok(BuiltinWindowState::NthValue(self.state.clone())) + Ok(Some(Box::new(BuiltinWindowState::NthValue( + self.state.clone(), + )))) } fn update_state( @@ -169,7 +173,8 @@ impl PartitionEvaluator for NthValueEvaluator { Ok(()) } - fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { + fn set_state(&mut self, state: Box<(dyn PartitionState + 'static)>) -> Result<()> { + let state = state.as_any().downcast_ref::().unwrap(); if let BuiltinWindowState::NthValue(nth_value_state) = state { self.state = nth_value_state.clone() } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 479fa263337a..6618de423f2b 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -18,13 +18,13 @@ //! Defines physical expression for `ntile` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::Field; use arrow_schema::DataType; use datafusion_common::Result; +use datafusion_expr::partition_evaluator::PartitionEvaluator; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 89ca40dd564f..81d8c7107108 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -18,7 +18,6 @@ //! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, RankState}; use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; @@ -27,6 +26,7 @@ use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::utils::get_row_at_idx; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::iter; use std::ops::Range; @@ -125,8 +125,10 @@ impl PartitionEvaluator for RankEvaluator { Ok(Range { start, end }) } - fn state(&self) -> Result { - Ok(BuiltinWindowState::Rank(self.state.clone())) + fn state( + &self, + ) -> Result>, DataFusionError> { + Ok(Some(Box::new(BuiltinWindowState::Rank(self.state.clone())))) } fn update_state( diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index 9883d67f7cd8..229504bbd7f2 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -17,13 +17,13 @@ //! Defines physical expression for `row_number` that can evaluated at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, NumRowsState}; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -76,9 +76,11 @@ pub(crate) struct NumRowsEvaluator { } impl PartitionEvaluator for NumRowsEvaluator { - fn state(&self) -> Result { + fn state(&self) -> Result>> { // If we do not use state we just return Default - Ok(BuiltinWindowState::NumRows(self.state.clone())) + Ok(Some(Box::new(BuiltinWindowState::NumRows( + self.state.clone(), + )))) } fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 7fe616feda61..4dbed8705505 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,16 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::window_frame_state::WindowFrameContext; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::partition_evaluator::{PartitionEvaluator, PartitionState}; +use datafusion_expr::window_frame_state::{ + PartitionBatchState, WindowAggState, WindowFrameContext, +}; use datafusion_expr::{Accumulator, WindowFrame}; use indexmap::IndexMap; use std::any::Any; @@ -337,83 +338,12 @@ pub enum BuiltinWindowState { Default, } -#[derive(Debug)] -pub struct WindowAggState { - /// The range that we calculate the window function - pub window_frame_range: Range, - pub window_frame_ctx: Option, - /// The index of the last row that its result is calculated inside the partition record batch buffer. - pub last_calculated_index: usize, - /// The offset of the deleted row number - pub offset_pruned_rows: usize, - /// Stores the results calculated by window frame - pub out_col: ArrayRef, - /// Keeps track of how many rows should be generated to be in sync with input record_batch. - // (For each row in the input record batch we need to generate a window result). - pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition - pub is_end: bool, -} - -impl WindowAggState { - pub fn prune_state(&mut self, n_prune: usize) { - self.window_frame_range = Range { - start: self.window_frame_range.start - n_prune, - end: self.window_frame_range.end - n_prune, - }; - self.last_calculated_index -= n_prune; - self.offset_pruned_rows += n_prune; - - match self.window_frame_ctx.as_mut() { - // Rows have no state do nothing - Some(WindowFrameContext::Rows(_)) => {} - Some(WindowFrameContext::Range { .. }) => {} - Some(WindowFrameContext::Groups { state, .. }) => { - let mut n_group_to_del = 0; - for (_, end_idx) in &state.group_end_indices { - if n_prune < *end_idx { - break; - } - n_group_to_del += 1; - } - state.group_end_indices.drain(0..n_group_to_del); - state - .group_end_indices - .iter_mut() - .for_each(|(_, start_idx)| *start_idx -= n_prune); - state.current_group_idx -= n_group_to_del; - } - None => {} - }; +impl PartitionState for BuiltinWindowState { + fn as_any(&self) -> &dyn Any { + self } } -impl WindowAggState { - pub fn update( - &mut self, - out_col: &ArrayRef, - partition_batch_state: &PartitionBatchState, - ) -> Result<()> { - self.last_calculated_index += out_col.len(); - self.out_col = concat(&[&self.out_col, &out_col])?; - self.n_row_result_missing = - partition_batch_state.record_batch.num_rows() - self.last_calculated_index; - self.is_end = partition_batch_state.is_end; - Ok(()) - } -} - -/// State for each unique partition determined according to PARTITION BY column(s) -#[derive(Debug)] -pub struct PartitionBatchState { - /// The record_batch belonging to current partition - pub record_batch: RecordBatch, - /// Flag indicating whether we have received all data for this partition - pub is_end: bool, - /// Number of rows emitted for each partition - pub n_out_row: usize, -} - /// Key for IndexMap for each unique partition /// /// For instance, if window frame is `OVER(PARTITION BY a,b)`, @@ -429,18 +359,3 @@ pub type PartitionWindowAggStates = IndexMap; /// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. pub type PartitionBatches = IndexMap; - -impl WindowAggState { - pub fn new(out_type: &DataType) -> Result { - let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); - Ok(Self { - window_frame_range: Range { start: 0, end: 0 }, - window_frame_ctx: None, - last_calculated_index: 0, - offset_pruned_rows: 0, - out_col: empty_out_col, - n_row_result_missing: 0, - is_end: false, - }) - } -} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 104a65832dcd..ab667d97c676 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -183,11 +183,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn find_window_func(&self, name: &str) -> Result { window_function::find_df_window_func(name) + // next check user defined aggregates .or_else(|| { self.schema_provider .get_aggregate_meta(name) .map(WindowFunction::AggregateUDF) }) + // next check user defined window functions + .or_else(|| { + self.schema_provider + .get_window_meta(name) + .map(WindowFunction::WindowUDF) + }) .ok_or_else(|| { DataFusionError::Plan(format!("There is no window function named {name}")) }) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ceec01037425..26ff5466f408 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -22,6 +22,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::field_not_found; +use datafusion_expr::WindowUDF; use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -46,6 +47,8 @@ pub trait ContextProvider { fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; /// Getter for system/user-defined variable type fn get_variable_type(&self, variable_names: &[String]) -> Option;