Skip to content
75 changes: 64 additions & 11 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ use std::any::Any;

use super::power::PowerFunc;

use crate::utils::{calculate_binary_math, decimal128_to_i128};
use crate::utils::{
calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128,
};
use arrow::array::{Array, ArrayRef};
use arrow::compute::kernels::cast;
use arrow::datatypes::{
DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
Float32Type, Float64Type,
};
use arrow::error::ArrowError;
use arrow_buffer::i256;
Expand Down Expand Up @@ -102,6 +104,54 @@ impl LogFunc {
}
}

/// Binary function to calculate logarithm of Decimal32 `value` using `base` base
/// Returns error if base is invalid
fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
if !base.is_finite() || base.trunc() != base {
return Err(ArrowError::ComputeError(format!(
"Log cannot use non-integer base: {base}"
)));
}
if (base as u32) < 2 {
return Err(ArrowError::ComputeError(format!(
"Log base must be greater than 1: {base}"
)));
}

let unscaled_value = decimal32_to_i32(value, scale)?;
if unscaled_value > 0 {
let log_value: u32 = unscaled_value.ilog(base as i32);
Ok(log_value as f64)
} else {
// Reflect f64::log behaviour
Ok(f64::NAN)
}
}

/// Binary function to calculate logarithm of Decimal64 `value` using `base` base
/// Returns error if base is invalid
fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
if !base.is_finite() || base.trunc() != base {
return Err(ArrowError::ComputeError(format!(
"Log cannot use non-integer base: {base}"
)));
}
if (base as u32) < 2 {
return Err(ArrowError::ComputeError(format!(
"Log base must be greater than 1: {base}"
)));
}

let unscaled_value = decimal64_to_i64(value, scale)?;
if unscaled_value > 0 {
let log_value: u32 = unscaled_value.ilog(base as i64);
Ok(log_value as f64)
} else {
// Reflect f64::log behaviour
Ok(f64::NAN)
}
}

/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
/// Returns error if base is invalid
fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> {
Expand Down Expand Up @@ -223,15 +273,18 @@ impl ScalarUDFImpl for LogFunc {
|value, base| Ok(value.log(base)),
)?
}
// TODO: native log support for decimal 32 & 64; right now upcast
// to decimal128 to calculate
// https://github.com/apache/datafusion/issues/17555
DataType::Decimal32(precision, scale)
| DataType::Decimal64(precision, scale) => {
calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>(
&cast(&value, &DataType::Decimal128(*precision, *scale))?,
DataType::Decimal32(_, scale) => {
calculate_binary_math::<Decimal32Type, Float64Type, Float64Type, _>(
&value,
&base,
|value, base| log_decimal128(value, *scale, base),
|value, base| log_decimal32(value, *scale, base),
)?
}
DataType::Decimal64(_, scale) => {
calculate_binary_math::<Decimal64Type, Float64Type, Float64Type, _>(
&value,
&base,
|value, base| log_decimal64(value, *scale, base),
)?
}
DataType::Decimal128(_, scale) => {
Expand Down
137 changes: 137 additions & 0 deletions datafusion/functions/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,40 @@ pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
}
}

pub fn decimal32_to_i32(value: i32, scale: i8) -> Result<i32, ArrowError> {
if scale < 0 {
Err(ArrowError::ComputeError(
"Negative scale is not supported".into(),
))
} else if scale == 0 {
Ok(value)
} else {
match 10_i32.checked_pow(scale as u32) {
Some(divisor) => Ok(value / divisor),
None => Err(ArrowError::ComputeError(format!(
"Cannot get a power of {scale}"
))),
}
}
}

pub fn decimal64_to_i64(value: i64, scale: i8) -> Result<i64, ArrowError> {
if scale < 0 {
Err(ArrowError::ComputeError(
"Negative scale is not supported".into(),
))
} else if scale == 0 {
Ok(value)
} else {
match i64::from(10).checked_pow(scale as u32) {
Some(divisor) => Ok(value / divisor),
None => Err(ArrowError::ComputeError(format!(
"Cannot get a power of {scale}"
))),
}
}
}

#[cfg(test)]
pub mod test {
/// $FUNC ScalarUDFImpl to test
Expand Down Expand Up @@ -334,6 +368,7 @@ pub mod test {
}

use arrow::datatypes::DataType;
use itertools::Either;
pub(crate) use test_function;

use super::*;
Expand Down Expand Up @@ -376,4 +411,106 @@ pub mod test {
}
}
}

#[test]
fn test_decimal32_to_i32() {
let cases: [(i32, i8, Either<i32, String>); _] = [
(123, 0, Either::Left(123)),
(1230, 1, Either::Left(123)),
(123000, 3, Either::Left(123)),
(1234567, 2, Either::Left(12345)),
(-1234567, 2, Either::Left(-12345)),
(1, 0, Either::Left(1)),
(
123,
-3,
Either::Right("Negative scale is not supported".into()),
),
(
123,
i8::MAX,
Either::Right("Cannot get a power of 127".into()),
),
(999999999, 0, Either::Left(999999999)),
(999999999, 3, Either::Left(999999)),
];

for (value, scale, expected) in cases {
match decimal32_to_i32(value, scale) {
Ok(actual) => {
let expected_value =
expected.left().expect("Got value but expected none");
assert_eq!(
actual, expected_value,
"{value} and {scale} vs {expected_value:?}"
);
}
Err(ArrowError::ComputeError(msg)) => {
assert_eq!(
msg,
expected.right().expect("Got error but expected value")
);
}
Err(_) => {
assert!(expected.is_right())
}
}
}
}

#[test]
fn test_decimal64_to_i64() {
let cases: [(i64, i8, Either<i64, String>); _] = [
(123, 0, Either::Left(123)),
(1234567890, 2, Either::Left(12345678)),
(-1234567890, 2, Either::Left(-12345678)),
(
123,
-3,
Either::Right("Negative scale is not supported".into()),
),
(
123,
i8::MAX,
Either::Right("Cannot get a power of 127".into()),
),
(
999999999999999999i64,
0,
Either::Left(999999999999999999i64),
),
(
999999999999999999i64,
3,
Either::Left(999999999999999999i64 / 1000),
),
(
-999999999999999999i64,
3,
Either::Left(-999999999999999999i64 / 1000),
),
];

for (value, scale, expected) in cases {
match decimal64_to_i64(value, scale) {
Ok(actual) => {
let expected_value =
expected.left().expect("Got value but expected none");
assert_eq!(
actual, expected_value,
"{value} and {scale} vs {expected_value:?}"
);
}
Err(ArrowError::ComputeError(msg)) => {
assert_eq!(
msg,
expected.right().expect("Got error but expected value")
);
}
Err(_) => {
assert!(expected.is_right())
}
}
}
}
}
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,11 @@ select log(arrow_cast(100, 'Decimal32(9, 2)'));
----
2

query R
select log(2.0, arrow_cast(12345.67, 'Decimal32(9, 2)'));
----
13

# log for small decimal64
query R
select log(arrow_cast(100, 'Decimal64(18, 0)'));
Expand All @@ -805,6 +810,12 @@ select log(arrow_cast(100, 'Decimal64(18, 2)'));
----
2

query R
select log(2.0, arrow_cast(12345.6789, 'Decimal64(15, 4)'));
----
13


# log for small decimal128
query R
select log(arrow_cast(100, 'Decimal128(38, 0)'));
Expand Down