Skip to content

Commit

Permalink
Updates and get example compiling
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 9, 2023
1 parent edf0afc commit 1bc2d6e
Show file tree
Hide file tree
Showing 24 changed files with 330 additions and 83 deletions.
152 changes: 144 additions & 8 deletions datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::{
array::{AsArray, Float64Array},
datatypes::Float64Type,
};
use arrow_schema::DataType;
use datafusion::datasource::file_format::options::CsvReadOptions;

use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::DataFusionError;
use datafusion_expr::{
partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF,
};

// create local execution context with `cars.csv` registered as a table named `cars`
async fn create_context() -> Result<SessionContext> {
Expand All @@ -39,6 +50,9 @@ async fn create_context() -> Result<SessionContext> {
async fn main() -> Result<()> {
let ctx = create_context().await?;

// register the window function with DataFusion so wecan call it
ctx.register_udwf(my_average());

// Use SQL to run the new window function
let df = ctx.sql("SELECT * from cars").await?;
// print the results
Expand All @@ -52,23 +66,145 @@ async fn main() -> Result<()> {
"SELECT car, \
speed, \
lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\
my_average(speed) OVER (PARTITION BY car ORDER BY time),\
time \
from cars",
)
.await?;
// print the results
df.show().await?;

// ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using
let df = ctx.sql("SELECT car, \
speed, \
lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
time \
from cars").await?;
// print the results
df.show().await?;
// // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using
// let df = ctx.sql("SELECT car, \
// speed, \
// lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
// time \
// from cars").await?;
// // print the results
// df.show().await?;

// todo show how to run dataframe API as well

Ok(())
}

// TODO make a helper funciton like `crate_udf` that helps to make these signatures

fn my_average() -> WindowUDF {
WindowUDF {
name: String::from("my_average"),
// it will take 2 arguments -- the column and the window size
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
return_type: Arc::new(return_type),
partition_evaluator: Arc::new(make_partition_evaluator),
}
}

/// Compute the return type of the function given the argument types
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
if arg_types.len() != 1 {
return Err(DataFusionError::Plan(format!(
"my_udwf expects 1 argument, got {}: {:?}",
arg_types.len(),
arg_types
)));
}
Ok(Arc::new(arg_types[0].clone()))
}

/// Create a partition evaluator for this argument
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}

/// This implements the lowest level evaluation for a window function
///
/// It handles calculating the value of the window function for each
/// distinct values of `PARTITION BY` (each car type in our example)
#[derive(Clone, Debug)]
struct MyPartitionEvaluator {}

impl MyPartitionEvaluator {
fn new() -> Self {
Self {}
}
}

/// These different evaluation methods are called depending on the various settings of WindowUDF
impl PartitionEvaluator for MyPartitionEvaluator {
fn get_range(&self, _idx: usize, _n_rows: usize) -> Result<std::ops::Range<usize>> {
Err(DataFusionError::NotImplemented(
"get_range is not implemented for this window function".to_string(),
))
}

/// This function is given the values of each partition
fn evaluate(
&self,
values: &[arrow::array::ArrayRef],
_num_rows: usize,
) -> Result<arrow::array::ArrayRef> {
// datafusion has handled ensuring we get the correct input argument
assert_eq!(values.len(), 1);

// For this example, we convert convert the input argument to an
// array of floating point numbers to calculate a moving average
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();

// implement a simple moving average by averaging the current
// value with the previous value
//
// value | avg
// ------+------
// 10 | 10
// 20 | 15
// 30 | 25
// 30 | 30
//
let mut previous_value = None;
let new_values: Float64Array = arr
.values()
.iter()
.map(|&value| {
let new_value = previous_value
.map(|previous_value| (value + previous_value) / 2.0)
.unwrap_or(value);
previous_value = Some(value);
new_value
})
.collect();

Ok(Arc::new(new_values))
}

fn evaluate_stateful(
&mut self,
_values: &[arrow::array::ArrayRef],
) -> Result<datafusion_common::ScalarValue> {
Err(DataFusionError::NotImplemented(
"evaluate_stateful is not implemented by default".into(),
))
}

fn evaluate_with_rank(
&self,
_num_rows: usize,
_ranks_in_partition: &[std::ops::Range<usize>],
) -> Result<arrow::array::ArrayRef> {
Err(DataFusionError::NotImplemented(
"evaluate_partition_with_rank is not implemented by default".into(),
))
}

fn evaluate_inside_range(
&self,
_values: &[arrow::array::ArrayRef],
_range: &std::ops::Range<usize>,
) -> Result<datafusion_common::ScalarValue> {
Err(DataFusionError::NotImplemented(
"evaluate_inside_range is not implemented by default".into(),
))
}
}

// TODO show how to use other evaluate methods
43 changes: 42 additions & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
DescribeTable, StringifiedPlan, UserDefinedLogicalNode,
DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
Expand Down Expand Up @@ -797,6 +797,20 @@ impl SessionContext {
.insert(f.name.clone(), Arc::new(f));
}

/// Registers an window UDF within this context.
///
/// Note in SQL queries, window function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
pub fn register_udwf(&self, f: WindowUDF) {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
}

/// Creates a [`DataFrame`] for reading a data source.
///
/// For more control such as reading multiple files, you can use
Expand Down Expand Up @@ -1290,6 +1304,10 @@ impl FunctionRegistry for SessionContext {
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().udaf(name)
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}
}

