Skip to content

Commit

Permalink
Add nanvl builtin function (#7171)
Browse files Browse the repository at this point in the history
* Add nanvl.

* Add tests.

* Apply cargo fmt.

* Add doc.

* Minor doc change.
  • Loading branch information
sarutak committed Aug 2, 2023
1 parent c2cb616 commit 354dc19
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 4 deletions.
3 changes: 3 additions & 0 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ async fn test_mathematical_expressions_with_null() -> Result<()> {
test_expression!("atan2(NULL, NULL)", "NULL");
test_expression!("atan2(1, NULL)", "NULL");
test_expression!("atan2(NULL, 1)", "NULL");
test_expression!("nanvl(NULL, NULL)", "NULL");
test_expression!("nanvl(1, NULL)", "NULL");
test_expression!("nanvl(NULL, 1)", "NULL");
Ok(())
}

Expand Down
6 changes: 6 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,9 @@ query RRRRRRR
SELECT atan2(2.0, 1.0), atan2(-2.0, 1.0), atan2(2.0, -1.0), atan2(-2.0, -1.0), atan2(NULL, 1.0), atan2(2.0, NULL), atan2(NULL, NULL);
----
1.107148717794 -1.107148717794 2.034443935796 -2.034443935796 NULL NULL NULL

# nanvl
query RRR
SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10))
----
1 1 NaN
35 changes: 35 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,41 @@ select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from signed_integ
NaN 13.28771 NaN
NaN 6.64386 NaN

## nanvl

# nanvl scalar function
query RRR rowsort
select nanvl(0, 1), nanvl(asin(10), 2), nanvl(3, asin(10));
----
0 2 3

# nanvl scalar nulls
query R rowsort
select nanvl(null, 64);
----
NULL

# nanvl scalar nulls #1
query R rowsort
select nanvl(2, null);
----
NULL

# nanvl scalar nulls #2
query R rowsort
select nanvl(null, null);
----
NULL

# nanvl with columns (round is needed to normalize the outputs of different operating systems)
query RRR rowsort
select round(nanvl(asin(f + a), 2), 5), round(nanvl(asin(b + c), 3), 5), round(nanvl(asin(d + e), 4), 5) from small_floats;
----
0.7754 1.11977 -0.9273
2 -0.20136 0.7754
2 -1.11977 4
NULL NULL NULL

## pi

# pi scalar function
Expand Down
13 changes: 13 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ pub enum BuiltinScalarFunction {
Log10,
/// log2
Log2,
/// nanvl
Nanvl,
/// pi
Pi,
/// power
Expand Down Expand Up @@ -328,6 +330,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Log => Volatility::Immutable,
BuiltinScalarFunction::Log10 => Volatility::Immutable,
BuiltinScalarFunction::Log2 => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Pi => Volatility::Immutable,
BuiltinScalarFunction::Power => Volatility::Immutable,
BuiltinScalarFunction::Round => Volatility::Immutable,
Expand Down Expand Up @@ -760,6 +763,11 @@ impl BuiltinScalarFunction {
_ => Ok(Float64),
},

BuiltinScalarFunction::Nanvl => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},

BuiltinScalarFunction::ArrowTypeof => Ok(Utf8),

BuiltinScalarFunction::Abs
Expand Down Expand Up @@ -1120,6 +1128,10 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Nanvl => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
),
BuiltinScalarFunction::Factorial => {
Signature::uniform(1, vec![Int64], self.volatility())
}
Expand Down Expand Up @@ -1193,6 +1205,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Log => &["log"],
BuiltinScalarFunction::Log10 => &["log10"],
BuiltinScalarFunction::Log2 => &["log2"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Pi => &["pi"],
BuiltinScalarFunction::Power => &["power", "pow"],
BuiltinScalarFunction::Radians => &["radians"],
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ scalar_expr!(
scalar_expr!(CurrentDate, current_date, ,"returns current UTC date as a [`DataType::Date32`] value");
scalar_expr!(Now, now, ,"returns current timestamp in nanoseconds, using the same value for all instances of now() in same statement");
scalar_expr!(CurrentTime, current_time, , "returns current UTC time as a [`DataType::Time64`] value");
scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y");

scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");

Expand Down Expand Up @@ -989,6 +990,7 @@ mod test {
test_unary_scalar_expr!(Log10, log10);
test_unary_scalar_expr!(Ln, ln);
test_scalar_expr!(Atan2, atan2, y, x);
test_scalar_expr!(Nanvl, nanvl, x, y);

test_scalar_expr!(Ascii, ascii, input);
test_scalar_expr!(BitLength, bit_length, string);
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln),
BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10),
BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2),
BuiltinScalarFunction::Nanvl => {
Arc::new(|args| make_scalar_function(math_expressions::nanvl)(args))
}
BuiltinScalarFunction::Radians => Arc::new(math_expressions::to_radians),
BuiltinScalarFunction::Random => Arc::new(math_expressions::random),
BuiltinScalarFunction::Round => {
Expand Down
83 changes: 83 additions & 0 deletions datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,53 @@ pub fn lcm(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// Nanvl SQL function
pub fn nanvl(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Float64 => {
let compute_nanvl = |x: f64, y: f64| {
if x.is_nan() {
y
} else {
x
}
};

Ok(Arc::new(make_function_inputs2!(
&args[0],
&args[1],
"x",
"y",
Float64Array,
{ compute_nanvl }
)) as ArrayRef)
}

DataType::Float32 => {
let compute_nanvl = |x: f32, y: f32| {
if x.is_nan() {
y
} else {
x
}
};

Ok(Arc::new(make_function_inputs2!(
&args[0],
&args[1],
"x",
"y",
Float32Array,
{ compute_nanvl }
)) as ArrayRef)
}

other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function nanvl"
))),
}
}

