diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index e29c338bfd1..bf8c957b66b 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -28,6 +28,10 @@ use arrow::array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; +use arrow::array::{ + Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, + Int8Builder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, +}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; @@ -192,12 +196,109 @@ pub fn sum(expr: Arc) -> Arc { Arc::new(Sum::new(expr)) } +/// Represents a non-null literal value +pub struct Literal { + value: ScalarValue, +} + +impl Literal { + /// Create a literal value expression + pub fn new(value: ScalarValue) -> Self { + Self { value } + } +} + +/// Build array containing the same literal value repeated. This is necessary because the Arrow +/// memory model does not have the concept of a scalar value currently. +macro_rules! build_literal_array { + ($BATCH:ident, $BUILDER:ident, $VALUE:expr) => {{ + let mut builder = $BUILDER::new($BATCH.num_rows()); + for _ in 0..$BATCH.num_rows() { + builder.append_value($VALUE)?; + } + Ok(Arc::new(builder.finish())) + }}; +} + +impl PhysicalExpr for Literal { + fn name(&self) -> String { + "lit".to_string() + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.value.get_datatype()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + match &self.value { + ScalarValue::Int8(value) => build_literal_array!(batch, Int8Builder, *value), + ScalarValue::Int16(value) => { + build_literal_array!(batch, Int16Builder, *value) + } + ScalarValue::Int32(value) => { + build_literal_array!(batch, Int32Builder, *value) + } + ScalarValue::Int64(value) => { + build_literal_array!(batch, Int64Builder, *value) + } + ScalarValue::UInt8(value) => { + build_literal_array!(batch, UInt8Builder, *value) + } + ScalarValue::UInt16(value) => { + build_literal_array!(batch, UInt16Builder, *value) + } + ScalarValue::UInt32(value) => { + build_literal_array!(batch, UInt32Builder, *value) + } + ScalarValue::UInt64(value) => { + build_literal_array!(batch, UInt64Builder, *value) + } + ScalarValue::Float32(value) => { + build_literal_array!(batch, Float32Builder, *value) + } + ScalarValue::Float64(value) => { + build_literal_array!(batch, Float64Builder, *value) + } + other => Err(ExecutionError::General(format!( + "Unsupported literal type {:?}", + other + ))), + } + } +} + +/// Create a literal expression +pub fn lit(value: ScalarValue) -> Arc { + Arc::new(Literal::new(value)) +} + #[cfg(test)] mod tests { use super::*; use crate::error::Result; use arrow::datatypes::*; + #[test] + fn literal_i32() -> Result<()> { + // create an arbitrary record bacth + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // create and evaluate a literal expression + let literal_expr = lit(ScalarValue::Int32(42)); + let literal_array = literal_expr.evaluate(&batch)?; + let literal_array = literal_array.as_any().downcast_ref::().unwrap(); + + // note that the contents of the literal array are unrelated to the batch contents except for the length of the array + assert_eq!(literal_array.len(), 5); // 5 rows in the batch + for i in 0..literal_array.len() { + assert_eq!(literal_array.value(i), 42); + } + + Ok(()) + } + #[test] fn sum_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);