Skip to content

Commit

Permalink
Add Library Guide for User Defined Functions: Window/Aggregate (#8171)
Browse files Browse the repository at this point in the history
* udwf doc

Signed-off-by: veeupup <[email protected]>

* Add Library Guide for User Defined Functions: Window/Aggregate

Signed-off-by: veeupup <[email protected]>

* make docs prettier

Signed-off-by: veeupup <[email protected]>

---------

Signed-off-by: veeupup <[email protected]>
  • Loading branch information
Veeupup committed Nov 15, 2023
1 parent 7c2c2f0 commit cd1c648
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 5 deletions.
4 changes: 4 additions & 0 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ async fn main() -> Result<()> {
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
ctx.register_udaf(geometric_mean.clone());

let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?;
sql_df.show().await?;

// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
Expand Down
4 changes: 2 additions & 2 deletions datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async fn main() -> Result<()> {
"SELECT \
car, \
speed, \
smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\
smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\
time \
from cars \
ORDER BY \
Expand All @@ -109,7 +109,7 @@ async fn main() -> Result<()> {
"SELECT \
car, \
speed, \
smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\
time \
from cars \
ORDER BY \
Expand Down
316 changes: 313 additions & 3 deletions docs/source/library-user-guide/adding-udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ A Scalar UDF is a function that takes a row of data and returns a single value.
```rust
use std::sync::Arc;

use arrow::array::{ArrayRef, Int64Array};
use datafusion::arrow::array::{ArrayRef, Int64Array};
use datafusion::common::Result;

use datafusion::common::cast::as_int64_array;
Expand Down Expand Up @@ -78,6 +78,11 @@ The challenge however is that DataFusion doesn't know about this function. We ne
To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier.

```rust
use datafusion::logical_expr::{Volatility, create_udf};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::arrow::datatypes::DataType;
use std::sync::Arc;

let udf = create_udf(
"add_one",
vec![DataType::Int64],
Expand All @@ -98,6 +103,8 @@ A few things to note:
That gives us a `ScalarUDF` that we can register with the `SessionContext`:

```rust
use datafusion::execution::context::SessionContext;

let mut ctx = SessionContext::new();

ctx.register_udf(udf);
Expand All @@ -115,10 +122,313 @@ let df = ctx.sql(&sql).await.unwrap();

Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation.

Body coming soon.
For example, we will declare a user defined window function that computes a moving average.

```rust
use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type};
use datafusion::logical_expr::{PartitionEvaluator};
use datafusion::common::ScalarValue;
use datafusion::error::Result;
/// This implements the lowest level evaluation for a window function
///
/// It handles calculating the value of the window function for each
/// distinct values of `PARTITION BY`
#[derive(Clone, Debug)]
struct MyPartitionEvaluator {}

impl MyPartitionEvaluator {
fn new() -> Self {
Self {}
}
}

/// Different evaluation methods are called depending on the various
/// settings of WindowUDF. This example uses the simplest and most
/// general, `evaluate`. See `PartitionEvaluator` for the other more
/// advanced uses.
impl PartitionEvaluator for MyPartitionEvaluator {
/// Tell DataFusion the window function varies based on the value
/// of the window frame.
fn uses_window_frame(&self) -> bool {
true
}

/// This function is called once per input row.
///
/// `range`specifies which indexes of `values` should be
/// considered for the calculation.
///
/// Note this is the SLOWEST, but simplest, way to evaluate a
/// window function. It is much faster to implement
/// evaluate_all or evaluate_all_with_rank, if possible
fn evaluate(
&mut self,
values: &[ArrayRef],
range: &std::ops::Range<usize>,
) -> Result<ScalarValue> {
// Again, the input argument is an array of floating
// point numbers to calculate a moving average
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();

let range_len = range.end - range.start;

// our smoothing function will average all the values in the
let output = if range_len > 0 {
let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum();
Some(sum / range_len as f64)
} else {
None
};

Ok(ScalarValue::Float64(output))
}
}

/// Create a `PartitionEvalutor` to evaluate this function on a new
/// partition.
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}
```

### Registering a Window UDF

To register a Window UDF, you need to wrap the function implementation in a `WindowUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udwf` helper functions to make this easier.

```rust
use datafusion::logical_expr::{Volatility, create_udwf};
use datafusion::arrow::datatypes::DataType;
use std::sync::Arc;

// here is where we define the UDWF. We also declare its signature:
let smooth_it = create_udwf(
"smooth_it",
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(make_partition_evaluator),
);
```

The `create_udwf` has five arguments to check:

- The first argument is the name of the function. This is the name that will be used in SQL queries.
- **The second argument** is the `DataType` of input array (attention: this is not a list of arrays). I.e. in this case, the function accepts `Float64` as argument.
- The third argument is the return type of the function. I.e. in this case, the function returns an `Float64`.
- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input.
- **The fifth argument** is the function implementation. This is the function that we defined above.

That gives us a `WindowUDF` that we can register with the `SessionContext`:

```rust
use datafusion::execution::context::SessionContext;

let ctx = SessionContext::new();

ctx.register_udwf(smooth_it);
```

At this point, you can use the `smooth_it` function in your query:

For example, if we have a [`cars.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like

```csv
car,speed,time
red,20.0,1996-04-12T12:05:03.000000000
red,20.3,1996-04-12T12:05:04.000000000
green,10.0,1996-04-12T12:05:03.000000000
green,10.3,1996-04-12T12:05:04.000000000
...
```

Then, we can query like below:

```rust
use datafusion::datasource::file_format::options::CsvReadOptions;
// register csv table first
let csv_path = "cars.csv".to_string();
ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?;
// do query with smooth_it
let df = ctx
.sql(
"SELECT \
car, \
speed, \
smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed,\
time \
from cars \
ORDER BY \
car",
)
.await?;
// print the results
df.show().await?;
```

the output will be like:

```csv
+-------+-------+--------------------+---------------------+
| car | speed | smooth_speed | time |
+-------+-------+--------------------+---------------------+
| green | 10.0 | 10.0 | 1996-04-12T12:05:03 |
| green | 10.3 | 10.15 | 1996-04-12T12:05:04 |
| green | 10.4 | 10.233333333333334 | 1996-04-12T12:05:05 |
| green | 10.5 | 10.3 | 1996-04-12T12:05:06 |
| green | 11.0 | 10.440000000000001 | 1996-04-12T12:05:07 |
| green | 12.0 | 10.700000000000001 | 1996-04-12T12:05:08 |
| green | 14.0 | 11.171428571428573 | 1996-04-12T12:05:09 |
| green | 15.0 | 11.65 | 1996-04-12T12:05:10 |
| green | 15.1 | 12.033333333333333 | 1996-04-12T12:05:11 |
| green | 15.2 | 12.35 | 1996-04-12T12:05:12 |
| green | 8.0 | 11.954545454545455 | 1996-04-12T12:05:13 |
| green | 2.0 | 11.125 | 1996-04-12T12:05:14 |
| red | 20.0 | 20.0 | 1996-04-12T12:05:03 |
| red | 20.3 | 20.15 | 1996-04-12T12:05:04 |
...
```

## Adding an Aggregate UDF

Aggregate UDFs are functions that take a group of rows and return a single value. These are akin to SQL's `SUM` or `COUNT` functions.

Body coming soon.
For example, we will declare a single-type, single return type UDAF that computes the geometric mean.

```rust
use datafusion::arrow::array::ArrayRef;
use datafusion::scalar::ScalarValue;
use datafusion::{error::Result, physical_plan::Accumulator};

/// A UDAF has state across multiple rows, and thus we require a `struct` with that state.
#[derive(Debug)]
struct GeometricMean {
n: u32,
prod: f64,
}

impl GeometricMean {
// how the struct is initialized
pub fn new() -> Self {
GeometricMean { n: 0, prod: 1.0 }
}
}

// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
// to use them.
impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
])
}

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}

// DataFusion calls this function to update the accumulator's state for a batch
// of inputs rows. In this case the product is updated with values from the first column
// and the count is updated based on the row count
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::Float64(Some(value)) = v {
self.prod *= value;
self.n += 1;
} else {
unreachable!("")
}
Ok(())
})
}

// Optimization hint: this trait also supports `update_batch` and `merge_batch`,
// that can be used to perform these operations on arrays instead of single values.
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
let arr = &states[0];
(0..arr.len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1])
{
self.prod *= prod;
self.n += n;
} else {
unreachable!("")
}
Ok(())
})
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}
```

### registering an Aggregate UDF

To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier.

```rust
use datafusion::logical_expr::{Volatility, create_udaf};
use datafusion::arrow::datatypes::DataType;
use std::sync::Arc;

// here is where we define the UDAF. We also declare its signature:
let geometric_mean = create_udaf(
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
"geo_mean",
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
vec![DataType::Float64],
// the return type; DataFusion expects this to match the type returned by `evaluate`.
Arc::new(DataType::Float64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(GeometricMean::new()))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
```

The `create_udaf` has six arguments to check:

- The first argument is the name of the function. This is the name that will be used in SQL queries.
- The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in this case, the function accepts a single `Float64` argument.
- The third argument is the return type of the function. I.e. in this case, the function returns an `Int64`.
- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input.
- The fifth argument is the function implementation. This is the function that we defined above.
- The sixth argument is the description of the state, which will by passed between execution stages.

That gives us a `AggregateUDF` that we can register with the `SessionContext`:

```rust
use datafusion::execution::context::SessionContext;

let ctx = SessionContext::new();

ctx.register_udaf(geometric_mean);
```

Then, we can query like below:

```rust
let df = ctx.sql("SELECT geo_mean(a) FROM t").await?;
```

0 comments on commit cd1c648

Please sign in to comment.