/// Pi SQL function
pub fn pi(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if !matches!(&args[0], ColumnarValue::Array(_)) {
Expand Down Expand Up @@ -958,4 +1005,40 @@ mod tests {
assert_eq!(floats.value(3), 123.0);
assert_eq!(floats.value(4), -321.0);
}

#[test]
fn test_nanvl_f64() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y
Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x
];

let result = nanvl(&args).expect("failed to initialize function atan2");
let floats =
as_float64_array(&result).expect("failed to initialize function atan2");

assert_eq!(floats.len(), 4);
assert_eq!(floats.value(0), 1.0);
assert_eq!(floats.value(1), 6.0);
assert_eq!(floats.value(2), 3.0);
assert!(floats.value(3).is_nan());
}

#[test]
fn test_nanvl_f32() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y
Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x
];

let result = nanvl(&args).expect("failed to initialize function atan2");
let floats =
as_float32_array(&result).expect("failed to initialize function atan2");

assert_eq!(floats.len(), 4);
assert_eq!(floats.value(0), 1.0);
assert_eq!(floats.value(1), 6.0);
assert_eq!(floats.value(2), 3.0);
assert!(floats.value(3).is_nan());
}
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ enum ScalarFunction {
ArrayReplaceN = 108;
ArrayRemoveAll = 109;
ArrayReplaceAll = 110;
Nanvl = 111;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ use datafusion_expr::{
expr::{self, InList, Sort, WindowFunction},
factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, now, nullif, octet_length, pi, power, radians, random,
regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with,
strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros,
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians,
random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad,
rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt,
starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros,
to_timestamp_millis, to_timestamp_seconds, translate, trim, trim_array, trunc, upper,
uuid,
window_frame::regularize,
Expand Down Expand Up @@ -522,6 +522,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::StructFun => Self::Struct,
ScalarFunction::FromUnixtime => Self::FromUnixtime,
ScalarFunction::Atan2 => Self::Atan2,
ScalarFunction::Nanvl => Self::Nanvl,
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
}
}
Expand Down Expand Up @@ -1527,6 +1528,10 @@ pub fn parse_expr(
ScalarFunction::CurrentDate => Ok(current_date()),
ScalarFunction::CurrentTime => Ok(current_time()),
ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)),
ScalarFunction::Nanvl => Ok(nanvl(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
_ => Err(proto_error(
"Protobuf deserialization error: Unsupported scalar function",
)),
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Struct => Self::StructFun,
BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime,
BuiltinScalarFunction::Atan2 => Self::Atan2,
BuiltinScalarFunction::Nanvl => Self::Nanvl,
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
};

Expand Down
17 changes: 17 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- [log](#log)
- [log10](#log10)
- [log2](#log2)
- [nanvl](#nanvl)
- [pi](#pi)
- [power](#power)
- [pow](#pow)
Expand Down Expand Up @@ -353,6 +354,22 @@ log2(numeric_expression)
- **numeric_expression**: Numeric expression to operate on.
Can be a constant, column, or function, and any combination of arithmetic operators.

### `nanvl`

Returns the first argument if it's not _NaN_.
Returns the second argument otherwise.

```
nanvl(expression_x, expression_y)
```

#### Arguments

- **expression_x**: Numeric expression to return if it's not _NaN_.
Can be a constant, column, or function, and any combination of arithmetic operators.
- **expression_y**: Numeric expression to return if the first expression is _NaN_.
Can be a constant, column, or function, and any combination of arithmetic operators.

### `pi`

Returns an approximate value of π.
Expand Down

0 comments on commit 354dc19

Please sign in to comment.