diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs index e753105f225..94df59d95f9 100644 --- a/rust/datafusion/examples/simple_udf.rs +++ b/rust/datafusion/examples/simple_udf.rs @@ -16,7 +16,7 @@ // under the License. use arrow::{ - array::{Array, ArrayRef, Float32Array, Float64Array, Float64Builder}, + array::{ArrayRef, Float32Array, Float64Array}, datatypes::DataType, record_batch::RecordBatch, util::pretty, @@ -63,8 +63,7 @@ async fn main() -> Result<()> { let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| { // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: // 1. cast the values to the type we want - // 2. perform the computation for every element in the array (using a loop or SIMD) - // 3. construct the resulting array + // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result // this is guaranteed by DataFusion based on the function's signature. assert_eq!(args.len(), 2); @@ -82,21 +81,23 @@ async fn main() -> Result<()> { // this is guaranteed by DataFusion. We place it just to make it obvious. assert_eq!(exponent.len(), base.len()); - // 2. Arrow's builder is used to construct an Arrow array. - let mut builder = Float64Builder::new(base.len()); - for index in 0..base.len() { - // in arrow, any value can be null. - // Here we decide to make our UDF to return null when either base or exponent is null. - if base.is_null(index) || exponent.is_null(index) { - builder.append_null()?; - } else { - // 3. computation. Since we do not have any SIMD `pow` operation at our hands, - // we loop over each entry. Array's values are obtained via `.value(index)`. - let value = base.value(index).powf(exponent.value(index)); - builder.append_value(value)?; - } - } - Ok(Arc::new(builder.finish())) + // 2. perform the computation + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| { + match (base, exponent) { + // in arrow, any value can be null. + // Here we decide to make our UDF to return null when either base or exponent is null. + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + } + }) + .collect::(); + + // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) + // `Arc` because arrays are immutable, thread-safe, trait objects. + Ok(Arc::new(array)) }); // Next: