Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions rust/datafusion/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use arrow::{
array::{Array, ArrayRef, Float32Array, Float64Array, Float64Builder},
array::{ArrayRef, Float32Array, Float64Array},
datatypes::DataType,
record_batch::RecordBatch,
util::pretty,
Expand Down Expand Up @@ -63,8 +63,7 @@ async fn main() -> Result<()> {
let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a loop or SIMD)
// 3. construct the resulting array
// 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result

// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);
Expand All @@ -82,21 +81,23 @@ async fn main() -> Result<()> {
// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());

// 2. Arrow's builder is used to construct an Arrow array.
let mut builder = Float64Builder::new(base.len());
for index in 0..base.len() {
// in arrow, any value can be null.
// Here we decide to make our UDF to return null when either base or exponent is null.
if base.is_null(index) || exponent.is_null(index) {
builder.append_null()?;
} else {
// 3. computation. Since we do not have any SIMD `pow` operation at our hands,
// we loop over each entry. Array's values are obtained via `.value(index)`.
let value = base.value(index).powf(exponent.value(index));
builder.append_value(value)?;
}
}
Ok(Arc::new(builder.finish()))
// 2. perform the computation
let array = base
.iter()
.zip(exponent.iter())
.map(|(base, exponent)| {
match (base, exponent) {
// in arrow, any value can be null.
// Here we decide to make our UDF to return null when either base or exponent is null.
(Some(base), Some(exponent)) => Some(base.powf(exponent)),
_ => None,
}
})
.collect::<Float64Array>();

// `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!)
// `Arc` because arrays are immutable, thread-safe, trait objects.
Ok(Arc::new(array))
});

// Next:
Expand Down