diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index bf8c957b66b..fe1a7de6d8f 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -32,6 +32,7 @@ use arrow::array::{ Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, }; +use arrow::compute::kernels::cast::cast; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; @@ -196,6 +197,61 @@ pub fn sum(expr: Arc) -> Arc { Arc::new(Sum::new(expr)) } +/// CAST expression casts an expression to a specific data type +pub struct CastExpr { + /// The expression to cast + expr: Arc, + /// The data type to cast to + cast_type: DataType, +} + +/// Determine if a DataType is numeric or not +fn is_numeric(dt: &DataType) -> bool { + match dt { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true, + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true, + DataType::Float16 | DataType::Float32 | DataType::Float64 => true, + _ => false, + } +} + +impl CastExpr { + /// Create a CAST expression + pub fn try_new( + expr: Arc, + input_schema: &Schema, + cast_type: DataType, + ) -> Result { + let expr_type = expr.data_type(input_schema)?; + // numbers can be cast to numbers and strings + if is_numeric(&expr_type) + && (is_numeric(&cast_type) || cast_type == DataType::Utf8) + { + Ok(Self { expr, cast_type }) + } else { + Err(ExecutionError::General(format!( + "Invalid CAST from {:?} to {:?}", + expr_type, cast_type + ))) + } + } +} + +impl PhysicalExpr for CastExpr { + fn name(&self) -> String { + "CAST".to_string() + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.cast_type.clone()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let value = self.expr.evaluate(batch)?; + Ok(cast(&value, &self.cast_type)?) + } +} + /// Represents a non-null literal value pub struct Literal { value: ScalarValue, @@ -276,6 +332,7 @@ pub fn lit(value: ScalarValue) -> Arc { mod tests { use super::*; use crate::error::Result; + use arrow::array::BinaryArray; use arrow::datatypes::*; #[test] @@ -299,6 +356,56 @@ mod tests { Ok(()) } + #[test] + fn cast_i32_to_u32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + let cast = CastExpr::try_new(col(0), &schema, DataType::UInt32)?; + let result = cast.evaluate(&batch)?; + assert_eq!(result.len(), 5); + + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to UInt32Array"); + assert_eq!(result.value(0), 1_u32); + + Ok(()) + } + + #[test] + fn cast_i32_to_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + let cast = CastExpr::try_new(col(0), &schema, DataType::Utf8)?; + let result = cast.evaluate(&batch)?; + assert_eq!(result.len(), 5); + + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to BinaryArray"); + assert_eq!(result.value(0), "1".as_bytes()); + + Ok(()) + } + + #[test] + fn invalid_cast() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + match CastExpr::try_new(col(0), &schema, DataType::Int32) { + Err(ExecutionError::General(ref str)) => { + assert_eq!(str, "Invalid CAST from Utf8 to Int32"); + Ok(()) + } + _ => panic!(), + } + } + #[test] fn sum_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);