diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0550741a9f9..1fcdcc59c9a 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -38,7 +38,7 @@ use crate::{ use crate::{ physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, - type_coercion::can_coerce_from, udf::ScalarUDF, + type_casting::can_cast_types, udf::ScalarUDF, }, sql::parser::FileType, }; @@ -323,12 +323,13 @@ impl Expr { /// /// # Errors /// - /// This function errors when it is impossible to cast the expression to the target [arrow::datatypes::DataType]. + /// This function errors when it is impossible to cast the + /// expression to the target [arrow::datatypes::DataType]. pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self.clone()) - } else if can_coerce_from(cast_to_type, &this_type) { + } else if can_cast_types(&this_type, cast_to_type) { Ok(Expr::Cast { expr: Box::new(self.clone()), data_type: cast_to_type.clone(), diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 4c9029e7195..06eb3d51d8d 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -24,7 +24,9 @@ use std::{cell::RefCell, convert::TryFrom}; use crate::error::{ExecutionError, Result}; use crate::logical_plan::Operator; -use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::physical_plan::{ + type_casting::can_cast_types, Accumulator, AggregateExpr, PhysicalExpr, +}; use crate::scalar::ScalarValue; use arrow::array::{ Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, @@ -1532,7 +1534,10 @@ impl PhysicalExpr for CastExpr { } } -/// Returns a cast operation, if casting needed. +/// Returns a physical cast operation that casts `expr` to `cast_type` +/// if casting is needed. +/// +/// Note that such casts may lose type information pub fn cast( expr: Arc, input_schema: &Schema, @@ -1540,19 +1545,12 @@ pub fn cast( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - return Ok(expr.clone()); - } - if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if is_numeric(&expr_type) - && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) - { + Ok(expr.clone()) + } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr { expr, cast_type })) } else { Err(ExecutionError::General(format!( - "Invalid CAST from {:?} to {:?}", + "Unsupported CAST from {:?} to {:?}", expr_type, cast_type ))) } @@ -1992,9 +1990,10 @@ mod tests { #[test] fn invalid_cast() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let result = cast(col("a"), &schema, DataType::Int32); - result.expect_err("Invalid CAST from Utf8 to Int32"); + // Ensure a useful error happens at plan time if invalid casts are used + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let result = cast(col("a"), &schema, DataType::LargeBinary); + result.expect_err("expected Invalid CAST"); Ok(()) } diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index 3565610ca99..7d46249b996 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -207,6 +207,7 @@ pub mod planner; pub mod projection; pub mod sort; pub mod string_expressions; +pub mod type_casting; pub mod type_coercion; pub mod udaf; pub mod udf; diff --git a/rust/datafusion/src/physical_plan/type_casting.rs b/rust/datafusion/src/physical_plan/type_casting.rs new file mode 100644 index 00000000000..3f2fda16d31 --- /dev/null +++ b/rust/datafusion/src/physical_plan/type_casting.rs @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides a way of checking what type casts are +//! supported at planning time for DataFusion. Since DataFusion uses +//! the Arrow `cast` compute kernel, the supported casts are the same +//! as the Arrow casts. +//! +//! The rules in this module are designed to be redundant with the +//! rules in the Arrow `cast` kernel. The redundancy is needed so that +//! DataFusion can generate an error at plan time rather than during +//! execution (which could happen many hours after execution starts, +//! when the query finally reaches that point) +//! + +use arrow::datatypes::*; + +/// Return true if a value of type `from_type` can be cast into a +/// value of `to_type`. Note that such as cast may be lossy. For +/// lossless type conversions, see the `type_coercion` module +/// +/// See the module level documentation for more detail on casting +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + if from_type == to_type { + return true; + } + + // Note this is meant to mirror the structure in arrow/src/compute/kernels/cast.rs + match (from_type, to_type) { + (Struct(_), _) => false, + (_, Struct(_)) => false, + (List(list_from), List(list_to)) => can_cast_types(list_from, list_to), + (List(_), _) => false, + (_, List(list_to)) => can_cast_types(from_type, list_to), + (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { + can_cast_types(from_value_type, to_value_type) + } + (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), + + (_, Boolean) => is_numeric_type(from_type), + (Boolean, _) => is_numeric_type(from_type) || from_type == &Utf8, + (Utf8, _) => is_numeric_type(to_type), + (_, Utf8) => is_numeric_type(from_type) || from_type == &Binary, + + // start numeric casts + (UInt8, UInt16) => true, + (UInt8, UInt32) => true, + (UInt8, UInt64) => true, + (UInt8, Int8) => true, + (UInt8, Int16) => true, + (UInt8, Int32) => true, + (UInt8, Int64) => true, + (UInt8, Float32) => true, + (UInt8, Float64) => true, + + (UInt16, UInt8) => true, + (UInt16, UInt32) => true, + (UInt16, UInt64) => true, + (UInt16, Int8) => true, + (UInt16, Int16) => true, + (UInt16, Int32) => true, + (UInt16, Int64) => true, + (UInt16, Float32) => true, + (UInt16, Float64) => true, + + (UInt32, UInt8) => true, + (UInt32, UInt16) => true, + (UInt32, UInt64) => true, + (UInt32, Int8) => true, + (UInt32, Int16) => true, + (UInt32, Int32) => true, + (UInt32, Int64) => true, + (UInt32, Float32) => true, + (UInt32, Float64) => true, + + (UInt64, UInt8) => true, + (UInt64, UInt16) => true, + (UInt64, UInt32) => true, + (UInt64, Int8) => true, + (UInt64, Int16) => true, + (UInt64, Int32) => true, + (UInt64, Int64) => true, + (UInt64, Float32) => true, + (UInt64, Float64) => true, + + (Int8, UInt8) => true, + (Int8, UInt16) => true, + (Int8, UInt32) => true, + (Int8, UInt64) => true, + (Int8, Int16) => true, + (Int8, Int32) => true, + (Int8, Int64) => true, + (Int8, Float32) => true, + (Int8, Float64) => true, + + (Int16, UInt8) => true, + (Int16, UInt16) => true, + (Int16, UInt32) => true, + (Int16, UInt64) => true, + (Int16, Int8) => true, + (Int16, Int32) => true, + (Int16, Int64) => true, + (Int16, Float32) => true, + (Int16, Float64) => true, + + (Int32, UInt8) => true, + (Int32, UInt16) => true, + (Int32, UInt32) => true, + (Int32, UInt64) => true, + (Int32, Int8) => true, + (Int32, Int16) => true, + (Int32, Int64) => true, + (Int32, Float32) => true, + (Int32, Float64) => true, + + (Int64, UInt8) => true, + (Int64, UInt16) => true, + (Int64, UInt32) => true, + (Int64, UInt64) => true, + (Int64, Int8) => true, + (Int64, Int16) => true, + (Int64, Int32) => true, + (Int64, Float32) => true, + (Int64, Float64) => true, + + (Float32, UInt8) => true, + (Float32, UInt16) => true, + (Float32, UInt32) => true, + (Float32, UInt64) => true, + (Float32, Int8) => true, + (Float32, Int16) => true, + (Float32, Int32) => true, + (Float32, Int64) => true, + (Float32, Float64) => true, + + (Float64, UInt8) => true, + (Float64, UInt16) => true, + (Float64, UInt32) => true, + (Float64, UInt64) => true, + (Float64, Int8) => true, + (Float64, Int16) => true, + (Float64, Int32) => true, + (Float64, Int64) => true, + (Float64, Float32) => true, + // end numeric casts + + // temporal casts + (Int32, Date32(_)) => true, + (Int32, Time32(_)) => true, + (Date32(_), Int32) => true, + (Time32(_), Int32) => true, + (Int64, Date64(_)) => true, + (Int64, Time64(_)) => true, + (Date64(_), Int64) => true, + (Time64(_), Int64) => true, + (Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => true, + (Date64(DateUnit::Millisecond), Date32(DateUnit::Day)) => true, + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true, + (Time32(_), Time64(_)) => true, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true, + (Time64(_), Time32(to_unit)) => match to_unit { + TimeUnit::Second => true, + TimeUnit::Millisecond => true, + _ => false, + }, + (Timestamp(_, _), Int64) => true, + (Int64, Timestamp(_, _)) => true, + (Timestamp(_, _), Timestamp(_, _)) => true, + (Timestamp(_, _), Date32(_)) => true, + (Timestamp(_, _), Date64(_)) => true, + // date64 to timestamp might not make sense, + + // end temporal casts + (_, _) => false, + } +} + +fn is_numeric_type(t: &DataType) -> bool { + use self::DataType::*; + match t { + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 + | Float64 => true, + _ => false, + } +} + +#[cfg(test)] +mod tests { + // The purpose of this test is to verify that the rules of type + // casting between Arrow and DataFusion remain in sync. + + // At a high level, each test attempts to cast the input arrays + // into the target type using the cast kernel and verifies the + // compatibility between `can_cast_from` and the cast kernel + + #[test] + fn test_casting() { + //let arrays = vec![]; + } +}