diff --git a/rust/arrow/src/compute/kernels/boolean.rs b/rust/arrow/src/compute/kernels/boolean.rs index f404ff4d1e8..07cf5288fcf 100644 --- a/rust/arrow/src/compute/kernels/boolean.rs +++ b/rust/arrow/src/compute/kernels/boolean.rs @@ -22,14 +22,15 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. +use std::ops::Not; use std::sync::Arc; -use crate::array::{Array, ArrayData, BooleanArray}; +use crate::array::{Array, ArrayData, BooleanArray, PrimitiveArray}; use crate::buffer::{ buffer_bin_and, buffer_bin_or, buffer_unary_not, Buffer, MutableBuffer, }; use crate::compute::util::combine_option_bitmap; -use crate::datatypes::DataType; +use crate::datatypes::{ArrowNumericType, DataType}; use crate::error::{ArrowError, Result}; use crate::util::bit_util::ceil; @@ -223,6 +224,102 @@ pub fn is_not_null(input: &Array) -> Result { Ok(BooleanArray::from(Arc::new(data))) } +/// Copies original array, setting null bit to true if a secondary comparison boolean array is set to true. +/// Typically used to implement NULLIF. +// NOTE: For now this only supports Primitive Arrays. Although the code could be made generic, the issue +// is that currently the bitmap operations result in a final bitmap which is aligned to bit 0, and thus +// the left array's data needs to be sliced to a new offset, and for non-primitive arrays shifting the +// data might be too complicated. In the future, to avoid shifting left array's data, we could instead +// shift the final bitbuffer to the right, prepending with 0's instead. +pub fn nullif( + left: &PrimitiveArray, + right: &BooleanArray, +) -> Result> +where + T: ArrowNumericType, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + let left_data = left.data(); + let right_data = right.data(); + + // If left has no bitmap, create a new one with all values set for nullity op later + // left=0 (null) right=null output bitmap=null + // left=0 right=1 output bitmap=null + // left=1 (set) right=null output bitmap=set (passthrough) + // left=1 right=1 & comp=true output bitmap=null + // left=1 right=1 & comp=false output bitmap=set + // + // Thus: result = left null bitmap & (!right_values | !right_bitmap) + // OR left null bitmap & !(right_values & right_bitmap) + // + // Do the right expression !(right_values & right_bitmap) first since there are two steps + // TRICK: convert BooleanArray buffer as a bitmap for faster operation + let right_combo_buffer = match right.data().null_bitmap() { + Some(right_bitmap) => { + // NOTE: right values and bitmaps are combined and stay at bit offset right.offset() + (&right.values() & &right_bitmap.bits).ok().map(|b| b.not()) + } + None => Some(!&right.values()), + }; + + // AND of original left null bitmap with right expression + // Here we take care of the possible offsets of the left and right arrays all at once. + let modified_null_buffer = match left_data.null_bitmap() { + Some(left_null_bitmap) => match right_combo_buffer { + Some(rcb) => Some(buffer_bin_and( + &left_null_bitmap.bits, + left_data.offset(), + &rcb, + right_data.offset(), + left_data.len(), + )), + None => Some( + left_null_bitmap + .bits + .bit_slice(left_data.offset(), left.len()), + ), + }, + None => right_combo_buffer + .map(|rcb| rcb.bit_slice(right_data.offset(), right_data.len())), + }; + + // Align/shift left data on offset as needed, since new bitmaps are shifted and aligned to 0 already + // NOTE: this probably only works for primitive arrays. + let data_buffers = if left.offset() == 0 { + left_data.buffers().to_vec() + } else { + // Shift each data buffer by type's bit_width * offset. + left_data + .buffers() + .iter() + .map(|buf| { + buf.bit_slice( + left.offset() * T::get_bit_width(), + left.len() * T::get_bit_width(), + ) + }) + .collect::>() + }; + + // Construct new array with same values but modified null bitmap + // TODO: shift data buffer as needed + let data = ArrayData::new( + T::DATA_TYPE, + left.len(), + None, // force new to compute the number of null bits + modified_null_buffer, + 0, // No need for offset since left data has been shifted + data_buffers, + left_data.child_data().to_vec(), + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + #[cfg(test)] mod tests { use super::*; @@ -585,4 +682,50 @@ mod tests { assert_eq!(expected, res); assert_eq!(&None, res.data_ref().null_bitmap()); } + + #[test] + fn test_nullif_int_array() { + let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9)]); + let comp = + BooleanArray::from(vec![Some(false), None, Some(true), Some(false), None]); + let res = nullif(&a, &comp).unwrap(); + + let expected = Int32Array::from(vec![ + Some(15), + None, + None, // comp true, slot 2 turned into null + Some(1), + // Even though comp array / right is null, should still pass through original value + // comp true, slot 2 turned into null + Some(9), + ]); + + assert_eq!(expected, res); + } + + #[test] + fn test_nullif_int_array_offset() { + let a = Int32Array::from(vec![None, Some(15), Some(8), Some(1), Some(9)]); + let a = a.slice(1, 3); // Some(15), Some(8), Some(1) + let a = a.as_any().downcast_ref::().unwrap(); + let comp = BooleanArray::from(vec![ + Some(false), + Some(false), + Some(false), + None, + Some(true), + Some(false), + None, + ]); + let comp = comp.slice(2, 3); // Some(false), None, Some(true) + let comp = comp.as_any().downcast_ref::().unwrap(); + let res = nullif(&a, &comp).unwrap(); + + let expected = Int32Array::from(vec![ + Some(15), // False => keep it + Some(8), // None => keep it + None, // true => None + ]); + assert_eq!(&expected, &res) + } } diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 114cc41b518..bd828c678e8 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -58,6 +58,8 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - String functions - [x] Length - [x] Concatenate +- Miscellaneous/Boolean functions + - [x] nullif - Common date/time functions - [ ] Basic date functions - [ ] Basic time functions diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 4247cf0a752..0004fdfc1ae 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -30,7 +30,7 @@ use arrow::array::{self, Array, BooleanBuilder, LargeStringArray}; use arrow::compute; use arrow::compute::kernels; use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; -use arrow::compute::kernels::boolean::{and, or}; +use arrow::compute::kernels::boolean::{and, nullif, or}; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, @@ -1535,6 +1535,80 @@ pub fn binary( Ok(Arc::new(BinaryExpr::new(l, op, r))) } +/// Invoke a compute kernel on a primitive array and a Boolean Array +macro_rules! compute_bool_array_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::() + .expect("compute_op failed to downcast array"); + Ok(Arc::new($OP(&ll, &rr)?)) + }}; +} + +/// Binary op between primitive and boolean arrays +macro_rules! primitive_bool_array_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + match $LEFT.data_type() { + DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array), + DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array), + DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array), + DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array), + DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array), + DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array), + DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array), + DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array), + DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array), + DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for NULLIF/primitive/boolean operator", + other + ))), + } + }}; +} + +/// +/// Implements NULLIF(expr1, expr2) +/// Args: 0 - left expr is any array +/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. +/// +pub fn nullif_func(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "{:?} args were supplied but NULLIF takes exactly two args", + args.len(), + ))); + } + + // Get args0 == args1 evaluated and produce a boolean array + let cond_array = binary_array_op!(args[0], args[1], eq)?; + + // Now, invoke nullif on the result + primitive_bool_array_op!(args[0], *cond_array, nullif) +} + +/// Currently supported types by the nullif function. +/// The order of these types correspond to the order on which coercion applies +/// This should thus be from least informative to most informative +pub static SUPPORTED_NULLIF_TYPES: &'static [DataType] = &[ + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, +]; + /// Not expression #[derive(Debug)] pub struct NotExpr { @@ -3151,6 +3225,70 @@ mod tests { ) } + #[test] + fn nullif_int32() -> Result<()> { + let a = Int32Array::from(vec![ + Some(1), + Some(2), + None, + None, + Some(3), + None, + None, + Some(4), + Some(5), + ]); + let a = Arc::new(a); + let a_len = a.len(); + + let lit_array = Arc::new(Int32Array::from(vec![2; a.len()])); + + let result = nullif_func(&[a.clone(), lit_array])?; + + assert_eq!(result.len(), a_len); + + let expected = Int32Array::from(vec![ + Some(1), + None, + None, + None, + Some(3), + None, + None, + Some(4), + Some(5), + ]); + assert_array_eq::(expected, result); + Ok(()) + } + + #[test] + // Ensure that arrays with no nulls can also invoke NULLIF() correctly + fn nullif_int32_nonulls() -> Result<()> { + let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); + let a = Arc::new(a); + let a_len = a.len(); + + let lit_array = Arc::new(Int32Array::from(vec![1; a.len()])); + + let result = nullif_func(&[a.clone(), lit_array])?; + assert_eq!(result.len(), a_len); + + let expected = Int32Array::from(vec![ + None, + Some(3), + Some(10), + Some(7), + Some(8), + None, + Some(2), + Some(4), + Some(5), + ]); + assert_array_eq::(expected, result); + Ok(()) + } + fn aggregate( batch: &RecordBatch, agg: Arc, @@ -3275,7 +3413,11 @@ mod tests { .expect("Actual array should unwrap to type of expected array"); for i in 0..expected.len() { - assert_eq!(expected.value(i), actual.value(i)); + if expected.is_null(i) { + assert!(actual.is_null(i)); + } else { + assert_eq!(expected.value(i), actual.value(i)); + } } } diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index b63c05d64d3..19d3b6e2633 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -36,6 +36,7 @@ use super::{ use crate::error::{DataFusionError, Result}; use crate::physical_plan::array_expressions; use crate::physical_plan::datetime_expressions; +use crate::physical_plan::expressions::{nullif_func, SUPPORTED_NULLIF_TYPES}; use crate::physical_plan::math_expressions; use crate::physical_plan::string_expressions; use arrow::{ @@ -121,6 +122,8 @@ pub enum BuiltinScalarFunction { ToTimestamp, /// construct an array from columns Array, + /// SQL NULLIF() + NullIf, } impl fmt::Display for BuiltinScalarFunction { @@ -155,6 +158,7 @@ impl FromStr for BuiltinScalarFunction { "concat" => BuiltinScalarFunction::Concat, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "array" => BuiltinScalarFunction::Array, + "nullif" => BuiltinScalarFunction::NullIf, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -184,9 +188,8 @@ pub fn return_type( )); } - // the return type of the built in function. Eventually there - // will be built-in functions whose return type depends on the - // incoming type. + // the return type of the built in function. + // Some built-in functions' return type depends on the incoming type. match fun { BuiltinScalarFunction::Length => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::Int64, @@ -206,6 +209,11 @@ pub fn return_type( Box::new(Field::new("item", arg_types[0].clone(), true)), arg_types.len() as i32, )), + BuiltinScalarFunction::NullIf => { + // NULLIF has two args and they might get coerced, get a preview of this + let coerced_types = data_types(arg_types, &signature(fun)); + coerced_types.map(|typs| typs[0].clone()) + } _ => Ok(DataType::Float64), } } @@ -235,6 +243,7 @@ pub fn create_physical_expr( BuiltinScalarFunction::Trunc => math_expressions::trunc, BuiltinScalarFunction::Abs => math_expressions::abs, BuiltinScalarFunction::Signum => math_expressions::signum, + BuiltinScalarFunction::NullIf => nullif_func, BuiltinScalarFunction::Length => |args| Ok(length(args[0].as_ref())?), BuiltinScalarFunction::Concat => { |args| Ok(Arc::new(string_expressions::concatenate(args)?)) @@ -274,6 +283,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Array => { Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) } + BuiltinScalarFunction::NullIf => { + Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) + } // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 0d275391ca9..e27fd8e3de6 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -788,6 +788,16 @@ mod tests { quick_test(sql, expected); } + #[test] + fn select_where_nullif_division() { + let sql = "SELECT c3/(c4+c5) \ + FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1"; + let expected = "Projection: #c3 Divide #c4 Plus #c5\ + \n Filter: #c3 Divide nullif(#c4 Plus #c5, Int64(0)) Gt Float64(0.1)\ + \n TableScan: aggregate_test_100 projection=None"; + quick_test(sql, expected); + } + #[test] fn select_order_by() { let sql = "SELECT id FROM person ORDER BY id"; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index fc35f4fd975..49df98d5700 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -508,6 +508,29 @@ async fn csv_query_avg_multi_batch() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_nullif_divide_by_0() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT c8/nullif(c7, 0) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let actual = &actual[80..90]; // We just want to compare rows 80-89 + let expected = vec![ + vec!["258"], + vec!["664"], + vec!["NULL"], + vec!["22"], + vec!["164"], + vec!["448"], + vec!["365"], + vec!["1640"], + vec!["671"], + vec!["203"], + ]; + assert_eq!(expected, actual); + Ok(()) +} + #[tokio::test] async fn csv_query_count() -> Result<()> { let mut ctx = ExecutionContext::new();