Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions datafusion/functions/benches/signum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::{
util::bench_util::create_primitive_array,
};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_functions::math::signum;
Expand Down Expand Up @@ -88,6 +89,51 @@ fn criterion_benchmark(c: &mut Criterion) {
)
})
});

// Scalar benchmarks (the optimization we added)
let scalar_f32_args =
vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(-42.5)))];
let scalar_f32_arg_fields =
vec![Field::new("a", DataType::Float32, false).into()];
let return_field_f32 = Field::new("f", DataType::Float32, false).into();

c.bench_function(&format!("signum f32 scalar: {size}"), |b| {
b.iter(|| {
black_box(
signum
.invoke_with_args(ScalarFunctionArgs {
args: scalar_f32_args.clone(),
arg_fields: scalar_f32_arg_fields.clone(),
number_rows: 1,
return_field: Arc::clone(&return_field_f32),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

let scalar_f64_args =
vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-42.5)))];
let scalar_f64_arg_fields =
vec![Field::new("a", DataType::Float64, false).into()];
let return_field_f64 = Field::new("f", DataType::Float64, false).into();

c.bench_function(&format!("signum f64 scalar: {size}"), |b| {
b.iter(|| {
black_box(
signum
.invoke_with_args(ScalarFunctionArgs {
args: scalar_f64_args.clone(),
arg_fields: scalar_f64_arg_fields.clone(),
number_rows: 1,
return_field: Arc::clone(&return_field_f64),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});
}
}

Expand Down
80 changes: 48 additions & 32 deletions datafusion/functions/src/math/signum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray};
use arrow::array::AsArray;
use arrow::datatypes::DataType::{Float32, Float64};
use arrow::datatypes::{DataType, Float32Type, Float64Type};

use datafusion_common::{Result, exec_err};
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_macros::user_doc;

use crate::utils::make_scalar_function;

#[user_doc(
doc_section(label = "Math Functions"),
description = r#"Returns the sign of a number.
Expand Down Expand Up @@ -98,41 +97,58 @@ impl ScalarUDFImpl for SignumFunc {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(signum, vec![])(&args.args)
let return_type = args.return_type().clone();
let [arg] = take_function_args(self.name(), args.args)?;

match arg {
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
return ColumnarValue::Scalar(ScalarValue::Null)
.cast_to(&return_type, None);
}

match scalar {
ScalarValue::Float64(Some(v)) => {
let result = if v == 0.0 { 0.0 } else { v.signum() };
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result))))
}
ScalarValue::Float32(Some(v)) => {
let result = if v == 0.0 { 0.0 } else { v.signum() };
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result))))
}
_ => {
internal_err!(
"Unexpected scalar type for signum: {:?}",
scalar.data_type()
)
}
}
}
ColumnarValue::Array(array) => match array.data_type() {
Float64 => Ok(ColumnarValue::Array(Arc::new(
array.as_primitive::<Float64Type>().unary::<_, Float64Type>(
|x: f64| {
if x == 0.0 { 0.0 } else { x.signum() }
},
),
))),
Float32 => Ok(ColumnarValue::Array(Arc::new(
array.as_primitive::<Float32Type>().unary::<_, Float32Type>(
|x: f32| {
if x == 0.0 { 0.0 } else { x.signum() }
},
),
))),
other => exec_err!("Unsupported data type {other:?} for function signum"),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this should be internal error to be consistent with scalar path above

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a basic mental model for when I should use exec_err and when internal_err? Is there any documentation for this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exec err -> things that can happen in normal execution, such as invalid value to a function (e.g. trying to get ascii character from an integer input, and we input a value that doesnt have a corresponding character like 99999)

internal err -> things that shouldn't normally happen, aka occur if some other bug in datafusion allowed this code path to occur

in this case, the signature should already guard us to only have f32/f64 inputs; therefore if at this point we find an array not of that type, then something went wrong in type coercion/signature code and its an internal bug

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Jefffrey

},
}
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

/// signum SQL function
fn signum(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
Float64 => Ok(Arc::new(
args[0]
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(
|x: f64| {
if x == 0_f64 { 0_f64 } else { x.signum() }
},
),
) as ArrayRef),

Float32 => Ok(Arc::new(
args[0]
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(
|x: f32| {
if x == 0_f32 { 0_f32 } else { x.signum() }
},
),
) as ArrayRef),

other => exec_err!("Unsupported data type {other:?} for function signum"),
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand Down