Skip to content

Commit

Permalink
Support User Defined Window Functions (#6703)
Browse files Browse the repository at this point in the history
* Support User Defined Window Functions

* Apply suggestions from code review

Co-authored-by: Mustafa Akur <[email protected]>

* Remove left over println

* Apply suggestions from code review

Co-authored-by: Mustafa Akur <[email protected]>

* Apply suggestions from code review

Co-authored-by: Mustafa Akur <[email protected]>

* fix docs

---------

Co-authored-by: Mustafa Akur <[email protected]>
  • Loading branch information
alamb and mustafasrepo authored Jun 22, 2023
1 parent eb290a0 commit b1b8c9c
Show file tree
Hide file tree
Showing 32 changed files with 1,232 additions and 31 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false }
prost-derive = { version = "0.11", default-features = false }
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.82"
tempfile = "3"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
tonic = "0.9"
url = "2.2"
Expand Down
1 change: 1 addition & 0 deletions datafusion-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cargo run --example csv_sql
- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF)
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)

## Distributed

Expand Down
6 changes: 5 additions & 1 deletion datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{
AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource,
AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF,
};
use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule};
use datafusion_optimizer::optimizer::Optimizer;
Expand Down Expand Up @@ -216,6 +216,10 @@ impl ContextProvider for MyContextProvider {
None
}

fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
None
}

fn options(&self) -> &ConfigOptions {
&self.options
}
Expand Down
211 changes: 211 additions & 0 deletions datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::{
array::{ArrayRef, AsArray, Float64Array},
datatypes::Float64Type,
};
use arrow_schema::DataType;
use datafusion::datasource::file_format::options::CsvReadOptions;

use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{
PartitionEvaluator, Signature, Volatility, WindowFrame, WindowUDF,
};

// create local execution context with `cars.csv` registered as a table named `cars`
async fn create_context() -> Result<SessionContext> {
// declare a new context. In spark API, this corresponds to a new spark SQL session
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
println!("pwd: {}", std::env::current_dir().unwrap().display());
let csv_path = "datafusion/core/tests/data/cars.csv".to_string();
let read_options = CsvReadOptions::default().has_header(true);

ctx.register_csv("cars", &csv_path, read_options).await?;
Ok(ctx)
}

/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context().await?;

// register the window function with DataFusion so we can call it
ctx.register_udwf(smooth_it());

// Use SQL to run the new window function
let df = ctx.sql("SELECT * from cars").await?;
// print the results
df.show().await?;

// Use SQL to run the new window function:
//
// `PARTITION BY car`:each distinct value of car (red, and green)
// should be treated as a separate partition (and will result in
// creating a new `PartitionEvaluator`)
//
// `ORDER BY time`: within each partition ('green' or 'red') the
// rows will be be ordered by the value in the `time` column
//
// `evaluate_inside_range` is invoked with a window defined by the
// SQL. In this case:
//
// The first invocation will be passed row 0, the first row in the
// partition.
//
// The second invocation will be passed rows 0 and 1, the first
// two rows in the partition.
//
// etc.
let df = ctx
.sql(
"SELECT \
car, \
speed, \
smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\
time \
from cars \
ORDER BY \
car",
)
.await?;
// print the results
df.show().await?;

// this time, call the new widow function with an explicit
// window so evaluate will be invoked with each window.
//
// `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation
// sees at most 3 rows: the row before, the current row, and the 1
// row afterward.
let df = ctx.sql(
"SELECT \
car, \
speed, \
smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
time \
from cars \
ORDER BY \
car",
).await?;
// print the results
df.show().await?;

// Now, run the function using the DataFrame API:
let window_expr = smooth_it().call(
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(false),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;

// print the results
df.show().await?;

Ok(())
}

fn smooth_it() -> WindowUDF {
WindowUDF {
name: String::from("smooth_it"),
// it will take 1 arguments -- the column to smooth
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
return_type: Arc::new(return_type),
partition_evaluator_factory: Arc::new(make_partition_evaluator),
}
}

/// Compute the return type of the smooth_it window function given
/// arguments of `arg_types`.
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
if arg_types.len() != 1 {
return Err(DataFusionError::Plan(format!(
"my_udwf expects 1 argument, got {}: {:?}",
arg_types.len(),
arg_types
)));
}
Ok(Arc::new(arg_types[0].clone()))
}

