Skip to content

Commit

Permalink
switch implementation of erf to use ScalarUDFImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent f17db18 commit dfec16a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 61 deletions.
134 changes: 89 additions & 45 deletions src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,54 +189,98 @@ pub fn div(args: &[ArrayRef]) -> Result<ArrayRef> {
}

/// Error function
pub fn erf(args: &[ArrayRef]) -> Result<ArrayRef> {
let column_data = &args[0];
let data = column_data.into_data();
let data_type = data.data_type();
#[derive(Debug)]
pub struct Erf {
signature: Signature,
}

let mut float64array_builder = Float64Array::builder(args[0].len());
match data_type {
DataType::Float64 => {
let values = datafusion::common::cast::as_float64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
DataType::Int64 => {
let values = datafusion::common::cast::as_int64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
DataType::UInt64 => {
let values = datafusion::common::cast::as_uint64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
t => {
return Err(DataFusionError::Internal(format!(
"Unsupported type {t} for erf function"
)))
impl Erf {
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable),
}
};
}
}

Ok(Arc::new(float64array_builder.finish()) as ArrayRef)
impl ScalarUDFImpl for Erf {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"erf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let column_data = args
.first()
.ok_or(DataFusionError::Internal("Empty argument".to_string()))?;

let col_array = match column_data {
ColumnarValue::Array(array) => array,
ColumnarValue::Scalar(_) => {
return Err(DataFusionError::Execution("Empty argument".to_string()))
}
};

let mut float64array_builder = Float64Array::builder(col_array.len());
let column_data = col_array;
let data = column_data.into_data();
let data_type = data.data_type();

match data_type {
Float64 => {
let values = datafusion::common::cast::as_float64_array(&col_array)?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
Int64 => {
let values = datafusion::common::cast::as_int64_array(&col_array)?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
UInt64 => {
let values = datafusion::common::cast::as_uint64_array(&col_array)?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
t => {
return Err(DataFusionError::Internal(format!(
"Unsupported type {t} for erf function"
)))
}
};

Ok(ColumnarValue::Array(
Arc::new(float64array_builder.finish()) as ArrayRef,
))
}
}

/// Complementary error function
Expand Down
19 changes: 3 additions & 16 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

use std::sync::Arc;

use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt64, UInt8, Utf8};
use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt8, Utf8};
use datafusion::error::Result;
use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
use datafusion::prelude::SessionContext;

use crate::postgres::math_udfs::{
acosd, asind, atand, ceiling, cosd, cotd, div, erf, sind, tand, Erfc, RandomNormal,
acosd, asind, atand, ceiling, cosd, cotd, div, sind, tand, Erf, Erfc, RandomNormal,
};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network,
Expand All @@ -36,7 +36,7 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
register_tand(ctx);
register_ceiling(ctx);
register_div(ctx);
register_erf(ctx);
ctx.register_udf(ScalarUDF::from(Erf::new()));
ctx.register_udf(ScalarUDF::from(Erfc::new()));
ctx.register_udf(ScalarUDF::from(RandomNormal::new()));
Ok(())
Expand Down Expand Up @@ -146,19 +146,6 @@ fn register_ceiling(ctx: &SessionContext) {
ctx.register_udf(ceiling_udf);
}

fn register_erf(ctx: &SessionContext) {
let erf_udf = make_scalar_function(erf);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64)));
let erf_udf = ScalarUDF::new(
"erf",
&Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable),
&return_type,
&erf_udf,
);

ctx.register_udf(erf_udf);
}

fn register_div(ctx: &SessionContext) {
let udf = make_scalar_function(div);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Int64)));
Expand Down

0 comments on commit dfec16a

Please sign in to comment.