Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: User Defined Window Functions #6617

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false }
prost-derive = { version = "0.11", default-features = false }
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.82"
tempfile = "3"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
tonic = "0.9"
url = "2.2"
Expand Down
1 change: 1 addition & 0 deletions datafusion-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cargo run --example csv_sql
- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF)
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)

## Distributed

Expand Down
210 changes: 210 additions & 0 deletions datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::{
array::{AsArray, Float64Array},
datatypes::Float64Type,
};
use arrow_schema::DataType;
use datafusion::datasource::file_format::options::CsvReadOptions;

use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::DataFusionError;
use datafusion_expr::{
partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF,
};

// create local execution context with `cars.csv` registered as a table named `cars`
async fn create_context() -> Result<SessionContext> {
// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
println!("pwd: {}", std::env::current_dir().unwrap().display());
let csv_path = format!("datafusion/core/tests/data/cars.csv");
let read_options = CsvReadOptions::default().has_header(true);

ctx.register_csv("cars", &csv_path, read_options).await?;
Ok(ctx)
}

/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context().await?;

// register the window function with DataFusion so wecan call it
ctx.register_udwf(my_average());

// Use SQL to run the new window function
let df = ctx.sql("SELECT * from cars").await?;
// print the results
df.show().await?;

// Use SQL to run the new window function
// `PARTITION BY car`:each distinct value of car (red, and green) should be treated separately
// `ORDER BY time`: within each group (greed or green) the values will be orderd by time
let df = ctx
.sql(
"SELECT car, \
speed, \
lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\
my_average(speed) OVER (PARTITION BY car ORDER BY time),\
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shows calling the user defined window function via SQL

time \
from cars",
)
.await?;
// print the results
df.show().await?;

// // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using
// let df = ctx.sql("SELECT car, \
// speed, \
// lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
// time \
// from cars").await?;
// // print the results
// df.show().await?;

// todo show how to run dataframe API as well

Ok(())
}

// TODO make a helper funciton like `crate_udf` that helps to make these signatures

fn my_average() -> WindowUDF {
WindowUDF {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the structure that provides metadata about the window function

name: String::from("my_average"),
// it will take 2 arguments -- the column and the window size
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
return_type: Arc::new(return_type),
partition_evaluator: Arc::new(make_partition_evaluator),
}
}

/// Compute the return type of the function given the argument types
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
if arg_types.len() != 1 {
return Err(DataFusionError::Plan(format!(
"my_udwf expects 1 argument, got {}: {:?}",
arg_types.len(),
arg_types
)));
}
Ok(Arc::new(arg_types[0].clone()))
}

/// Create a partition evaluator for this argument
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}
Comment on lines +146 to +148
Copy link
Contributor

@stuartcarnie stuartcarnie Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible we could support passing scalar arguments when creating an instance of the function, similar to the built-in functions?

For example, the lag function takes an optional scalar value for the second argument, which is the shift offset:

https://github.com/apache/arrow-datafusion/blob/a42cc8d98b6e875c485e7e9b106d30803a32b00a/datafusion/core/src/physical_plan/windows/mod.rs#L148-L152

I would use this for functions such as moving_average, which requires a scalar for specifying the minimum number of rows to average.


Note

This would be a welcomed feature for UDAFs too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try and figure out how to do this

Copy link
Contributor Author

