From dfec16ab6313720661f3e11b578f87bd092f5624 Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi Date: Sat, 20 Apr 2024 14:01:18 +0400 Subject: [PATCH] switch implementation of erf to use ScalarUDFImpl --- src/postgres/math_udfs.rs | 134 +++++++++++++++++++++++++------------- src/postgres/mod.rs | 19 +----- 2 files changed, 92 insertions(+), 61 deletions(-) diff --git a/src/postgres/math_udfs.rs b/src/postgres/math_udfs.rs index 3a23066..00844a3 100644 --- a/src/postgres/math_udfs.rs +++ b/src/postgres/math_udfs.rs @@ -189,54 +189,98 @@ pub fn div(args: &[ArrayRef]) -> Result { } /// Error function -pub fn erf(args: &[ArrayRef]) -> Result { - let column_data = &args[0]; - let data = column_data.into_data(); - let data_type = data.data_type(); +#[derive(Debug)] +pub struct Erf { + signature: Signature, +} - let mut float64array_builder = Float64Array::builder(args[0].len()); - match data_type { - DataType::Float64 => { - let values = datafusion::common::cast::as_float64_array(&args[0])?; - values.iter().try_for_each(|value| { - if let Some(value) = value { - float64array_builder.append_value(libm::erf(value)) - } else { - float64array_builder.append_null(); - } - Ok::<(), DataFusionError>(()) - })?; - } - DataType::Int64 => { - let values = datafusion::common::cast::as_int64_array(&args[0])?; - values.iter().try_for_each(|value| { - if let Some(value) = value { - float64array_builder.append_value(libm::erf(value as f64)) - } else { - float64array_builder.append_null(); - } - Ok::<(), DataFusionError>(()) - })?; - } - DataType::UInt64 => { - let values = datafusion::common::cast::as_uint64_array(&args[0])?; - values.iter().try_for_each(|value| { - if let Some(value) = value { - float64array_builder.append_value(libm::erf(value as f64)) - } else { - float64array_builder.append_null(); - } - Ok::<(), DataFusionError>(()) - })?; - } - t => { - return Err(DataFusionError::Internal(format!( - "Unsupported type {t} for erf function" - ))) +impl Erf { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable), } - }; + } +} - Ok(Arc::new(float64array_builder.finish()) as ArrayRef) +impl ScalarUDFImpl for Erf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "erf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Float64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let column_data = args + .first() + .ok_or(DataFusionError::Internal("Empty argument".to_string()))?; + + let col_array = match column_data { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(_) => { + return Err(DataFusionError::Execution("Empty argument".to_string())) + } + }; + + let mut float64array_builder = Float64Array::builder(col_array.len()); + let column_data = col_array; + let data = column_data.into_data(); + let data_type = data.data_type(); + + match data_type { + Float64 => { + let values = datafusion::common::cast::as_float64_array(&col_array)?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erf(value)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + Int64 => { + let values = datafusion::common::cast::as_int64_array(&col_array)?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erf(value as f64)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + UInt64 => { + let values = datafusion::common::cast::as_uint64_array(&col_array)?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erf(value as f64)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + t => { + return Err(DataFusionError::Internal(format!( + "Unsupported type {t} for erf function" + ))) + } + }; + + Ok(ColumnarValue::Array( + Arc::new(float64array_builder.finish()) as ArrayRef, + )) + } } /// Complementary error function diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index fff4690..5372390 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -3,14 +3,14 @@ use std::sync::Arc; -use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt64, UInt8, Utf8}; +use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt8, Utf8}; use datafusion::error::Result; use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility}; use datafusion::physical_expr::functions::make_scalar_function; use datafusion::prelude::SessionContext; use crate::postgres::math_udfs::{ - acosd, asind, atand, ceiling, cosd, cotd, div, erf, sind, tand, Erfc, RandomNormal, + acosd, asind, atand, ceiling, cosd, cotd, div, sind, tand, Erf, Erfc, RandomNormal, }; use crate::postgres::network_udfs::{ broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network, @@ -36,7 +36,7 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> { register_tand(ctx); register_ceiling(ctx); register_div(ctx); - register_erf(ctx); + ctx.register_udf(ScalarUDF::from(Erf::new())); ctx.register_udf(ScalarUDF::from(Erfc::new())); ctx.register_udf(ScalarUDF::from(RandomNormal::new())); Ok(()) @@ -146,19 +146,6 @@ fn register_ceiling(ctx: &SessionContext) { ctx.register_udf(ceiling_udf); } -fn register_erf(ctx: &SessionContext) { - let erf_udf = make_scalar_function(erf); - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64))); - let erf_udf = ScalarUDF::new( - "erf", - &Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable), - &return_type, - &erf_udf, - ); - - ctx.register_udf(erf_udf); -} - fn register_div(ctx: &SessionContext) { let udf = make_scalar_function(div); let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Int64)));