diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index fe1a7de6d8f..9e0741c263c 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -23,16 +23,18 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::logicalplan::ScalarValue; +use crate::logicalplan::{Operator, ScalarValue}; use arrow::array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::array::{ Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, }; +use arrow::compute::kernels::boolean::{and, or}; use arrow::compute::kernels::cast::cast; +use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; @@ -197,6 +199,140 @@ pub fn sum(expr: Arc) -> Arc { Arc::new(Sum::new(expr)) } +/// Invoke a compute kernel on a pair of arrays +macro_rules! compute_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new($OP(&ll, &rr)?)) + }}; +} + +/// Invoke a compute kernel on a pair of arrays +macro_rules! comparison_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + match $LEFT.data_type() { + DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), + DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), + DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array), + DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array), + DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array), + DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array), + DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array), + DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array), + DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array), + DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array), + other => Err(ExecutionError::General(format!( + "Unsupported data type {:?}", + other + ))), + } + }}; +} + +/// Invoke a boolean kernel on a pair of arrays +macro_rules! boolean_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::() + .expect("boolean_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::() + .expect("boolean_op failed to downcast array"); + Ok(Arc::new($OP(&ll, &rr)?)) + }}; +} +/// Binary expression +pub struct BinaryExpr { + left: Arc, + op: Operator, + right: Arc, +} + +impl BinaryExpr { + /// Create new binary expression + pub fn new( + left: Arc, + op: Operator, + right: Arc, + ) -> Self { + Self { left, op, right } + } +} + +impl PhysicalExpr for BinaryExpr { + fn name(&self) -> String { + format!("{:?}", self.op) + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.left.data_type(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let left = self.left.evaluate(batch)?; + let right = self.right.evaluate(batch)?; + if left.data_type() != right.data_type() { + return Err(ExecutionError::General(format!( + "Cannot evaluate binary expression {:?} with types {:?} and {:?}", + self.op, + left.data_type(), + right.data_type() + ))); + } + match &self.op { + Operator::Lt => comparison_op!(left, right, lt), + Operator::LtEq => comparison_op!(left, right, lt_eq), + Operator::Gt => comparison_op!(left, right, gt), + Operator::GtEq => comparison_op!(left, right, gt_eq), + Operator::Eq => comparison_op!(left, right, eq), + Operator::NotEq => comparison_op!(left, right, neq), + Operator::And => { + if left.data_type() == &DataType::Boolean { + boolean_op!(left, right, and) + } else { + return Err(ExecutionError::General(format!( + "Cannot evaluate binary expression {:?} with types {:?} and {:?}", + self.op, + left.data_type(), + right.data_type() + ))); + } + } + Operator::Or => { + if left.data_type() == &DataType::Boolean { + boolean_op!(left, right, or) + } else { + return Err(ExecutionError::General(format!( + "Cannot evaluate binary expression {:?} with types {:?} and {:?}", + self.op, + left.data_type(), + right.data_type() + ))); + } + } + _ => Err(ExecutionError::General("Unsupported operator".to_string())), + } + } +} + +/// Create a binary expression +pub fn binary( + l: Arc, + op: Operator, + r: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(l, op, r)) +} + /// CAST expression casts an expression to a specific data type pub struct CastExpr { /// The expression to cast @@ -335,6 +471,71 @@ mod tests { use arrow::array::BinaryArray; use arrow::datatypes::*; + #[test] + fn binary_comparison() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![1, 2, 4, 8, 16]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + )?; + + // expression: "a < b" + let lt = binary(col(0), Operator::Lt, col(1)); + let result = lt.evaluate(&batch)?; + assert_eq!(result.len(), 5); + + let expected = vec![false, false, true, true, true]; + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + for i in 0..5 { + assert_eq!(result.value(i), expected[i]); + } + + Ok(()) + } + + #[test] + fn binary_nested() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let a = Int32Array::from(vec![2, 4, 6, 8, 10]); + let b = Int32Array::from(vec![2, 5, 4, 8, 8]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + )?; + + // expression: "a < b OR a == b" + let expr = binary( + binary(col(0), Operator::Lt, col(1)), + Operator::Or, + binary(col(0), Operator::Eq, col(1)), + ); + let result = expr.evaluate(&batch)?; + assert_eq!(result.len(), 5); + + let expected = vec![true, true, false, true, false]; + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + for i in 0..5 { + print!("{}", i); + assert_eq!(result.value(i), expected[i]); + } + + Ok(()) + } + #[test] fn literal_i32() -> Result<()> { // create an arbitrary record bacth