diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index 3283dff6217..79394e427a8 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -2056,13 +2056,13 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer, usize)> for StructArray { /// assert_eq!(array.keys().collect::>>(), vec![Some(0), Some(0), Some(1), Some(2)]); /// ``` pub struct DictionaryArray { - /// Array of keys, much like a PrimitiveArray + /// Array of keys, stored as a PrimitiveArray. data: ArrayDataRef, /// Pointer to the key values. raw_values: RawPtrBox, - /// Array of any values. + /// Array of dictionary values (can by any DataType). values: ArrayRef, /// Values are ordered. diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index d8cb480a80c..0b8c9d3d0c0 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Defines cast kernels for `ArrayRef`, allowing casting arrays between supported -//! datatypes. +//! Defines cast kernels for `ArrayRef`, to convert `Array`s between +//! supported datatypes. //! //! Example: //! @@ -44,7 +44,8 @@ use crate::compute::kernels::arithmetic::{divide, multiply}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -/// Cast array to provided data type +/// Cast `array` to the provided data type and return a new Array with +/// type `to_type`, if possible. /// /// Behavior: /// * Boolean to Utf8: `true` => '1', `false` => `0` @@ -125,6 +126,34 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { Ok(list_array) } + (Dictionary(index_type, _), _) => match **index_type { + DataType::Int8 => dictionary_cast::(array, to_type), + DataType::Int16 => dictionary_cast::(array, to_type), + DataType::Int32 => dictionary_cast::(array, to_type), + DataType::Int64 => dictionary_cast::(array, to_type), + DataType::UInt8 => dictionary_cast::(array, to_type), + DataType::UInt16 => dictionary_cast::(array, to_type), + DataType::UInt32 => dictionary_cast::(array, to_type), + DataType::UInt64 => dictionary_cast::(array, to_type), + _ => Err(ArrowError::ComputeError(format!( + "Casting from dictionary type {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (_, Dictionary(index_type, value_type)) => match **index_type { + DataType::Int8 => cast_to_dictionary::(array, value_type), + DataType::Int16 => cast_to_dictionary::(array, value_type), + DataType::Int32 => cast_to_dictionary::(array, value_type), + DataType::Int64 => cast_to_dictionary::(array, value_type), + DataType::UInt8 => cast_to_dictionary::(array, value_type), + DataType::UInt16 => cast_to_dictionary::(array, value_type), + DataType::UInt32 => cast_to_dictionary::(array, value_type), + DataType::UInt64 => cast_to_dictionary::(array, value_type), + _ => Err(ArrowError::ComputeError(format!( + "Casting from type {:?} to dictionary type {:?} not supported", + from_type, to_type, + ))), + }, (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), UInt16 => cast_numeric_to_bool::(array), @@ -755,10 +784,253 @@ where Ok(b.finish()) } +/// Attempts to cast an `ArrayDictionary` with index type K into +/// `to_type` for supported type. +/// +/// K is the key type +fn dictionary_cast( + array: &ArrayRef, + to_type: &DataType, +) -> Result { + use DataType::*; + + let dict_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), + ) + })?; + + match to_type { + Dictionary(to_index_type, to_value_type) => { + let keys_array: ArrayRef = Arc::new(dict_array.keys_array()); + let values_array: ArrayRef = dict_array.values(); + let cast_keys = cast(&keys_array, to_index_type)?; + let cast_values = cast(&values_array, to_value_type)?; + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > keys_array.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - keys_array.null_count(), + keys_array.data_type(), + to_index_type + ))); + } + + // keys are data, child_data is values (dictionary) + let data = Arc::new(ArrayData::new( + to_type.clone(), + cast_keys.len(), + Some(cast_keys.null_count()), + cast_keys + .data() + .null_bitmap() + .clone() + .map(|bitmap| bitmap.bits), + cast_keys.data().offset(), + cast_keys.data().buffers().to_vec(), + vec![cast_values.data()], + )); + + // create the appropriate array type + let new_array: ArrayRef = match **to_index_type { + Int8 => Arc::new(DictionaryArray::::from(data)), + Int16 => Arc::new(DictionaryArray::::from(data)), + Int32 => Arc::new(DictionaryArray::::from(data)), + Int64 => Arc::new(DictionaryArray::::from(data)), + UInt8 => Arc::new(DictionaryArray::::from(data)), + UInt16 => Arc::new(DictionaryArray::::from(data)), + UInt32 => Arc::new(DictionaryArray::::from(data)), + UInt64 => Arc::new(DictionaryArray::::from(data)), + _ => { + return Err(ArrowError::ComputeError(format!( + "Unsupported type {:?} for dictionary index", + to_index_type + ))) + } + }; + + Ok(new_array) + } + // numeric types + Int8 => unpack_dictionary_to_numeric::(dict_array, to_type), + Int16 => unpack_dictionary_to_numeric::(dict_array, to_type), + Int32 => unpack_dictionary_to_numeric::(dict_array, to_type), + Int64 => unpack_dictionary_to_numeric::(dict_array, to_type), + UInt8 => unpack_dictionary_to_numeric::(dict_array, to_type), + UInt16 => unpack_dictionary_to_numeric::(dict_array, to_type), + UInt32 => unpack_dictionary_to_numeric::(dict_array, to_type), + UInt64 => unpack_dictionary_to_numeric::(dict_array, to_type), + Utf8 => unpack_dictionary_to_string::(dict_array), + _ => Err(ArrowError::ComputeError(format!( + "Unsupported output type for dictionary conversion: {:?}", + to_type + ))), + } +} + +// Unpack the dictionary where the keys are of type and the values +// are of type into a primative array of type to_type +fn unpack_dictionary_to_numeric( + dict_array: &DictionaryArray, + to_type: &DataType, +) -> Result +where + K: ArrowDictionaryKeyType, + V: ArrowNumericType, +{ + // attempt to cast the dict values to the target type + let cast_dict_values = cast(&dict_array.values(), to_type)?; + let dict_values = cast_dict_values + .as_any() + .downcast_ref::>() + .unwrap(); + + let mut b = PrimitiveBuilder::::new(dict_array.len()); + + // copy each element one at a time + for key in dict_array.keys() { + match key { + Some(key) => { + let key = key.to_usize().ok_or_else(|| { + ArrowError::ComputeError(format!( + "Could not convert {:?} to usize for dictionary index in StringArray", + key + )) + })?; + b.append_value(dict_values.value(key))? + } + None => b.append_null()?, + } + } + Ok(Arc::new(b.finish())) +} + +/// Unpack the dictionary into StringBuffer +fn unpack_dictionary_to_string(dict_array: &DictionaryArray) -> Result +where + K: ArrowDictionaryKeyType, +{ + use DataType::*; + + // attempt to cast the dict values to the taget type Utf8 (Strings) and then copy them over + let cast_dict_values = cast(&dict_array.values(), &Utf8)?; + let dict_values = cast_dict_values + .as_any() + .downcast_ref::() + .unwrap(); + + let mut b = StringBuilder::new(dict_array.len()); + + // copy each element one at a time + for key in dict_array.keys() { + match key { + Some(key) => { + let key = key.to_usize().ok_or_else(|| { + ArrowError::ComputeError(format!( + "Could not convert {:?} to usize for dictionary index in StringArray", + key + )) + })?; + b.append_value(dict_values.value(key))? + } + None => b.append_null()?, + } + } + Ok(Arc::new(b.finish())) +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +fn cast_to_dictionary( + array: &ArrayRef, + dict_value_type: &DataType, +) -> Result { + use DataType::*; + + match *dict_value_type { + Int8 => pack_numeric_to_dictionary::(array, dict_value_type), + Int16 => pack_numeric_to_dictionary::(array, dict_value_type), + Int32 => pack_numeric_to_dictionary::(array, dict_value_type), + Int64 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt8 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt16 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt32 => pack_numeric_to_dictionary::(array, dict_value_type), + UInt64 => pack_numeric_to_dictionary::(array, dict_value_type), + Utf8 => pack_string_to_dictionary::(array), + _ => Err(ArrowError::ComputeError(format!( + "Internal Error: Unsupported output type for dictionary packing: {:?}", + dict_value_type + ))), + } +} + +// Packs the data from the primitive array of type to a +// DictionaryArray with keys of type K and values of value_type V +fn pack_numeric_to_dictionary( + array: &ArrayRef, + dict_value_type: &DataType, +) -> Result +where + K: ArrowDictionaryKeyType, + V: ArrowNumericType, +{ + // attempt to cast the source array values to the target value type (the dictionary values type) + let cast_values = cast(array, &dict_value_type)?; + let values = cast_values + .as_any() + .downcast_ref::>() + .unwrap(); + + let keys_builder = PrimitiveBuilder::::new(values.len()); + let values_builder = PrimitiveBuilder::::new(values.len()); + let mut b = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null()?; + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + +// Packs the data as a StringDictionaryArray, if possible, with the +// key types of K +fn pack_string_to_dictionary(array: &ArrayRef) -> Result +where + K: ArrowDictionaryKeyType, +{ + let cast_values = cast(array, &DataType::Utf8)?; + let values = cast_values.as_any().downcast_ref::().unwrap(); + + let keys_builder = PrimitiveBuilder::::new(values.len()); + let values_builder = StringBuilder::new(values.len()); + let mut b = StringDictionaryBuilder::new(keys_builder, values_builder); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null()?; + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + #[cfg(test)] mod tests { use super::*; - use crate::buffer::Buffer; + use crate::{buffer::Buffer, util::pretty::array_value_to_string}; #[test] fn test_cast_i32_to_f64() { @@ -2048,6 +2320,7 @@ mod tests { ); } + /// Convert `array` into a vector of strings by casting to data type dt fn get_cast_values(array: &ArrayRef, dt: &DataType) -> Vec where T: ArrowNumericType, @@ -2064,4 +2337,209 @@ mod tests { } v } + + #[test] + fn test_cast_utf8_dict() { + // FROM a dictionary with of Utf8 values + use DataType::*; + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + builder.append("one").unwrap(); + builder.append_null().unwrap(); + builder.append("three").unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["one", "null", "three"]; + + // Test casting TO StringArray + let cast_type = Utf8; + let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Test casting TO Dictionary (with different index sizes) + + let cast_type = Dictionary(Box::new(Int16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_primitive() { + use DataType::*; + // test converting from an array that has indexes of a type + // that are out of bounds for a particular other kind of + // index. + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = PrimitiveBuilder::::new(10); + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + builder.append(i).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{:?}", res); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + actual_error, + expected_error + ); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_utf8() { + use DataType::*; + // Same test as test_cast_dict_to_dict_bad_index_value but use + // string values (and encode the expected behavior here); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + let val = format!("val{}", i); + builder.append(&val).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{:?}", res); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + actual_error, + expected_error + ); + } + + #[test] + fn test_cast_primitive_dict() { + // FROM a dictionary with of INT32 values + use DataType::*; + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = PrimitiveBuilder::::new(10); + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + builder.append(1).unwrap(); + builder.append_null().unwrap(); + builder.append(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Test casting TO PrimitiveArray, different dictionary type + let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 succeeded"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Utf8); + + let cast_array = cast(&array, &Int64).expect("cast to int64 succeeded"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Int64); + } + + #[test] + fn test_cast_primitive_array_to_dict() { + use DataType::*; + + let mut builder = PrimitiveBuilder::::new(10); + builder.append_value(1).unwrap(); + builder.append_null().unwrap(); + builder.append_value(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Cast to a dictionary (same value type, Int32) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Cast to a dictionary (different value type, Int8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_string_array_to_dict() { + use DataType::*; + + let mut builder = StringBuilder::new(10); + builder.append_value("one").unwrap(); + builder.append_null().unwrap(); + builder.append_value("three").unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["one", "null", "three"]; + + // Cast to a dictionary (same value type, Utf8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + /// Print the `DictionaryArray` `array` as a vector of strings + fn array_to_strings(array: &ArrayRef) -> Vec { + (0..array.len()) + .map(|i| { + if array.is_null(i) { + "null".to_string() + } else { + array_value_to_string(array, i).expect("Convert array to String") + } + }) + .collect() + } } diff --git a/rust/arrow/src/util/pretty.rs b/rust/arrow/src/util/pretty.rs index 3c23e28f7ed..b881c3ae25d 100644 --- a/rust/arrow/src/util/pretty.rs +++ b/rust/arrow/src/util/pretty.rs @@ -92,6 +92,7 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result make_string!(array::StringArray, column, row), DataType::Boolean => make_string!(array::BooleanArray, column, row), + DataType::Int8 => make_string!(array::Int8Array, column, row), DataType::Int16 => make_string!(array::Int16Array, column, row), DataType::Int32 => make_string!(array::Int32Array, column, row), DataType::Int64 => make_string!(array::Int64Array, column, row), diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 8d21e381e25..dea025a9e51 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,8 +37,7 @@ use crate::{ }; use crate::{ physical_plan::{ - aggregates, expressions::binary_operator_data_type, functions, - type_coercion::can_coerce_from, udf::ScalarUDF, + aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, }, sql::parser::FileType, }; @@ -323,21 +322,19 @@ impl Expr { /// /// # Errors /// - /// This function errors when it is impossible to cast the expression to the target [arrow::datatypes::DataType]. + /// Currently no errors happen at plan time. If it is impossible + /// to cast the expression to the target + /// [arrow::datatypes::DataType] then an error will occur at + /// runtime. pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self.clone()) - } else if can_coerce_from(cast_to_type, &this_type) { + } else { Ok(Expr::Cast { expr: Box::new(self.clone()), data_type: cast_to_type.clone(), }) - } else { - Err(ExecutionError::General(format!( - "Cannot automatically convert {:?} to {:?}", - this_type, cast_to_type - ))) } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 6c5e1e8da8d..7cb3335803a 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -21,9 +21,12 @@ use std::sync::Arc; extern crate arrow; extern crate datafusion; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::TimeUnit}; +use arrow::{datatypes::Int32Type, record_batch::RecordBatch}; +use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + util::pretty::array_value_to_string, +}; use datafusion::datasource::{csv::CsvReadOptions, MemTable}; use datafusion::error::Result; @@ -93,10 +96,18 @@ async fn parquet_query() { // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = - "4\t\"0\"\n5\t\"1\"\n6\t\"0\"\n7\t\"1\"\n2\t\"0\"\n3\t\"1\"\n0\t\"0\"\n1\t\"1\"" - .to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["4", "0"], + vec!["5", "1"], + vec!["6", "0"], + vec!["7", "1"], + vec!["2", "0"], + vec!["3", "1"], + vec!["0", "0"], + vec!["1", "1"], + ]; + assert_eq!(expected, actual); } @@ -122,8 +133,8 @@ async fn csv_count_star() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT COUNT(*), COUNT(1), COUNT(c1) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "100\t100\t100".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["100", "100", "100"]]; assert_eq!(expected, actual); Ok(()) } @@ -133,8 +144,11 @@ async fn csv_query_with_predicate() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "\"e\"\t0.39144436569161134\n\"d\"\t0.38870280983958583".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["e", "0.39144436569161134"], + vec!["d", "0.38870280983958583"], + ]; assert_eq!(expected, actual); Ok(()) } @@ -144,8 +158,8 @@ async fn csv_query_with_negated_predicate() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE NOT(c1 != 'a')"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "21".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["21"]]; assert_eq!(expected, actual); Ok(()) } @@ -155,8 +169,8 @@ async fn csv_query_with_is_not_null_predicate() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NOT NULL"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "100".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["100"]]; assert_eq!(expected, actual); Ok(()) } @@ -166,8 +180,8 @@ async fn csv_query_with_is_null_predicate() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NULL"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "0".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"]]; assert_eq!(expected, actual); Ok(()) } @@ -179,8 +193,14 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = "1\t0.05636955101974106\t0.9965400387585364\n2\t0.16301110515739792\t0.991517828651004\n3\t0.047343434291126085\t0.9293883502480845\n4\t0.02182578039211991\t0.9237877978193884\n5\t0.01479305307777301\t0.9723580396501548".to_string(); - assert_eq!(expected, actual.join("\n")); + let expected = vec![ + vec!["1", "0.05636955101974106", "0.9965400387585364"], + vec!["2", "0.16301110515739792", "0.991517828651004"], + vec!["3", "0.047343434291126085", "0.9293883502480845"], + vec!["4", "0.02182578039211991", "0.9237877978193884"], + vec!["5", "0.01479305307777301", "0.9723580396501548"], + ]; + assert_eq!(expected, actual); Ok(()) } @@ -191,34 +211,34 @@ async fn csv_query_group_by_two_columns() -> Result<()> { let sql = "SELECT c1, c2, MIN(c3) FROM aggregate_test_100 GROUP BY c1, c2"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = [ - "\"a\"\t1\t-85", - "\"a\"\t2\t-48", - "\"a\"\t3\t-72", - "\"a\"\t4\t-101", - "\"a\"\t5\t-101", - "\"b\"\t1\t12", - "\"b\"\t2\t-60", - "\"b\"\t3\t-101", - "\"b\"\t4\t-117", - "\"b\"\t5\t-82", - "\"c\"\t1\t-24", - "\"c\"\t2\t-117", - "\"c\"\t3\t-2", - "\"c\"\t4\t-90", - "\"c\"\t5\t-94", - "\"d\"\t1\t-99", - "\"d\"\t2\t93", - "\"d\"\t3\t-76", - "\"d\"\t4\t5", - "\"d\"\t5\t-59", - "\"e\"\t1\t36", - "\"e\"\t2\t-61", - "\"e\"\t3\t-95", - "\"e\"\t4\t-56", - "\"e\"\t5\t-86", + let expected = vec![ + vec!["a", "1", "-85"], + vec!["a", "2", "-48"], + vec!["a", "3", "-72"], + vec!["a", "4", "-101"], + vec!["a", "5", "-101"], + vec!["b", "1", "12"], + vec!["b", "2", "-60"], + vec!["b", "3", "-101"], + vec!["b", "4", "-117"], + vec!["b", "5", "-82"], + vec!["c", "1", "-24"], + vec!["c", "2", "-117"], + vec!["c", "3", "-2"], + vec!["c", "4", "-90"], + vec!["c", "5", "-94"], + vec!["d", "1", "-99"], + vec!["d", "2", "93"], + vec!["d", "3", "-76"], + vec!["d", "4", "5"], + vec!["d", "5", "-59"], + vec!["e", "1", "36"], + vec!["e", "2", "-61"], + vec!["e", "3", "-95"], + vec!["e", "4", "-56"], + vec!["e", "5", "-86"], ]; - assert_eq!(expected.join("\n"), actual.join("\n")); + assert_eq!(expected, actual); Ok(()) } @@ -229,8 +249,8 @@ async fn csv_query_avg_sqrt() -> Result<()> { let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = "0.6706002946036462".to_string(); - assert_eq!(actual.join("\n"), expected); + let expected = vec![vec!["0.6706002946036462"]]; + assert_eq!(actual, expected); Ok(()) } @@ -243,8 +263,8 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { register_aggregate_csv(&mut ctx)?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; let actual = execute(&mut ctx, sql).await; - let expected = "0.6584408483418833".to_string(); - assert_eq!(actual.join("\n"), expected); + let expected = vec![vec!["0.6584408483418833"]]; + assert_eq!(actual, expected); Ok(()) } @@ -255,14 +275,14 @@ async fn sqrt_f32_vs_f64() -> Result<()> { register_aggregate_csv(&mut ctx)?; // sqrt(f32)'s plan passes let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; - let actual = &execute(&mut ctx, sql).await[0]; - let expected = "0.6584408485889435".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0.6584408485889435"]]; - assert_eq!(*actual, expected); + assert_eq!(actual, expected); let sql = "SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100"; - let actual = &execute(&mut ctx, sql).await[0]; - let expected = "0.6584408483418833".to_string(); - assert_eq!(*actual, expected); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0.6584408483418833"]]; + assert_eq!(actual, expected); Ok(()) } @@ -285,8 +305,8 @@ async fn csv_query_sqrt_sqrt() -> Result<()> { let sql = "SELECT sqrt(sqrt(c12)) FROM aggregate_test_100 LIMIT 1"; let actual = execute(&mut ctx, sql).await; // sqrt(sqrt(c12=0.9294097332465232)) = 0.9818650561397431 - let expected = "0.9818650561397431".to_string(); - assert_eq!(actual.join("\n"), expected); + let expected = vec![vec!["0.9818650561397431"]]; + assert_eq!(actual, expected); Ok(()) } @@ -328,8 +348,8 @@ async fn csv_query_avg() -> Result<()> { let sql = "SELECT avg(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = "0.5089725099127211".to_string(); - assert_eq!(expected, actual.join("\n")); + let expected = vec![vec!["0.5089725099127211"]]; + assert_eq!(expected, actual); Ok(()) } @@ -340,8 +360,14 @@ async fn csv_query_group_by_avg() -> Result<()> { let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = "\"a\"\t0.48754517466109415\n\"b\"\t0.41040709263815384\n\"c\"\t0.6600456536439784\n\"d\"\t0.48855379387549824\n\"e\"\t0.48600669271341534".to_string(); - assert_eq!(expected, actual.join("\n")); + let expected = vec![ + vec!["a", "0.48754517466109415"], + vec!["b", "0.41040709263815384"], + vec!["c", "0.6600456536439784"], + vec!["d", "0.48855379387549824"], + vec!["e", "0.48600669271341534"], + ]; + assert_eq!(expected, actual); Ok(()) } @@ -352,8 +378,14 @@ async fn csv_query_group_by_avg_with_projection() -> Result<()> { let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = "0.41040709263815384\t\"b\"\n0.48600669271341534\t\"e\"\n0.48754517466109415\t\"a\"\n0.48855379387549824\t\"d\"\n0.6600456536439784\t\"c\"".to_string(); - assert_eq!(expected, actual.join("\n")); + let expected = vec![ + vec!["0.41040709263815384", "b"], + vec!["0.48600669271341534", "e"], + vec!["0.48754517466109415", "a"], + vec!["0.48855379387549824", "d"], + vec!["0.6600456536439784", "c"], + ]; + assert_eq!(expected, actual); Ok(()) } @@ -382,8 +414,8 @@ async fn csv_query_count() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT count(c12) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "100".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["100"]]; assert_eq!(expected, actual); Ok(()) } @@ -395,8 +427,14 @@ async fn csv_query_group_by_int_count() -> Result<()> { let sql = "SELECT c1, count(c12) FROM aggregate_test_100 GROUP BY c1"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = "\"a\"\t21\n\"b\"\t19\n\"c\"\t21\n\"d\"\t18\n\"e\"\t21".to_string(); - assert_eq!(expected, actual.join("\n")); + let expected = vec![ + vec!["a", "21"], + vec!["b", "19"], + vec!["c", "21"], + vec!["d", "18"], + vec!["e", "21"], + ]; + assert_eq!(expected, actual); Ok(()) } @@ -407,8 +445,14 @@ async fn csv_query_group_with_aliased_aggregate() -> Result<()> { let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = "\"a\"\t21\n\"b\"\t19\n\"c\"\t21\n\"d\"\t18\n\"e\"\t21".to_string(); - assert_eq!(expected, actual.join("\n")); + let expected = vec![ + vec!["a", "21"], + vec!["b", "19"], + vec!["c", "21"], + vec!["d", "18"], + vec!["e", "21"], + ]; + assert_eq!(expected, actual); Ok(()) } @@ -419,9 +463,14 @@ async fn csv_query_group_by_string_min_max() -> Result<()> { let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); - let expected = - "\"a\"\t0.02182578039211991\t0.9800193410444061\n\"b\"\t0.04893135681998029\t0.9185813970744787\n\"c\"\t0.0494924465469434\t0.991517828651004\n\"d\"\t0.061029375346466685\t0.9748360509016578\n\"e\"\t0.01479305307777301\t0.9965400387585364".to_string(); - assert_eq!(expected, actual.join("\n")); + let expected = vec![ + vec!["a", "0.02182578039211991", "0.9800193410444061"], + vec!["b", "0.04893135681998029", "0.9185813970744787"], + vec!["c", "0.0494924465469434", "0.991517828651004"], + vec!["d", "0.061029375346466685", "0.9748360509016578"], + vec!["e", "0.01479305307777301", "0.9965400387585364"], + ]; + assert_eq!(expected, actual); Ok(()) } @@ -430,8 +479,8 @@ async fn csv_query_cast() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "0.39144436569161134\n0.38870280983958583".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0.39144436569161134"], vec!["0.38870280983958583"]]; assert_eq!(expected, actual); Ok(()) } @@ -441,8 +490,11 @@ async fn csv_query_cast_literal() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT c12, CAST(1 AS float) FROM aggregate_test_100 WHERE c12 > CAST(0 AS float) LIMIT 2"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "0.9294097332465232\t1.0\n0.3114712539863804\t1.0".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["0.9294097332465232", "1"], + vec!["0.3114712539863804", "1"], + ]; assert_eq!(expected, actual); Ok(()) } @@ -452,8 +504,8 @@ async fn csv_query_limit() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 2"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "\"c\"\n\"d\"".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["c"], vec!["d"]]; assert_eq!(expected, actual); Ok(()) } @@ -463,8 +515,109 @@ async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["2"], + vec!["5"], + vec!["1"], + vec!["1"], + vec!["5"], + vec!["4"], + vec!["3"], + vec!["3"], + vec!["1"], + vec!["4"], + vec!["1"], + vec!["4"], + vec!["3"], + vec!["2"], + vec!["1"], + vec!["1"], + vec!["2"], + vec!["1"], + vec!["3"], + vec!["2"], + vec!["4"], + vec!["1"], + vec!["5"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["4"], + vec!["5"], + vec!["2"], + vec!["3"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["5"], + vec!["3"], + vec!["1"], + vec!["2"], + vec!["3"], + vec!["3"], + vec!["3"], + vec!["2"], + vec!["4"], + vec!["1"], + vec!["3"], + vec!["2"], + vec!["5"], + vec!["2"], + vec!["1"], + vec!["4"], + vec!["1"], + vec!["4"], + vec!["2"], + vec!["5"], + vec!["4"], + vec!["2"], + vec!["3"], + vec!["4"], + vec!["4"], + vec!["4"], + vec!["5"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["2"], + vec!["4"], + vec!["2"], + vec!["3"], + vec!["5"], + vec!["1"], + vec!["1"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["2"], + vec!["1"], + vec!["1"], + vec!["5"], + vec!["4"], + vec!["5"], + vec!["2"], + vec!["3"], + vec!["2"], + vec!["4"], + vec!["1"], + vec!["3"], + vec!["4"], + vec!["3"], + vec!["2"], + vec!["5"], + vec!["3"], + vec!["3"], + vec!["2"], + vec!["5"], + vec!["5"], + vec!["4"], + vec!["1"], + vec!["3"], + vec!["3"], + vec!["4"], + vec!["4"], + ]; assert_eq!(expected, actual); Ok(()) } @@ -474,8 +627,109 @@ async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["2"], + vec!["5"], + vec!["1"], + vec!["1"], + vec!["5"], + vec!["4"], + vec!["3"], + vec!["3"], + vec!["1"], + vec!["4"], + vec!["1"], + vec!["4"], + vec!["3"], + vec!["2"], + vec!["1"], + vec!["1"], + vec!["2"], + vec!["1"], + vec!["3"], + vec!["2"], + vec!["4"], + vec!["1"], + vec!["5"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["4"], + vec!["5"], + vec!["2"], + vec!["3"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["5"], + vec!["3"], + vec!["1"], + vec!["2"], + vec!["3"], + vec!["3"], + vec!["3"], + vec!["2"], + vec!["4"], + vec!["1"], + vec!["3"], + vec!["2"], + vec!["5"], + vec!["2"], + vec!["1"], + vec!["4"], + vec!["1"], + vec!["4"], + vec!["2"], + vec!["5"], + vec!["4"], + vec!["2"], + vec!["3"], + vec!["4"], + vec!["4"], + vec!["4"], + vec!["5"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["2"], + vec!["4"], + vec!["2"], + vec!["3"], + vec!["5"], + vec!["1"], + vec!["1"], + vec!["4"], + vec!["2"], + vec!["1"], + vec!["2"], + vec!["1"], + vec!["1"], + vec!["5"], + vec!["4"], + vec!["5"], + vec!["2"], + vec!["3"], + vec!["2"], + vec!["4"], + vec!["1"], + vec!["3"], + vec!["4"], + vec!["3"], + vec!["2"], + vec!["5"], + vec!["3"], + vec!["3"], + vec!["2"], + vec!["5"], + vec!["5"], + vec!["4"], + vec!["1"], + vec!["3"], + vec!["3"], + vec!["4"], + vec!["4"], + ]; assert_eq!(expected, actual); Ok(()) } @@ -485,8 +739,8 @@ async fn csv_query_limit_zero() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 0"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected: Vec> = vec![]; assert_eq!(expected, actual); Ok(()) } @@ -496,8 +750,22 @@ async fn csv_query_create_external_table() { let mut ctx = ExecutionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "\"c\"\t2\t1\t18109\t2033001162\t-6513304855495910254\t25\t43062\t1491205016\t10\t0.110830784\t0.9294097332465232\t\"6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW\"".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec![ + "c", + "2", + "1", + "18109", + "2033001162", + "-6513304855495910254", + "25", + "43062", + "1491205016", + "10", + "0.110830784", + "0.9294097332465232", + "6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW", + ]]; assert_eq!(expected, actual); } @@ -506,8 +774,8 @@ async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(c12) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "100".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["100"]]; assert_eq!(expected, actual); } @@ -516,8 +784,8 @@ async fn csv_query_count_star() { let mut ctx = ExecutionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(*) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "100".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["100"]]; assert_eq!(expected, actual); } @@ -526,8 +794,8 @@ async fn csv_query_count_one() { let mut ctx = ExecutionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(1) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "100".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["100"]]; assert_eq!(expected, actual); } @@ -536,13 +804,18 @@ async fn csv_explain() { let mut ctx = ExecutionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "\"logical_plan\"\t\"Projection: #c1\\n Filter: #c2 Gt Int64(10)\\n TableScan: aggregate_test_100 projection=None\"".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec![ + "logical_plan", + "Projection: #c1\n Filter: #c2 Gt Int64(10)\n TableScan: aggregate_test_100 projection=None" + ] + ]; assert_eq!(expected, actual); // Also, expect same result with lowercase explain let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await.join("\n"); + let actual = execute(&mut ctx, sql).await; assert_eq!(expected, actual); } @@ -551,7 +824,11 @@ async fn csv_explain_verbose() { let mut ctx = ExecutionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await.join("\n"); + let actual = execute(&mut ctx, sql).await; + + // flatten to a single string + let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); + // Don't actually test the contents of the debuging output (as // that may change and keeping this test updated will be a // pain). Instead just check for a few key pieces. @@ -638,107 +915,68 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) { .unwrap(); } -/// Execute query and return result set as tab delimited string -async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec { - let plan = ctx.create_logical_plan(&sql).unwrap(); +/// Execute query and return result set as 2-d table of Vecs +/// `result[row][column]` +async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(&sql).expect(&msg); let logical_schema = plan.schema(); - let plan = ctx.optimize(&plan).unwrap(); + + let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); + let plan = ctx.optimize(&plan).expect(&msg); let optimized_logical_schema = plan.schema(); - let plan = ctx.create_physical_plan(&plan).unwrap(); + + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).expect(&msg); let physical_schema = plan.schema(); - let results = ctx.collect(plan).await.unwrap(); + + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let results = ctx.collect(plan).await.expect(&msg); assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); assert_eq!(logical_schema.as_ref(), physical_schema.as_ref()); - result_str(&results) + result_vec(&results) } -/// Converts an array's value at `row_index` to a string. -fn array_str(array: &Arc, row_index: usize) -> String { - if array.is_null(row_index) { +/// Specialised String representation +fn col_str(column: &ArrayRef, row_index: usize) -> String { + if column.is_null(row_index) { return "NULL".to_string(); } - // beyond this point, we can assume that `array...downcast().value(row_index)` is valid, - // due to the `if` above. - - match array.data_type() { - DataType::Int8 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::Int16 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::Int32 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::Int64 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::UInt8 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::UInt16 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::UInt32 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::UInt64 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::Float32 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::Float64 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::Utf8 => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::Boolean => { - let array = array.as_any().downcast_ref::().unwrap(); - format!("{:?}", array.value(row_index)) - } - DataType::FixedSizeList(_, n) => { - let array = array.as_any().downcast_ref::().unwrap(); - let array = array.value(row_index); - let mut r = Vec::with_capacity(*n as usize); - for i in 0..*n { - r.push(array_str(&array, i as usize)); - } - format!("[{}]", r.join(",")) + // Special case ListArray as there is no pretty print support for it yet + if let DataType::FixedSizeList(_, n) = column.data_type() { + let array = column + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index); + + let mut r = Vec::with_capacity(*n as usize); + for i in 0..*n { + r.push(col_str(&array, i as usize)); } - _ => "???".to_string(), + return format!("[{}]", r.join(",")); } + + array_value_to_string(column, row_index) + .ok() + .unwrap_or_else(|| "???".to_string()) } -fn result_str(results: &[RecordBatch]) -> Vec { +/// Converts the results into a 2d array of strings, `result[row][column]` +/// Special cases nulls to NULL for testing +fn result_vec(results: &[RecordBatch]) -> Vec> { let mut result = vec![]; for batch in results { for row_index in 0..batch.num_rows() { - let mut str = String::new(); - for column_index in 0..batch.num_columns() { - if column_index > 0 { - str.push_str("\t"); - } - let column = batch.column(column_index); - - str.push_str(&array_str(column, row_index)); - } - result.push(str); + let row_vec = batch + .columns() + .iter() + .map(|column| col_str(column, row_index)) + .collect(); + result.push(row_vec); } } result @@ -759,8 +997,8 @@ async fn generic_query_length>>( let mut ctx = ExecutionContext::new(); ctx.register_table("test", Box::new(table)); let sql = "SELECT length(c1) FROM test"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "0\n1\n2\n3".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; assert_eq!(expected, actual); Ok(()) } @@ -793,8 +1031,8 @@ async fn query_not() -> Result<()> { let mut ctx = ExecutionContext::new(); ctx.register_table("test", Box::new(table)); let sql = "SELECT NOT c1 FROM test"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "true\nNULL\nfalse".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["true"], vec!["NULL"], vec!["false"]]; assert_eq!(expected, actual); Ok(()) } @@ -820,7 +1058,12 @@ async fn query_concat() -> Result<()> { ctx.register_table("test", Box::new(table)); let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; - let expected = vec!["\"-hi-0\"", "\"a-hi-1\"", "NULL", "\"aaa-hi-3\""]; + let expected = vec![ + vec!["-hi-0"], + vec!["a-hi-1"], + vec!["NULL"], + vec!["aaa-hi-3"], + ]; assert_eq!(expected, actual); Ok(()) } @@ -847,10 +1090,10 @@ async fn query_array() -> Result<()> { let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![ - "[\"\",\"0\"]", - "[\"a\",\"1\"]", - "[\"aa\",NULL]", - "[\"aaa\",\"3\"]", + vec!["[,0]"], + vec!["[a,1]"], + vec!["[aa,NULL]"], + vec!["[aaa,3]"], ]; assert_eq!(expected, actual); Ok(()) @@ -872,9 +1115,9 @@ async fn like() -> Result<()> { register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; // check that the physical and logical schemas are equal - let actual = execute(&mut ctx, sql).await.join("\n"); + let actual = execute(&mut ctx, sql).await; - let expected = "1".to_string(); + let expected = vec![vec!["1"]]; assert_eq!(expected, actual); Ok(()) } @@ -908,9 +1151,9 @@ async fn to_timstamp() -> Result<()> { ctx.register_table("ts_data", make_timestamp_nano_table()?); let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; - let actual = execute(&mut ctx, sql).await.join("\n"); + let actual = execute(&mut ctx, sql).await; - let expected = "2".to_string(); + let expected = vec![vec!["2"]]; assert_eq!(expected, actual); Ok(()) } @@ -933,8 +1176,8 @@ async fn query_is_null() -> Result<()> { let mut ctx = ExecutionContext::new(); ctx.register_table("test", Box::new(table)); let sql = "SELECT c1 IS NULL FROM test"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "false\ntrue\nfalse".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["false"], vec!["true"], vec!["false"]]; assert_eq!(expected, actual); Ok(()) } @@ -957,8 +1200,75 @@ async fn query_is_not_null() -> Result<()> { let mut ctx = ExecutionContext::new(); ctx.register_table("test", Box::new(table)); let sql = "SELECT c1 IS NOT NULL FROM test"; - let actual = execute(&mut ctx, sql).await.join("\n"); - let expected = "true\nfalse\ntrue".to_string(); + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["true"], vec!["false"], vec!["true"]]; + assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn query_on_string_dictionary() -> Result<()> { + // Test to ensure DataFusion can operate on dictionary types + // Use StringDictionary (32 bit indexes = keys) + let field_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + builder.append("one")?; + builder.append_null()?; + builder.append("three")?; + let array = Arc::new(builder.finish()); + + let data = RecordBatch::try_new(schema.clone(), vec![array])?; + + let table = MemTable::new(schema, vec![vec![data]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + + // Basic SELECT + let sql = "SELECT * FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["NULL"], vec!["three"]]; + assert_eq!(expected, actual); + + // basic filtering + let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["three"]]; + assert_eq!(expected, actual); + + // The following queries are not yet supported + + // // filtering with constant + // let sql = "SELECT * FROM test WHERE d1 = 'three'"; + // let actual = execute(&mut ctx, sql).await; + // let expected = vec![ + // vec!["three"], + // ]; + // assert_eq!(expected, actual); + + // // Expression evaluation + // let sql = "SELECT concat(d1, '-foo') FROM test"; + // let actual = execute(&mut ctx, sql).await; + // let expected = vec![ + // vec!["one-foo"], + // vec!["NULL"], + // vec!["three-foo"], + // ]; + // assert_eq!(expected, actual); + + // // aggregation + // let sql = "SELECT COUNT(d1) FROM test"; + // let actual = execute(&mut ctx, sql).await; + // let expected = vec![ + // vec!["2"] + // ]; + // assert_eq!(expected, actual); + + Ok(()) +}