diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index 22f26674261..4e1d84524d5 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -2106,13 +2106,13 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer, usize)> for StructArray { /// assert_eq!(array.keys().collect::>>(), vec![Some(0), Some(0), Some(1), Some(2)]); /// ``` pub struct DictionaryArray { - /// Array of keys, much like a PrimitiveArray + /// Array of keys, stored as a PrimitiveArray. data: ArrayDataRef, /// Pointer to the key values. raw_values: RawPtrBox, - /// Array of any values. + /// Array of dictionary values (can by any DataType). values: ArrayRef, /// Values are ordered. diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index 7a0dcd36433..88dc3dee20f 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Defines cast kernels for `ArrayRef`, allowing casting arrays between supported -//! datatypes. +//! Defines cast kernels for `ArrayRef`, to convert `Array`s between +//! supported datatypes. //! //! Example: //! @@ -38,13 +38,14 @@ use std::str; use std::sync::Arc; -use crate::array::*; use crate::buffer::Buffer; use crate::compute::kernels::arithmetic::{divide, multiply}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; +use crate::{array::*, compute::take}; -/// Cast array to provided data type +/// Cast `array` to the provided data type and return a new Array with +/// type `to_type`, if possible. /// /// Behavior: /// * Boolean to Utf8: `true` => '1', `false` => `0` @@ -125,6 +126,34 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { Ok(list_array) } + (Dictionary(index_type, _), _) => match **index_type { + DataType::Int8 => dictionary_cast::(array, to_type), + DataType::Int16 => dictionary_cast::(array, to_type), + DataType::Int32 => dictionary_cast::(array, to_type), + DataType::Int64 => dictionary_cast::(array, to_type), + DataType::UInt8 => dictionary_cast::(array, to_type), + DataType::UInt16 => dictionary_cast::(array, to_type), + DataType::UInt32 => dictionary_cast::(array, to_type), + DataType::UInt64 => dictionary_cast::(array, to_type), + _ => Err(ArrowError::ComputeError(format!( + "Casting from dictionary type {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (_, Dictionary(index_type, value_type)) => match **index_type { + DataType::Int8 => cast_to_dictionary::(array, value_type), + DataType::Int16 => cast_to_dictionary::(array, value_type), + DataType::Int32 => cast_to_dictionary::(array, value_type), + DataType::Int64 => cast_to_dictionary::(array, value_type), + DataType::UInt8 => cast_to_dictionary::(array, value_type), + DataType::UInt16 => cast_to_dictionary::(array, value_type), + DataType::UInt32 => cast_to_dictionary::(array, value_type), + DataType::UInt64 => cast_to_dictionary::(array, value_type), + _ => Err(ArrowError::ComputeError(format!( + "Casting from type {:?} to dictionary type {:?} not supported", + from_type, to_type, + ))), + }, (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), UInt16 => cast_numeric_to_bool::(array), @@ -740,10 +769,203 @@ where .collect() } +/// Attempts to cast an `ArrayDictionary` with index type K into +/// `to_type` for supported types. +/// +/// K is the key type +fn dictionary_cast( + array: &ArrayRef, + to_type: &DataType, +) -> Result { + use DataType::*; + + match to_type { + Dictionary(to_index_type, to_value_type) => { + let dict_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), + ) + })?; + + let keys_array: ArrayRef = Arc::new(dict_array.keys_array()); + let values_array: ArrayRef = dict_array.values(); + let cast_keys = cast(&keys_array, to_index_type)?; + let cast_values = cast(&values_array, to_value_type)?; + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > keys_array.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - keys_array.null_count(), + keys_array.data_type(), + to_index_type + ))); + } + + // keys are data, child_data is values (dictionary) + let data = Arc::new(ArrayData::new( + to_type.clone(), + cast_keys.len(), + Some(cast_keys.null_count()), + cast_keys + .data() + .null_bitmap() + .clone() + .map(|bitmap| bitmap.bits), + cast_keys.data().offset(), + cast_keys.data().buffers().to_vec(), + vec![cast_values.data()], + )); + + // create the appropriate array type + let new_array: ArrayRef = match **to_index_type { + Int8 => Arc::new(DictionaryArray::::from(data)), + Int16 => Arc::new(DictionaryArray::::from(data)), + Int32 => Arc::new(DictionaryArray::::from(data)), + Int64 => Arc::new(DictionaryArray::::from(data)), + UInt8 => Arc::new(DictionaryArray::::from(data)), + UInt16 => Arc::new(DictionaryArray::::from(data)), + UInt32 => Arc::new(DictionaryArray::::from(data)), + UInt64 => Arc::new(DictionaryArray::::from(data)), + _ => { + return Err(ArrowError::ComputeError(format!( + "Unsupported type {:?} for dictionary index", + to_index_type + ))) + } + }; + + Ok(new_array) + } + _ => unpack_dictionary::(array, to_type), + } +} + +// Unpack a dictionary where the keys are of type into a flattened array of type to_type +fn unpack_dictionary(array: &ArrayRef, to_type: &DataType) -> Result +where + K: ArrowDictionaryKeyType, +{ + let dict_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), + ) + })?; + + // attempt to cast the dict values to the target type + // use the take kernel to expand out the dictionary + let cast_dict_values = cast(&dict_array.values(), to_type)?; + + // Note take requires first casting the indicies to u32 + let keys_array: ArrayRef = Arc::new(dict_array.keys_array()); + let indicies = cast(&keys_array, &DataType::UInt32)?; + let u32_indicies = + indicies + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dict indicies to UInt32".to_string(), + ) + })?; + + take(&cast_dict_values, u32_indicies, None) +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +fn cast_to_dictionary( + array: &ArrayRef, + dict_value_type: &DataType, +) -> Result { + use DataType::*; + + match *dict_value_type { + Int8 => pack_numeric_to_dictionary::(array, dict_value_type), + Int16 => pack_numeric_to_dictionary::(array, dict_value_type), + Int32 => pack_numeric_to_dictionary::(array, dict_value_type), + Int64 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt8 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt16 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt32 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt64 => pack_numeric_to_dictionary::(array, dict_value_type), + Utf8 => pack_string_to_dictionary::(array), + _ => Err(ArrowError::ComputeError(format!( + "Internal Error: Unsupported output type for dictionary packing: {:?}", + dict_value_type + ))), + } +} + +// Packs the data from the primitive array of type to a +// DictionaryArray with keys of type K and values of value_type V +fn pack_numeric_to_dictionary( + array: &ArrayRef, + dict_value_type: &DataType, +) -> Result +where + K: ArrowDictionaryKeyType, + V: ArrowNumericType, +{ + // attempt to cast the source array values to the target value type (the dictionary values type) + let cast_values = cast(array, &dict_value_type)?; + let values = cast_values + .as_any() + .downcast_ref::>() + .unwrap(); + + let keys_builder = PrimitiveBuilder::::new(values.len()); + let values_builder = PrimitiveBuilder::::new(values.len()); + let mut b = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null()?; + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + +// Packs the data as a StringDictionaryArray, if possible, with the +// key types of K +fn pack_string_to_dictionary(array: &ArrayRef) -> Result +where + K: ArrowDictionaryKeyType, +{ + let cast_values = cast(array, &DataType::Utf8)?; + let values = cast_values.as_any().downcast_ref::().unwrap(); + + let keys_builder = PrimitiveBuilder::::new(values.len()); + let values_builder = StringBuilder::new(values.len()); + let mut b = StringDictionaryBuilder::new(keys_builder, values_builder); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null()?; + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + #[cfg(test)] mod tests { use super::*; - use crate::buffer::Buffer; + use crate::{buffer::Buffer, util::pretty::array_value_to_string}; #[test] fn test_cast_i32_to_f64() { @@ -2033,6 +2255,7 @@ mod tests { ); } + /// Convert `array` into a vector of strings by casting to data type dt fn get_cast_values(array: &ArrayRef, dt: &DataType) -> Vec where T: ArrowNumericType, @@ -2049,4 +2272,209 @@ mod tests { } v } + + #[test] + fn test_cast_utf8_dict() { + // FROM a dictionary with of Utf8 values + use DataType::*; + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + builder.append("one").unwrap(); + builder.append_null().unwrap(); + builder.append("three").unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["one", "null", "three"]; + + // Test casting TO StringArray + let cast_type = Utf8; + let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Test casting TO Dictionary (with different index sizes) + + let cast_type = Dictionary(Box::new(Int16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_primitive() { + use DataType::*; + // test converting from an array that has indexes of a type + // that are out of bounds for a particular other kind of + // index. + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = PrimitiveBuilder::::new(10); + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + builder.append(i).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{:?}", res); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + actual_error, + expected_error + ); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_utf8() { + use DataType::*; + // Same test as test_cast_dict_to_dict_bad_index_value but use + // string values (and encode the expected behavior here); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + let val = format!("val{}", i); + builder.append(&val).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{:?}", res); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + actual_error, + expected_error + ); + } + + #[test] + fn test_cast_primitive_dict() { + // FROM a dictionary with of INT32 values + use DataType::*; + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = PrimitiveBuilder::::new(10); + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + builder.append(1).unwrap(); + builder.append_null().unwrap(); + builder.append(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Test casting TO PrimitiveArray, different dictionary type + let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 succeeded"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Utf8); + + let cast_array = cast(&array, &Int64).expect("cast to int64 succeeded"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Int64); + } + + #[test] + fn test_cast_primitive_array_to_dict() { + use DataType::*; + + let mut builder = PrimitiveBuilder::::new(10); + builder.append_value(1).unwrap(); + builder.append_null().unwrap(); + builder.append_value(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Cast to a dictionary (same value type, Int32) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Cast to a dictionary (different value type, Int8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_string_array_to_dict() { + use DataType::*; + + let mut builder = StringBuilder::new(10); + builder.append_value("one").unwrap(); + builder.append_null().unwrap(); + builder.append_value("three").unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["one", "null", "three"]; + + // Cast to a dictionary (same value type, Utf8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + /// Print the `DictionaryArray` `array` as a vector of strings + fn array_to_strings(array: &ArrayRef) -> Vec { + (0..array.len()) + .map(|i| { + if array.is_null(i) { + "null".to_string() + } else { + array_value_to_string(array, i).expect("Convert array to String") + } + }) + .collect() + } } diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 4c9029e7195..198d8c27d0e 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -51,6 +51,8 @@ use arrow::{ datatypes::Field, }; +use super::type_coercion::can_coerce_from; + /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{}[{}]", name, state_name) @@ -1086,7 +1088,40 @@ impl fmt::Display for BinaryExpr { } } -// the type that both lhs and rhs can be casted to for the purpose of a string computation +/// Coercion rules for dictionary values (aka the type of the dictionary itself) +fn dictionary_value_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + numerical_coercion(lhs_type, rhs_type).or_else(|| string_coercion(lhs_type, rhs_type)) +} + +/// Coercion rules for Dictionaries: the type that both lhs and rhs +/// can be casted to for the purpose of a computation. +/// +/// It would likely be preferable to cast primitive values to +/// dictionaries, and thus avoid unpacking dictionary as well as doing +/// faster comparisons. However, the arrow compute kernels (e.g. eq) +/// don't have DictionaryArray support yet, so fall back to unpacking +/// the dictionaries +fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + match (lhs_type, rhs_type) { + ( + DataType::Dictionary(_lhs_index_type, lhs_value_type), + DataType::Dictionary(_rhs_index_type, rhs_value_type), + ) => dictionary_value_coercion(lhs_value_type, rhs_value_type), + (DataType::Dictionary(_index_type, value_type), _) => { + dictionary_value_coercion(value_type, rhs_type) + } + (_, DataType::Dictionary(_index_type, value_type)) => { + dictionary_value_coercion(lhs_type, value_type) + } + _ => None, + } +} + +/// Coercion rules for Strings: the type that both lhs and rhs can be +/// casted to for the purpose of a string computation fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -1098,7 +1133,9 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } -/// coercion rule for numerical types +/// Coercion rule for numerical types: The type that both lhs and rhs +/// can be casted to for numerical calculation, while maintaining +/// maximum precision pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; @@ -1156,6 +1193,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { return Some(lhs_type.clone()); } numerical_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) } // coercion rules that assume an ordered set, such as "less than". @@ -1166,16 +1204,13 @@ fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option return Some(lhs_type.clone()); } - match numerical_coercion(lhs_type, rhs_type) { - None => { - // strings are naturally ordered, and thus ordering can be applied to them. - string_coercion(lhs_type, rhs_type) - } - t => t, - } + numerical_coercion(lhs_type, rhs_type) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) } -/// coercion rules for all binary operators +/// Coercion rules for all binary operators. Returns the output type +/// of applying `op` to an argument of `lhs_type` and `rhs_type`. fn common_binary_type( lhs_type: &DataType, op: &Operator, @@ -1532,7 +1567,8 @@ impl PhysicalExpr for CastExpr { } } -/// Returns a cast operation, if casting needed. +/// Returns a physical cast operation that casts `expr` to `cast_type` +/// if casting is needed pub fn cast( expr: Arc, input_schema: &Schema, @@ -1541,14 +1577,7 @@ pub fn cast( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { return Ok(expr.clone()); - } - if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if is_numeric(&expr_type) - && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) - { + } else if can_coerce_from(&cast_type, &expr_type) { Ok(Arc::new(CastExpr { expr, cast_type })) } else { Err(ExecutionError::General(format!( @@ -1675,11 +1704,14 @@ impl PhysicalSortExpr { mod tests { use super::*; use crate::error::Result; - use arrow::array::{ - LargeStringArray, PrimitiveArray, PrimitiveArrayOps, StringArray, - Time64NanosecondArray, - }; use arrow::datatypes::*; + use arrow::{ + array::{ + LargeStringArray, PrimitiveArray, PrimitiveArrayOps, PrimitiveBuilder, + StringArray, StringDictionaryBuilder, Time64NanosecondArray, + }, + util::pretty::array_value_to_string, + }; // Create a binary expression without coercion. Used here when we do not want to coerce the expressions // to valid types. Usage can result in an execution (after plan) error. @@ -1782,11 +1814,13 @@ mod tests { // runs an end-to-end test of physical type coercion: // 1. construct a record batch with two columns of type A and B + // (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements) // 2. construct a physical expression of A OP B // 3. evaluate the expression // 4. verify that the resulting expression is of type C + // 5. verify that the results of evaluation are $VEC macro_rules! test_coercion { - ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ + ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr) => {{ let schema = Schema::new(vec![ Field::new("a", $A_TYPE, false), Field::new("b", $B_TYPE, false), @@ -1802,18 +1836,18 @@ mod tests { let expression = binary(col("a"), $OP, col("b"), &schema)?; // verify that the expression's type is correct - assert_eq!(expression.data_type(&schema)?, $TYPE); + assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute let result = expression.evaluate(&batch)?; // verify that the array's data_type is correct - assert_eq!(*result.data_type(), $TYPE); + assert_eq!(*result.data_type(), $C_TYPE); // verify that the data itself is downcastable let result = result .as_any() - .downcast_ref::<$TYPEARRAY>() + .downcast_ref::<$C_ARRAY>() .expect("failed to downcast"); // verify that the result itself is correct for (i, x) in $VEC.iter().enumerate() { @@ -1887,6 +1921,107 @@ mod tests { Ok(()) } + #[test] + fn test_dictionary_type_coersion() -> Result<()> { + use DataType::*; + + // TODO: In the future, this would ideally return Dictionary types and avoid unpacking + let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); + let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32)); + + let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None); + + let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let rhs_type = Utf8; + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + + let lhs_type = Utf8; + let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + + Ok(()) + } + + // Note it would be nice to use the same test_coercion macro as + // above, but sadly the type of the values of the dictionary are + // not encoded in the rust type of the DictionaryArray. Thus there + // is no way at the time of this writing to create a dictionary + // array using the `From` trait + #[test] + fn test_dictionary_type_to_array_coersion() -> Result<()> { + // Test string a string dictionary + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let string_type = DataType::Utf8; + + // build dictionary + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + dict_builder.append("one")?; + dict_builder.append_null()?; + dict_builder.append("three")?; + dict_builder.append("four")?; + let dict_array = dict_builder.finish(); + + let str_array = + StringArray::from(vec![Some("not one"), Some("two"), None, Some("four")]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict", dict_type.clone(), true), + Field::new("str", string_type.clone(), true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(dict_array), Arc::new(str_array)], + )?; + + let expected = "false\n\n\ntrue"; + + // Test 1: dict = str + + // verify that we can construct the expression + let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?; + assert_eq!(expression.data_type(&schema)?, DataType::Boolean); + + // evaluate and verify the result type matched + let result = expression.evaluate(&batch)?; + assert_eq!(result.data_type(), &DataType::Boolean); + + // verify that the result itself is correct + assert_eq!(expected, array_to_string(&result)?); + + // Test 2: now test the other direction + // str = dict + + // verify that we can construct the expression + let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?; + assert_eq!(expression.data_type(&schema)?, DataType::Boolean); + + // evaluate and verify the result type matched + let result = expression.evaluate(&batch)?; + assert_eq!(result.data_type(), &DataType::Boolean); + + // verify that the result itself is correct + assert_eq!(expected, array_to_string(&result)?); + + Ok(()) + } + + // Convert the array to a newline delimited string of pretty printed values + fn array_to_string(array: &ArrayRef) -> Result { + let s = (0..array.len()) + .map(|i| array_value_to_string(array, i)) + .collect::, arrow::error::ArrowError>>()? + .join("\n"); + Ok(s) + } + #[test] fn test_coersion_error() -> Result<()> { let expr = @@ -1992,6 +2127,7 @@ mod tests { #[test] fn invalid_cast() -> Result<()> { + // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); let result = cast(col("a"), &schema, DataType::Int32); result.expect_err("Invalid CAST from Utf8 to Int32"); diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 5640daa5303..3a8dc742ff8 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -21,8 +21,8 @@ use std::sync::Arc; extern crate arrow; extern crate datafusion; -use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::TimeUnit}; +use arrow::{datatypes::Int32Type, record_batch::RecordBatch}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, util::pretty::array_value_to_string, @@ -918,14 +918,20 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) { /// Execute query and return result set as 2-d table of Vecs /// `result[row][column]` async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { - let plan = ctx.create_logical_plan(&sql).unwrap(); + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(&sql).expect(&msg); let logical_schema = plan.schema(); - let plan = ctx.optimize(&plan).unwrap(); + + let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); + let plan = ctx.optimize(&plan).expect(&msg); let optimized_logical_schema = plan.schema(); - let plan = ctx.create_physical_plan(&plan).unwrap(); + + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).expect(&msg); let physical_schema = plan.schema(); - let results = ctx.collect(plan).await.unwrap(); + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let results = ctx.collect(plan).await.expect(&msg); assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); assert_eq!(logical_schema.as_ref(), physical_schema.as_ref()); @@ -1200,3 +1206,59 @@ async fn query_is_not_null() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn query_on_string_dictionary() -> Result<()> { + // Test to ensure DataFusion can operate on dictionary types + // Use StringDictionary (32 bit indexes = keys) + let field_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + builder.append("one")?; + builder.append_null()?; + builder.append("three")?; + let array = Arc::new(builder.finish()); + + let data = RecordBatch::try_new(schema.clone(), vec![array])?; + + let table = MemTable::new(schema, vec![vec![data]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + + // Basic SELECT + let sql = "SELECT * FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["NULL"], vec!["three"]]; + assert_eq!(expected, actual); + + // basic filtering + let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["three"]]; + assert_eq!(expected, actual); + + // filtering with constant + let sql = "SELECT * FROM test WHERE d1 = 'three'"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["three"]]; + assert_eq!(expected, actual); + + // Expression evaluation + let sql = "SELECT concat(d1, '-foo') FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]]; + assert_eq!(expected, actual); + + // aggregation + let sql = "SELECT COUNT(d1) FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["2"]]; + assert_eq!(expected, actual); + + Ok(()) +}