/// A planner used to add extensions to DataFusion logical and physical plans.
Expand Down Expand Up @@ -1340,6 +1358,8 @@ pub struct SessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
/// Window functions registered in the context
window_functions: HashMap<String, Arc<WindowUDF>>,
/// Deserializer registry for extensions.
serializer_registry: Arc<dyn SerializerRegistry>,
/// Session configuration
Expand Down Expand Up @@ -1483,6 +1503,7 @@ impl SessionState {
catalog_list,
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
serializer_registry: Arc::new(EmptySerializerRegistry),
config,
execution_props: ExecutionProps::new(),
Expand Down Expand Up @@ -1959,6 +1980,11 @@ impl SessionState {
&self.aggregate_functions
}

/// Return reference to window functions
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
&self.window_functions
}

/// Return [SerializerRegistry] for extensions
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
self.serializer_registry.clone()
Expand Down Expand Up @@ -1992,6 +2018,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
self.state.aggregate_functions().get(name).cloned()
}

fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
self.state.window_functions().get(name).cloned()
}

fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
Expand Down Expand Up @@ -2039,6 +2069,16 @@ impl FunctionRegistry for SessionState {
))
})
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
let result = self.window_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDWF named \"{name}\" in the registry"
))
})
}
}

impl OptimizerConfig for SessionState {
Expand Down Expand Up @@ -2068,6 +2108,7 @@ impl From<&SessionState> for TaskContext {
state.config.clone(),
state.scalar_functions.clone(),
state.aggregate_functions.clone(),
state.window_functions.clone(),
state.runtime_env.clone(),
)
}
Expand Down
18 changes: 15 additions & 3 deletions datafusion/core/src/physical_plan/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ fn create_udwf_window_expr(
name: String,
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
// need to get the types into an owned vec for some reason
let input_types: Vec<_> = input_schema.fields().iter().map(|f| f.data_type().clone()).collect();
let input_types: Vec<_> = args
.iter()
.map(|arg| arg.data_type(input_schema).map(|dt| dt.clone()))
.collect::<Result<_>>()?;

// figure out the output type
let data_type = (fun.return_type)(&input_types)?;
Ok(Arc::new(WindowUDFExpr {
Expand Down Expand Up @@ -227,15 +231,23 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr {

fn field(&self) -> Result<Field> {
let nullable = false;
Ok(Field::new(&self.name, self.data_type.as_ref().clone(), nullable))
Ok(Field::new(
&self.name,
self.data_type.as_ref().clone(),
nullable,
))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.args.clone()
}

fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
todo!()
(self.fun.partition_evaluator)()
}

fn name(&self) -> &str {
&self.name
}
}

Expand Down
1 change: 0 additions & 1 deletion datafusion/core/tests/data/cars.csv
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@ green,15.1,1996-04-12T12:05:11.000000000
green,15.2,1996-04-12T12:05:12.000000000
green,8.0,1996-04-12T12:05:13.000000000
green,2.0,1996-04-12T12:05:14.000000000
green,0.0,1996-04-12T12:05:15.000000000
5 changes: 4 additions & 1 deletion datafusion/execution/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! FunctionRegistry trait

use datafusion_common::Result;
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode};
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use std::{collections::HashSet, sync::Arc};

/// A registry knows how to build logical expressions out of user-defined function' names
Expand All @@ -31,6 +31,9 @@ pub trait FunctionRegistry {

/// Returns a reference to the udaf named `name`.
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;

/// Returns a reference to the udwf named `name`.
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
Expand Down
Loading

0 comments on commit 1bc2d6e

Please sign in to comment.