diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index 08c6a2b3042..0b6e172d30a 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -44,6 +44,168 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::{array::*, compute::take}; +/// Return true if a value of type `from_type` can be cast into a +/// value of `to_type`. Note that such as cast may be lossy. +/// +/// If this function returns true to stay consistent with the `cast` kernel below. +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + (Struct(_), _) => false, + (_, Struct(_)) => false, + (List(list_from), List(list_to)) => can_cast_types(list_from, list_to), + (List(_), _) => false, + (_, List(list_to)) => can_cast_types(from_type, list_to), + (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { + can_cast_types(from_value_type, to_value_type) + } + (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), + + (_, Boolean) => DataType::is_numeric(from_type), + (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8, + (Utf8, _) => DataType::is_numeric(to_type), + (_, Utf8) => DataType::is_numeric(from_type) || from_type == &Binary, + + // start numeric casts + (UInt8, UInt16) => true, + (UInt8, UInt32) => true, + (UInt8, UInt64) => true, + (UInt8, Int8) => true, + (UInt8, Int16) => true, + (UInt8, Int32) => true, + (UInt8, Int64) => true, + (UInt8, Float32) => true, + (UInt8, Float64) => true, + + (UInt16, UInt8) => true, + (UInt16, UInt32) => true, + (UInt16, UInt64) => true, + (UInt16, Int8) => true, + (UInt16, Int16) => true, + (UInt16, Int32) => true, + (UInt16, Int64) => true, + (UInt16, Float32) => true, + (UInt16, Float64) => true, + + (UInt32, UInt8) => true, + (UInt32, UInt16) => true, + (UInt32, UInt64) => true, + (UInt32, Int8) => true, + (UInt32, Int16) => true, + (UInt32, Int32) => true, + (UInt32, Int64) => true, + (UInt32, Float32) => true, + (UInt32, Float64) => true, + + (UInt64, UInt8) => true, + (UInt64, UInt16) => true, + (UInt64, UInt32) => true, + (UInt64, Int8) => true, + (UInt64, Int16) => true, + (UInt64, Int32) => true, + (UInt64, Int64) => true, + (UInt64, Float32) => true, + (UInt64, Float64) => true, + + (Int8, UInt8) => true, + (Int8, UInt16) => true, + (Int8, UInt32) => true, + (Int8, UInt64) => true, + (Int8, Int16) => true, + (Int8, Int32) => true, + (Int8, Int64) => true, + (Int8, Float32) => true, + (Int8, Float64) => true, + + (Int16, UInt8) => true, + (Int16, UInt16) => true, + (Int16, UInt32) => true, + (Int16, UInt64) => true, + (Int16, Int8) => true, + (Int16, Int32) => true, + (Int16, Int64) => true, + (Int16, Float32) => true, + (Int16, Float64) => true, + + (Int32, UInt8) => true, + (Int32, UInt16) => true, + (Int32, UInt32) => true, + (Int32, UInt64) => true, + (Int32, Int8) => true, + (Int32, Int16) => true, + (Int32, Int64) => true, + (Int32, Float32) => true, + (Int32, Float64) => true, + + (Int64, UInt8) => true, + (Int64, UInt16) => true, + (Int64, UInt32) => true, + (Int64, UInt64) => true, + (Int64, Int8) => true, + (Int64, Int16) => true, + (Int64, Int32) => true, + (Int64, Float32) => true, + (Int64, Float64) => true, + + (Float32, UInt8) => true, + (Float32, UInt16) => true, + (Float32, UInt32) => true, + (Float32, UInt64) => true, + (Float32, Int8) => true, + (Float32, Int16) => true, + (Float32, Int32) => true, + (Float32, Int64) => true, + (Float32, Float64) => true, + + (Float64, UInt8) => true, + (Float64, UInt16) => true, + (Float64, UInt32) => true, + (Float64, UInt64) => true, + (Float64, Int8) => true, + (Float64, Int16) => true, + (Float64, Int32) => true, + (Float64, Int64) => true, + (Float64, Float32) => true, + // end numeric casts + + // temporal casts + (Int32, Date32(_)) => true, + (Int32, Time32(_)) => true, + (Date32(_), Int32) => true, + (Time32(_), Int32) => true, + (Int64, Date64(_)) => true, + (Int64, Time64(_)) => true, + (Date64(_), Int64) => true, + (Time64(_), Int64) => true, + (Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => true, + (Date64(DateUnit::Millisecond), Date32(DateUnit::Day)) => true, + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true, + (Time32(_), Time64(_)) => true, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true, + (Time64(_), Time32(to_unit)) => match to_unit { + TimeUnit::Second => true, + TimeUnit::Millisecond => true, + _ => false, + }, + (Timestamp(_, _), Int64) => true, + (Int64, Timestamp(_, _)) => true, + (Timestamp(_, _), Timestamp(_, _)) => true, + (Timestamp(_, _), Date32(_)) => true, + (Timestamp(_, _), Date64(_)) => true, + // date64 to timestamp might not make sense, + + // end temporal casts + (_, _) => false, + } +} + /// Cast `array` to the provided data type and return a new Array with /// type `to_type`, if possible. /// @@ -356,11 +518,24 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { // temporal casts (Int32, Date32(_)) => cast_array_data::(array, to_type.clone()), - (Int32, Time32(_)) => cast_array_data::(array, to_type.clone()), + (Int32, Time32(TimeUnit::Second)) => { + cast_array_data::(array, to_type.clone()) + } + (Int32, Time32(TimeUnit::Millisecond)) => { + cast_array_data::(array, to_type.clone()) + } + // No support for microsecond/nanosecond with i32 (Date32(_), Int32) => cast_array_data::(array, to_type.clone()), (Time32(_), Int32) => cast_array_data::(array, to_type.clone()), (Int64, Date64(_)) => cast_array_data::(array, to_type.clone()), - (Int64, Time64(_)) => cast_array_data::(array, to_type.clone()), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => { + cast_array_data::(array, to_type.clone()) + } + (Int64, Time64(TimeUnit::Nanosecond)) => { + cast_array_data::(array, to_type.clone()) + } + (Date64(_), Int64) => cast_array_data::(array, to_type.clone()), (Time64(_), Int64) => cast_array_data::(array, to_type.clone()), (Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => { @@ -549,7 +724,18 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { (Timestamp(from_unit, _), Date64(_)) => { let from_size = time_unit_multiple(&from_unit); let to_size = MILLISECONDS; - if from_size != to_size { + + // Scale time_array by (to_size / from_size) using a + // single integer operation, but need to avoid integer + // math rounding down to zero + + if to_size > from_size { + let time_array = Date64Array::from(array.data()); + Ok(Arc::new(multiply( + &time_array, + &Date64Array::from(vec![to_size / from_size; array.len()]), + )?) as ArrayRef) + } else if to_size < from_size { let time_array = Date64Array::from(array.data()); Ok(Arc::new(divide( &time_array, @@ -2477,4 +2663,290 @@ mod tests { }) .collect() } + + #[test] + fn test_can_cast_types() { + // this function attempts to ensure that can_cast_types stays + // in sync with cast. It simply tries all combinations of + // types and makes sure that if `can_cast_types` returns + // true, so does `cast` + + let all_types = get_all_types(); + + for array in get_arrays_of_all_types() { + for to_type in &all_types { + println!("Test casting {:?} --> {:?}", array.data_type(), to_type); + let cast_result = cast(&array, &to_type); + let reported_cast_ability = can_cast_types(array.data_type(), to_type); + + // check for mismatch + match (cast_result, reported_cast_ability) { + (Ok(_), false) => { + panic!("Was able to cast array from {:?} to {:?} but can_cast_types reported false", + array.data_type(), to_type) + }, + (Err(e), true) => { + panic!("Was not able to cast array from {:?} to {:?} but can_cast_types reported true. \ + Error was {:?}", + array.data_type(), to_type, e) + }, + // otherwise it was a match + _=> {}, + }; + } + } + } + + /// Create instances of arrays with varying types for cast tests + fn get_arrays_of_all_types() -> Vec { + let tz_name = Arc::new(String::from("America/New_York")); + let binary_data: Vec<&[u8]> = vec![b"foo", b"bar"]; + vec![ + Arc::new(BinaryArray::from(binary_data.clone())), + Arc::new(LargeBinaryArray::from(binary_data.clone())), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + Arc::new(make_list_array()), + Arc::new(make_large_list_array()), + Arc::new(make_fixed_size_list_array()), + Arc::new(make_fixed_size_binary_array()), + Arc::new(StructArray::from(vec![ + ( + Field::new("a", DataType::Boolean, false), + Arc::new(BooleanArray::from(vec![false, false, true, true])) + as Arc, + ), + ( + Field::new("b", DataType::Int32, false), + Arc::new(Int32Array::from(vec![42, 28, 19, 31])), + ), + ])), + //Arc::new(make_union_array()), + Arc::new(NullArray::new(10)), + Arc::new(StringArray::from(vec!["foo", "bar"])), + Arc::new(LargeStringArray::from(vec!["foo", "bar"])), + Arc::new(BooleanArray::from(vec![true, false])), + Arc::new(Int8Array::from(vec![1, 2])), + Arc::new(Int16Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(UInt8Array::from(vec![1, 2])), + Arc::new(UInt16Array::from(vec![1, 2])), + Arc::new(UInt32Array::from(vec![1, 2])), + Arc::new(UInt64Array::from(vec![1, 2])), + Arc::new(Float32Array::from(vec![1.0, 2.0])), + Arc::new(Float64Array::from(vec![1.0, 2.0])), + Arc::new(TimestampSecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampMillisecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampMicrosecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampNanosecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampSecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampMillisecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampMicrosecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampNanosecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(Date32Array::from(vec![1000, 2000])), + Arc::new(Date64Array::from(vec![1000, 2000])), + Arc::new(Time32SecondArray::from(vec![1000, 2000])), + Arc::new(Time32MillisecondArray::from(vec![1000, 2000])), + Arc::new(Time64MicrosecondArray::from(vec![1000, 2000])), + Arc::new(Time64NanosecondArray::from(vec![1000, 2000])), + Arc::new(IntervalYearMonthArray::from(vec![1000, 2000])), + Arc::new(IntervalDayTimeArray::from(vec![1000, 2000])), + Arc::new(DurationSecondArray::from(vec![1000, 2000])), + Arc::new(DurationMillisecondArray::from(vec![1000, 2000])), + Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), + Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), + ] + } + + fn make_list_array() -> ListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + + // Construct a list array from the above two + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build(); + ListArray::from(list_data) + } + + fn make_large_list_array() -> LargeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from(&[0i64, 3, 6, 8].to_byte_slice()); + + // Construct a list array from the above two + let list_data_type = DataType::LargeList(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build(); + LargeListArray::from(list_data) + } + + fn make_fixed_size_list_array() -> FixedSizeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from( + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9].to_byte_slice(), + )) + .build(); + + // Construct a fixed size list array from the above two + let list_data_type = DataType::FixedSizeList(Box::new(DataType::Int32), 2); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data.clone()) + .build(); + FixedSizeListArray::from(list_data) + } + + fn make_fixed_size_binary_array() -> FixedSizeBinaryArray { + let values: [u8; 15] = *b"hellotherearrow"; + + let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) + .len(3) + .add_buffer(Buffer::from(&values[..])) + .build(); + FixedSizeBinaryArray::from(array_data) + } + + fn make_union_array() -> UnionArray { + let mut builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", false).unwrap(); + builder.build().unwrap() + } + + /// Creates a dictionary with primitive dictionary values, and keys of type K + fn make_dictionary_primitive() -> ArrayRef { + let keys_builder = PrimitiveBuilder::::new(2); + // Pick Int32 arbitrarily for dictionary values + let values_builder = PrimitiveBuilder::::new(2); + let mut b = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + b.append(1).unwrap(); + b.append(2).unwrap(); + Arc::new(b.finish()) + } + + /// Creates a dictionary with utf8 values, and keys of type K + fn make_dictionary_utf8() -> ArrayRef { + let keys_builder = PrimitiveBuilder::::new(2); + // Pick Int32 arbitrarily for dictionary values + let values_builder = StringBuilder::new(2); + let mut b = StringDictionaryBuilder::new(keys_builder, values_builder); + b.append("foo").unwrap(); + b.append("bar").unwrap(); + Arc::new(b.finish()) + } + + // Get a selection of datatypes to try and cast to + fn get_all_types() -> Vec { + use DataType::*; + let tz_name = Arc::new(String::from("America/New_York")); + + vec![ + Null, + Boolean, + Int8, + Int16, + Int32, + UInt64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + Timestamp(TimeUnit::Second, None), + Timestamp(TimeUnit::Millisecond, None), + Timestamp(TimeUnit::Microsecond, None), + Timestamp(TimeUnit::Nanosecond, None), + Timestamp(TimeUnit::Second, Some(tz_name.clone())), + Timestamp(TimeUnit::Millisecond, Some(tz_name.clone())), + Timestamp(TimeUnit::Microsecond, Some(tz_name.clone())), + Timestamp(TimeUnit::Nanosecond, Some(tz_name.clone())), + Date32(DateUnit::Day), + Date64(DateUnit::Day), + Date32(DateUnit::Millisecond), + Date64(DateUnit::Millisecond), + Time32(TimeUnit::Second), + Time32(TimeUnit::Millisecond), + Time64(TimeUnit::Microsecond), + Time64(TimeUnit::Nanosecond), + Duration(TimeUnit::Second), + Duration(TimeUnit::Millisecond), + Duration(TimeUnit::Microsecond), + Duration(TimeUnit::Nanosecond), + Interval(IntervalUnit::YearMonth), + Interval(IntervalUnit::DayTime), + Binary, + FixedSizeBinary(10), + LargeBinary, + Utf8, + LargeUtf8, + List(Box::new(DataType::Int8)), + List(Box::new(DataType::Utf8)), + FixedSizeList(Box::new(DataType::Int8), 10), + FixedSizeList(Box::new(DataType::Utf8), 10), + LargeList(Box::new(DataType::Int8)), + LargeList(Box::new(DataType::Utf8)), + Struct(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, true), + ]), + Union(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, true), + ]), + Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), + Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + ] + } } diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs index 0d05f826d37..2db43062f2a 100644 --- a/rust/arrow/src/datatypes.rs +++ b/rust/arrow/src/datatypes.rs @@ -1129,6 +1129,16 @@ impl DataType { DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), } } + + /// Returns true if this type is numeric: (UInt*, Unit*, or Float*) + pub fn is_numeric(t: &DataType) -> bool { + use DataType::*; + match t { + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 + | Float64 => true, + _ => false, + } + } } impl Field { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index b8d0cc7fb82..6df92fe190e 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -25,7 +25,10 @@ use fmt::Debug; use std::{any::Any, collections::HashMap, collections::HashSet, fmt, sync::Arc}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, Field, Schema, SchemaRef}, +}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -37,8 +40,7 @@ use crate::{ }; use crate::{ physical_plan::{ - aggregates, expressions::binary_operator_data_type, functions, - type_coercion::can_coerce_from, udf::ScalarUDF, + aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, }, sql::parser::FileType, }; @@ -333,12 +335,13 @@ impl Expr { /// /// # Errors /// - /// This function errors when it is impossible to cast the expression to the target [arrow::datatypes::DataType]. + /// This function errors when it is impossible to cast the + /// expression to the target [arrow::datatypes::DataType]. pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self.clone()) - } else if can_coerce_from(cast_to_type, &this_type) { + } else if can_cast_types(&this_type, cast_to_type) { Ok(Expr::Cast { expr: Box::new(self.clone()), data_type: cast_to_type.clone(), diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 1f5dafdc19d..084f8186c5e 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -49,6 +49,7 @@ use arrow::{ }, datatypes::Field, }; +use compute::can_cast_types; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { @@ -1525,7 +1526,10 @@ impl PhysicalExpr for CastExpr { } } -/// Returns a cast operation, if casting needed. +/// Returns a physical cast operation that casts `expr` to `cast_type` +/// if casting is needed. +/// +/// Note that such casts may lose type information pub fn cast( expr: Arc, input_schema: &Schema, @@ -1533,19 +1537,12 @@ pub fn cast( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - return Ok(expr.clone()); - } - if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if is_numeric(&expr_type) - && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) - { + Ok(expr.clone()) + } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr { expr, cast_type })) } else { Err(ExecutionError::General(format!( - "Invalid CAST from {:?} to {:?}", + "Unsupported CAST from {:?} to {:?}", expr_type, cast_type ))) } @@ -1985,9 +1982,10 @@ mod tests { #[test] fn invalid_cast() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let result = cast(col("a"), &schema, DataType::Int32); - result.expect_err("Invalid CAST from Utf8 to Int32"); + // Ensure a useful error happens at plan time if invalid casts are used + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let result = cast(col("a"), &schema, DataType::LargeBinary); + result.expect_err("expected Invalid CAST"); Ok(()) }