-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-10330: [Rust][DataFusion] Implement NULLIF() SQL function #8688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4e364ac
e0b4975
1fa8567
ea2dc55
bcb9822
ee86ef6
431f32e
b1048ad
f188050
09b6bc4
ebc56e6
8753cb1
9edc13c
308c341
174db8e
adc36bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ use arrow::array::{self, Array, BooleanBuilder, LargeStringArray}; | |
| use arrow::compute; | ||
| use arrow::compute::kernels; | ||
| use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; | ||
| use arrow::compute::kernels::boolean::{and, or}; | ||
| use arrow::compute::kernels::boolean::{and, nullif, or}; | ||
| use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; | ||
| use arrow::compute::kernels::comparison::{ | ||
| eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, | ||
|
|
@@ -1535,6 +1535,80 @@ pub fn binary( | |
| Ok(Arc::new(BinaryExpr::new(l, op, r))) | ||
| } | ||
|
|
||
| /// Invoke a compute kernel on a primitive array and a Boolean Array | ||
| macro_rules! compute_bool_array_op { | ||
| ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ | ||
| let ll = $LEFT | ||
| .as_any() | ||
| .downcast_ref::<$DT>() | ||
| .expect("compute_op failed to downcast array"); | ||
| let rr = $RIGHT | ||
| .as_any() | ||
| .downcast_ref::<BooleanArray>() | ||
| .expect("compute_op failed to downcast array"); | ||
| Ok(Arc::new($OP(&ll, &rr)?)) | ||
| }}; | ||
| } | ||
|
|
||
| /// Binary op between primitive and boolean arrays | ||
| macro_rules! primitive_bool_array_op { | ||
| ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ | ||
| match $LEFT.data_type() { | ||
| DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array), | ||
| DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array), | ||
| DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array), | ||
| DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array), | ||
| DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array), | ||
| DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array), | ||
| DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array), | ||
| DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array), | ||
| DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array), | ||
| DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array), | ||
| other => Err(DataFusionError::Internal(format!( | ||
| "Unsupported data type {:?} for NULLIF/primitive/boolean operator", | ||
| other | ||
| ))), | ||
| } | ||
| }}; | ||
| } | ||
|
|
||
| /// | ||
| /// Implements NULLIF(expr1, expr2) | ||
| /// Args: 0 - left expr is any array | ||
| /// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. | ||
| /// | ||
| pub fn nullif_func(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| if args.len() != 2 { | ||
| return Err(DataFusionError::Internal(format!( | ||
| "{:?} args were supplied but NULLIF takes exactly two args", | ||
| args.len(), | ||
| ))); | ||
| } | ||
|
|
||
| // Get args0 == args1 evaluated and produce a boolean array | ||
| let cond_array = binary_array_op!(args[0], args[1], eq)?; | ||
|
|
||
| // Now, invoke nullif on the result | ||
| primitive_bool_array_op!(args[0], *cond_array, nullif) | ||
| } | ||
|
|
||
| /// Currently supported types by the nullif function. | ||
| /// The order of these types correspond to the order on which coercion applies | ||
| /// This should thus be from least informative to most informative | ||
| pub static SUPPORTED_NULLIF_TYPES: &'static [DataType] = &[ | ||
| DataType::Boolean, | ||
| DataType::UInt8, | ||
| DataType::UInt16, | ||
| DataType::UInt32, | ||
| DataType::UInt64, | ||
| DataType::Int8, | ||
| DataType::Int16, | ||
| DataType::Int32, | ||
| DataType::Int64, | ||
| DataType::Float32, | ||
| DataType::Float64, | ||
| ]; | ||
|
||
|
|
||
| /// Not expression | ||
| #[derive(Debug)] | ||
| pub struct NotExpr { | ||
|
|
@@ -3151,6 +3225,70 @@ mod tests { | |
| ) | ||
| } | ||
|
|
||
| #[test] | ||
| fn nullif_int32() -> Result<()> { | ||
| let a = Int32Array::from(vec![ | ||
| Some(1), | ||
| Some(2), | ||
| None, | ||
| None, | ||
| Some(3), | ||
| None, | ||
| None, | ||
| Some(4), | ||
| Some(5), | ||
| ]); | ||
| let a = Arc::new(a); | ||
| let a_len = a.len(); | ||
|
|
||
| let lit_array = Arc::new(Int32Array::from(vec![2; a.len()])); | ||
|
|
||
| let result = nullif_func(&[a.clone(), lit_array])?; | ||
|
|
||
| assert_eq!(result.len(), a_len); | ||
|
|
||
| let expected = Int32Array::from(vec![ | ||
| Some(1), | ||
| None, | ||
| None, | ||
| None, | ||
| Some(3), | ||
| None, | ||
| None, | ||
| Some(4), | ||
| Some(5), | ||
| ]); | ||
| assert_array_eq::<Int32Type>(expected, result); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| // Ensure that arrays with no nulls can also invoke NULLIF() correctly | ||
| fn nullif_int32_nonulls() -> Result<()> { | ||
| let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); | ||
| let a = Arc::new(a); | ||
| let a_len = a.len(); | ||
|
|
||
| let lit_array = Arc::new(Int32Array::from(vec![1; a.len()])); | ||
|
|
||
| let result = nullif_func(&[a.clone(), lit_array])?; | ||
| assert_eq!(result.len(), a_len); | ||
|
|
||
| let expected = Int32Array::from(vec![ | ||
| None, | ||
| Some(3), | ||
| Some(10), | ||
| Some(7), | ||
| Some(8), | ||
| None, | ||
| Some(2), | ||
| Some(4), | ||
| Some(5), | ||
| ]); | ||
| assert_array_eq::<Int32Type>(expected, result); | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn aggregate( | ||
| batch: &RecordBatch, | ||
| agg: Arc<dyn AggregateExpr>, | ||
|
|
@@ -3275,7 +3413,11 @@ mod tests { | |
| .expect("Actual array should unwrap to type of expected array"); | ||
|
|
||
| for i in 0..expected.len() { | ||
| assert_eq!(expected.value(i), actual.value(i)); | ||
| if expected.is_null(i) { | ||
| assert!(actual.is_null(i)); | ||
| } else { | ||
| assert_eq!(expected.value(i), actual.value(i)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.