@alamb alamb Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stuartcarnie I looked into adding the arguments. The primary issue I encountered is that a WindowUDF is specified in terms of structures in datafusion-expr ( aka it doesn't have access to PhysicalExprs as those are defined in a different crate.

Here are some possible signatures we could provide. Do you have any feedback on these possibilities?

Pass in the Exprs from the logical plan

This is non ideal in my mind as the PartitionEvaluator is created during execution (where the Exprs are normally not around anymore)

/// Factory that creates a PartitionEvaluator for the given window function.
///
/// This function is passed its input arguments so that cases such as
/// constants can be correctly handled.
pub type PartitionEvaluatorFunctionFactory =
    Arc<dyn Fn(&[Expr]) -> Result<Box<dyn PartitionEvaluator>> + Send + Sync>;

Pass in a ArgType enum

This is also non ideal in my mind as it seemingly artificially limits what the user defined window function can special case (why not Column's for example??)

enum ArgType {
  /// The argument was a single value
  Scalar(ScalarValue),
  /// the argument is something other than a single value
  Array
}

/// Factory that creates a PartitionEvaluator for the given window function.
///
/// This function is passed its input arguments so that cases such as
/// constants can be specially handled if desired.
pub type PartitionEvaluatorFunctionFactory =
    Arc<dyn Fn(args: Vec<ArgType>) -> Result<Box<dyn PartitionEvaluator>> + Send + Sync>;

Others?

Copy link
Contributor

@stuartcarnie stuartcarnie Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of passing the PhysicalExpr trait objects? Example:

/// Factory that creates a PartitionEvaluator for the given window function.
///
/// This function is passed its input arguments and schema so that cases such as
/// constants can be correctly handled.
pub type PartitionEvaluatorFunctionFactory =
    Arc<dyn Fn(&[Arc<dyn PhysicalExpr>], &Schema) -> Result<Box<dyn PartitionEvaluator>> + Send + Sync>;

Note

I've also included the input_schema, as this would be necessary to evaluate types for the arguments.

This would be similar to the create_built_in_window_expr:

https://github.com/apache/arrow-datafusion/blob/a42cc8d98b6e875c485e7e9b106d30803a32b00a/datafusion/core/src/physical_plan/windows/mod.rs#L120-L125

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of passing the PhysicalExpr trait objects? Example:

I think this would be ideal, the problem is that PhysicalExpr is defined in datafusion-physical-expr which is not a dependency of datafusion-expr (the dependency goes the other way): https://github.com/apache/arrow-datafusion/blob/6194d588d5c3e9f202a31a0c524f63e6fb08d040/datafusion/physical-expr/Cargo.toml#L54

Thus, since WindowUDF is defined in datafusion-expr it can't depend on PhysicalExpr

datafusion_expr: https://github.com/apache/arrow-datafusion/blob/6194d588d5c3e9f202a31a0c524f63e6fb08d040/datafusion/expr/Cargo.toml#L37

Copy link
Contributor

@stuartcarnie stuartcarnie Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, of course!

I'd suggest we don't hold up this work and move this problem to another PR to solve it for both user-defined aggregate and window functions.

It works today, just that the update_batch feels a bit awkward, as the scalar argument is passed as an ArrayRef. We might be able to engineer it so that it isn't a breaking change in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual runtime passes ColumnarValue: https://docs.rs/datafusion/latest/datafusion/physical_plan/enum.ColumnarValue.html

Which is either a scalar or an array

We could potentially update the signatures to accept that instead maybe (though we would have to move it to the datafusion_expr crate)


/// This implements the lowest level evaluation for a window function
///
/// It handles calculating the value of the window function for each
/// distinct values of `PARTITION BY` (each car type in our example)
#[derive(Clone, Debug)]
struct MyPartitionEvaluator {}

impl MyPartitionEvaluator {
fn new() -> Self {
Self {}
}
}

/// These different evaluation methods are called depending on the various settings of WindowUDF
impl PartitionEvaluator for MyPartitionEvaluator {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the proposal of how a user would specify specify the window calculation -- by impl PartitionEvaluator

fn get_range(&self, _idx: usize, _n_rows: usize) -> Result<std::ops::Range<usize>> {
Err(DataFusionError::NotImplemented(
"get_range is not implemented for this window function".to_string(),
))
}

/// This function is given the values of each partition
fn evaluate(
&self,
values: &[arrow::array::ArrayRef],
_num_rows: usize,
) -> Result<arrow::array::ArrayRef> {
// datafusion has handled ensuring we get the correct input argument
assert_eq!(values.len(), 1);

// For this example, we convert convert the input argument to an
// array of floating point numbers to calculate a moving average
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();

// implement a simple moving average by averaging the current
// value with the previous value
//
// value | avg
// ------+------
// 10 | 10
// 20 | 15
// 30 | 25
// 30 | 30
//
let mut previous_value = None;
let new_values: Float64Array = arr
.values()
.iter()
.map(|&value| {
let new_value = previous_value
.map(|previous_value| (value + previous_value) / 2.0)
.unwrap_or(value);
previous_value = Some(value);
new_value
})
.collect();

Ok(Arc::new(new_values))
}

fn evaluate_stateful(
&mut self,
_values: &[arrow::array::ArrayRef],
) -> Result<datafusion_common::ScalarValue> {
Err(DataFusionError::NotImplemented(
"evaluate_stateful is not implemented by default".into(),
))
}

fn evaluate_with_rank(
&self,
_num_rows: usize,
_ranks_in_partition: &[std::ops::Range<usize>],
) -> Result<arrow::array::ArrayRef> {
Err(DataFusionError::NotImplemented(
"evaluate_partition_with_rank is not implemented by default".into(),
))
}

fn evaluate_inside_range(
&self,
_values: &[arrow::array::ArrayRef],
_range: &std::ops::Range<usize>,
) -> Result<datafusion_common::ScalarValue> {
Err(DataFusionError::NotImplemented(
"evaluate_inside_range is not implemented by default".into(),
))
}
}

// TODO show how to use other evaluate methods
43 changes: 42 additions & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
DescribeTable, StringifiedPlan, UserDefinedLogicalNode,
DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
Expand Down Expand Up @@ -797,6 +797,20 @@ impl SessionContext {
.insert(f.name.clone(), Arc::new(f));
}

/// Registers an window UDF within this context.
///
/// Note in SQL queries, window function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
pub fn register_udwf(&self, f: WindowUDF) {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
}

/// Creates a [`DataFrame`] for reading a data source.
///
/// For more control such as reading multiple files, you can use
Expand Down Expand Up @@ -1290,6 +1304,10 @@ impl FunctionRegistry for SessionContext {
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().udaf(name)
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}
}

/// A planner used to add extensions to DataFusion logical and physical plans.
Expand Down Expand Up @@ -1340,6 +1358,8 @@ pub struct SessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
/// Window functions registered in the context
window_functions: HashMap<String, Arc<WindowUDF>>,
/// Deserializer registry for extensions.
serializer_registry: Arc<dyn SerializerRegistry>,
/// Session configuration
Expand Down Expand Up @@ -1483,6 +1503,7 @@ impl SessionState {
catalog_list,
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
serializer_registry: Arc::new(EmptySerializerRegistry),
config,
execution_props: ExecutionProps::new(),
Expand Down Expand Up @@ -1959,6 +1980,11 @@ impl SessionState {
&self.aggregate_functions
}

/// Return reference to window functions
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
&self.window_functions
}

/// Return [SerializerRegistry] for extensions
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
self.serializer_registry.clone()
Expand Down Expand Up @@ -1992,6 +2018,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
self.state.aggregate_functions().get(name).cloned()
}

fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
self.state.window_functions().get(name).cloned()
}

fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
Expand Down Expand Up @@ -2039,6 +2069,16 @@ impl FunctionRegistry for SessionState {
))
})
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
let result = self.window_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDWF named \"{name}\" in the registry"
))
})
}
}

impl OptimizerConfig for SessionState {
Expand Down Expand Up @@ -2068,6 +2108,7 @@ impl From<&SessionState> for TaskContext {
state.config.clone(),
state.scalar_functions.clone(),
state.aggregate_functions.clone(),
state.window_functions.clone(),
state.runtime_env.clone(),
)
}
Expand Down
Loading