diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 72a9cf4555787..18229fb076ad3 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -21,11 +21,13 @@ use std::any::Any; use super::power::PowerFunc; -use crate::utils::{calculate_binary_math, decimal128_to_i128}; +use crate::utils::{ + calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128, +}; use arrow::array::{Array, ArrayRef}; -use arrow::compute::kernels::cast; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, }; use arrow::error::ArrowError; use arrow_buffer::i256; @@ -102,6 +104,54 @@ impl LogFunc { } } +/// Binary function to calculate logarithm of Decimal32 `value` using `base` base +/// Returns error if base is invalid +fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { + if !base.is_finite() || base.trunc() != base { + return Err(ArrowError::ComputeError(format!( + "Log cannot use non-integer base: {base}" + ))); + } + if (base as u32) < 2 { + return Err(ArrowError::ComputeError(format!( + "Log base must be greater than 1: {base}" + ))); + } + + let unscaled_value = decimal32_to_i32(value, scale)?; + if unscaled_value > 0 { + let log_value: u32 = unscaled_value.ilog(base as i32); + Ok(log_value as f64) + } else { + // Reflect f64::log behaviour + Ok(f64::NAN) + } +} + +/// Binary function to calculate logarithm of Decimal64 `value` using `base` base +/// Returns error if base is invalid +fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { + if !base.is_finite() || base.trunc() != base { + return Err(ArrowError::ComputeError(format!( + "Log cannot use non-integer base: {base}" + ))); + } + if (base as u32) < 2 { + return Err(ArrowError::ComputeError(format!( + "Log base must be greater than 1: {base}" + ))); + } + + let unscaled_value = decimal64_to_i64(value, scale)?; + if unscaled_value > 0 { + let log_value: u32 = unscaled_value.ilog(base as i64); + Ok(log_value as f64) + } else { + // Reflect f64::log behaviour + Ok(f64::NAN) + } +} + /// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base /// Returns error if base is invalid fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { @@ -223,15 +273,18 @@ impl ScalarUDFImpl for LogFunc { |value, base| Ok(value.log(base)), )? } - // TODO: native log support for decimal 32 & 64; right now upcast - // to decimal128 to calculate - // https://github.com/apache/datafusion/issues/17555 - DataType::Decimal32(precision, scale) - | DataType::Decimal64(precision, scale) => { - calculate_binary_math::( - &cast(&value, &DataType::Decimal128(*precision, *scale))?, + DataType::Decimal32(_, scale) => { + calculate_binary_math::( + &value, &base, - |value, base| log_decimal128(value, *scale, base), + |value, base| log_decimal32(value, *scale, base), + )? + } + DataType::Decimal64(_, scale) => { + calculate_binary_math::( + &value, + &base, + |value, base| log_decimal64(value, *scale, base), )? } DataType::Decimal128(_, scale) => { diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index c4f15d0cca7f6..e160eb68d55e2 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -219,6 +219,40 @@ pub fn decimal128_to_i128(value: i128, scale: i8) -> Result { } } +pub fn decimal32_to_i32(value: i32, scale: i8) -> Result { + if scale < 0 { + Err(ArrowError::ComputeError( + "Negative scale is not supported".into(), + )) + } else if scale == 0 { + Ok(value) + } else { + match 10_i32.checked_pow(scale as u32) { + Some(divisor) => Ok(value / divisor), + None => Err(ArrowError::ComputeError(format!( + "Cannot get a power of {scale}" + ))), + } + } +} + +pub fn decimal64_to_i64(value: i64, scale: i8) -> Result { + if scale < 0 { + Err(ArrowError::ComputeError( + "Negative scale is not supported".into(), + )) + } else if scale == 0 { + Ok(value) + } else { + match i64::from(10).checked_pow(scale as u32) { + Some(divisor) => Ok(value / divisor), + None => Err(ArrowError::ComputeError(format!( + "Cannot get a power of {scale}" + ))), + } + } +} + #[cfg(test)] pub mod test { /// $FUNC ScalarUDFImpl to test @@ -334,6 +368,7 @@ pub mod test { } use arrow::datatypes::DataType; + use itertools::Either; pub(crate) use test_function; use super::*; @@ -376,4 +411,106 @@ pub mod test { } } } + + #[test] + fn test_decimal32_to_i32() { + let cases: [(i32, i8, Either); _] = [ + (123, 0, Either::Left(123)), + (1230, 1, Either::Left(123)), + (123000, 3, Either::Left(123)), + (1234567, 2, Either::Left(12345)), + (-1234567, 2, Either::Left(-12345)), + (1, 0, Either::Left(1)), + ( + 123, + -3, + Either::Right("Negative scale is not supported".into()), + ), + ( + 123, + i8::MAX, + Either::Right("Cannot get a power of 127".into()), + ), + (999999999, 0, Either::Left(999999999)), + (999999999, 3, Either::Left(999999)), + ]; + + for (value, scale, expected) in cases { + match decimal32_to_i32(value, scale) { + Ok(actual) => { + let expected_value = + expected.left().expect("Got value but expected none"); + assert_eq!( + actual, expected_value, + "{value} and {scale} vs {expected_value:?}" + ); + } + Err(ArrowError::ComputeError(msg)) => { + assert_eq!( + msg, + expected.right().expect("Got error but expected value") + ); + } + Err(_) => { + assert!(expected.is_right()) + } + } + } + } + + #[test] + fn test_decimal64_to_i64() { + let cases: [(i64, i8, Either); _] = [ + (123, 0, Either::Left(123)), + (1234567890, 2, Either::Left(12345678)), + (-1234567890, 2, Either::Left(-12345678)), + ( + 123, + -3, + Either::Right("Negative scale is not supported".into()), + ), + ( + 123, + i8::MAX, + Either::Right("Cannot get a power of 127".into()), + ), + ( + 999999999999999999i64, + 0, + Either::Left(999999999999999999i64), + ), + ( + 999999999999999999i64, + 3, + Either::Left(999999999999999999i64 / 1000), + ), + ( + -999999999999999999i64, + 3, + Either::Left(-999999999999999999i64 / 1000), + ), + ]; + + for (value, scale, expected) in cases { + match decimal64_to_i64(value, scale) { + Ok(actual) => { + let expected_value = + expected.left().expect("Got value but expected none"); + assert_eq!( + actual, expected_value, + "{value} and {scale} vs {expected_value:?}" + ); + } + Err(ArrowError::ComputeError(msg)) => { + assert_eq!( + msg, + expected.right().expect("Got error but expected value") + ); + } + Err(_) => { + assert!(expected.is_right()) + } + } + } + } } diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index a6b6dd04889ac..143cd786ab3c0 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -794,6 +794,11 @@ select log(arrow_cast(100, 'Decimal32(9, 2)')); ---- 2 +query R +select log(2.0, arrow_cast(12345.67, 'Decimal32(9, 2)')); +---- +13 + # log for small decimal64 query R select log(arrow_cast(100, 'Decimal64(18, 0)')); @@ -805,6 +810,12 @@ select log(arrow_cast(100, 'Decimal64(18, 2)')); ---- 2 +query R +select log(2.0, arrow_cast(12345.6789, 'Decimal64(15, 4)')); +---- +13 + + # log for small decimal128 query R select log(arrow_cast(100, 'Decimal128(38, 0)'));