/// Create a `PartitionEvalutor` to evaluate this function on a new
/// partition.
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}

/// This implements the lowest level evaluation for a window function
///
/// It handles calculating the value of the window function for each
/// distinct values of `PARTITION BY` (each car type in our example)
#[derive(Clone, Debug)]
struct MyPartitionEvaluator {}

impl MyPartitionEvaluator {
fn new() -> Self {
Self {}
}
}

/// Different evaluation methods are called depending on the various
/// settings of WindowUDF. This example uses the simplest and most
/// general, `evaluate`. See `PartitionEvaluator` for the other more
/// advanced uses.
impl PartitionEvaluator for MyPartitionEvaluator {
/// Tell DataFusion the window function varies based on the value
/// of the window frame.
fn uses_window_frame(&self) -> bool {
true
}

/// This function is called once per input row.
///
/// `range`specifies which indexes of `values` should be
/// considered for the calculation.
///
/// Note this is the SLOWEST, but simplest, way to evaluate a
/// window function. It is much faster to implement
/// evaluate_all or evaluate_all_with_rank, if possible
fn evaluate(
&mut self,
values: &[ArrayRef],
range: &std::ops::Range<usize>,
) -> Result<ScalarValue> {
// Again, the input argument is an array of floating
// point numbers to calculate a moving average
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();

let range_len = range.end - range.start;

// our smoothing function will average all the values in the
let output = if range_len > 0 {
let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum();
Some(sum / range_len as f64)
} else {
None
};

Ok(ScalarValue::Float64(output))
}
}
8 changes: 8 additions & 0 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ impl DataFrame {
Ok(DataFrame::new(self.session_state, plan))
}

/// Apply one or more window functions ([`Expr::WindowFunction`]) to extend the schema
pub fn window(self, window_exprs: Vec<Expr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.window(window_exprs)?
.build()?;
Ok(DataFrame::new(self.session_state, plan))
}

/// Limit the number of rows returned from this DataFrame.
///
/// `skip` - Number of rows to skip before fetch any row
Expand Down
43 changes: 42 additions & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion_common::alias::AliasGenerator;
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
DescribeTable, StringifiedPlan, UserDefinedLogicalNode,
DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
Expand Down Expand Up @@ -786,6 +786,20 @@ impl SessionContext {
.insert(f.name.clone(), Arc::new(f));
}

/// Registers a window UDF within this context.
///
/// Note in SQL queries, window function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
/// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
pub fn register_udwf(&self, f: WindowUDF) {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
}

/// Creates a [`DataFrame`] for reading a data source.
///
/// For more control such as reading multiple files, you can use
Expand Down Expand Up @@ -1279,6 +1293,10 @@ impl FunctionRegistry for SessionContext {
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().udaf(name)
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}
}

/// A planner used to add extensions to DataFusion logical and physical plans.
Expand Down Expand Up @@ -1329,6 +1347,8 @@ pub struct SessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
/// Window functions registered in the context
window_functions: HashMap<String, Arc<WindowUDF>>,
/// Deserializer registry for extensions.
serializer_registry: Arc<dyn SerializerRegistry>,
/// Session configuration
Expand Down Expand Up @@ -1423,6 +1443,7 @@ impl SessionState {
catalog_list,
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
serializer_registry: Arc::new(EmptySerializerRegistry),
config,
execution_props: ExecutionProps::new(),
Expand Down Expand Up @@ -1899,6 +1920,11 @@ impl SessionState {
&self.aggregate_functions
}

/// Return reference to window functions
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
&self.window_functions
}

/// Return [SerializerRegistry] for extensions
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
self.serializer_registry.clone()
Expand Down Expand Up @@ -1932,6 +1958,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
self.state.aggregate_functions().get(name).cloned()
}

fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
self.state.window_functions().get(name).cloned()
}

fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
Expand Down Expand Up @@ -1979,6 +2009,16 @@ impl FunctionRegistry for SessionState {
))
})
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
let result = self.window_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDWF named \"{name}\" in the registry"
))
})
}
}

impl OptimizerConfig for SessionState {
Expand Down Expand Up @@ -2012,6 +2052,7 @@ impl From<&SessionState> for TaskContext {
state.config.clone(),
state.scalar_functions.clone(),
state.aggregate_functions.clone(),
state.window_functions.clone(),
state.runtime_env.clone(),
)
}
Expand Down
Loading

0 comments on commit b1b8c9c

Please sign in to comment.