diff --git a/rust/arrow/src/compute/kernels/concat.rs b/rust/arrow/src/compute/kernels/concat.rs index 5f78cea75bd..4dc945611e8 100644 --- a/rust/arrow/src/compute/kernels/concat.rs +++ b/rust/arrow/src/compute/kernels/concat.rs @@ -114,6 +114,7 @@ pub fn concat(array_list: &[ArrayRef]) -> Result { DataType::Duration(TimeUnit::Nanosecond) => { concat_primitive::(array_data_list) } + DataType::List(nested_type) => concat_list(array_data_list, *nested_type.clone()), t => Err(ArrowError::ComputeError(format!( "Concat not supported for data type {:?}", t @@ -131,6 +132,37 @@ where Ok(ArrayBuilder::finish(&mut builder)) } +#[inline] +fn concat_primitive_list(array_data_list: &[ArrayDataRef]) -> Result +where + T: ArrowNumericType, +{ + let mut builder = ListBuilder::new(PrimitiveArray::::builder(0)); + builder.append_data(array_data_list)?; + Ok(ArrayBuilder::finish(&mut builder)) +} + +#[inline] +fn concat_list( + array_data_list: &[ArrayDataRef], + data_type: DataType, +) -> Result { + match data_type { + DataType::Int8 => concat_primitive_list::(array_data_list), + DataType::Int16 => concat_primitive_list::(array_data_list), + DataType::Int32 => concat_primitive_list::(array_data_list), + DataType::Int64 => concat_primitive_list::(array_data_list), + DataType::UInt8 => concat_primitive_list::(array_data_list), + DataType::UInt16 => concat_primitive_list::(array_data_list), + DataType::UInt32 => concat_primitive_list::(array_data_list), + DataType::UInt64 => concat_primitive_list::(array_data_list), + t => Err(ArrowError::ComputeError(format!( + "Concat not supported for list with data type {:?}", + t + ))), + } +} + #[cfg(test)] mod tests { use super::*; @@ -285,4 +317,75 @@ mod tests { Ok(()) } + + #[test] + fn test_concat_primitive_list_arrays() -> Result<()> { + fn populate_list1( + b: &mut ListBuilder>, + ) -> Result<()> { + b.values().append_value(-1)?; + b.values().append_value(-1)?; + b.values().append_value(2)?; + b.values().append_null()?; + b.values().append_null()?; + b.append(true)?; + b.append(true)?; + b.append(false)?; + b.values().append_value(10)?; + b.append(true)?; + Ok(()) + } + + fn populate_list2( + b: &mut ListBuilder>, + ) -> Result<()> { + b.append(false)?; + b.values().append_value(100)?; + b.values().append_null()?; + b.values().append_value(101)?; + b.append(true)?; + b.values().append_value(102)?; + b.append(true)?; + Ok(()) + } + + fn populate_list3( + b: &mut ListBuilder>, + ) -> Result<()> { + b.values().append_value(1000)?; + b.values().append_value(1001)?; + b.append(true)?; + Ok(()) + } + + let mut builder_in1 = ListBuilder::new(PrimitiveArray::::builder(0)); + let mut builder_in2 = ListBuilder::new(PrimitiveArray::::builder(0)); + let mut builder_in3 = ListBuilder::new(PrimitiveArray::::builder(0)); + populate_list1(&mut builder_in1)?; + populate_list2(&mut builder_in2)?; + populate_list3(&mut builder_in3)?; + + let mut builder_expected = + ListBuilder::new(PrimitiveArray::::builder(0)); + populate_list1(&mut builder_expected)?; + populate_list2(&mut builder_expected)?; + populate_list3(&mut builder_expected)?; + + let array_result = concat(&[ + Arc::new(builder_in1.finish()), + Arc::new(builder_in2.finish()), + Arc::new(builder_in3.finish()), + ])?; + + let array_expected = builder_expected.finish(); + + assert!( + array_result.equals(&array_expected), + "expect {:#?} to be: {:#?}", + array_result, + &array_expected + ); + + Ok(()) + } } diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 784e1f7b803..eabc779e49d 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -967,6 +967,133 @@ mod tests { Ok(()) } + async fn run_count_distinct_integers_aggregated_scenario( + partitions: Vec>, + ) -> Result> { + let tmp_dir = TempDir::new()?; + let mut ctx = ExecutionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("c_group", DataType::Utf8, false), + Field::new("c_int8", DataType::Int8, false), + Field::new("c_int16", DataType::Int16, false), + Field::new("c_int32", DataType::Int32, false), + Field::new("c_int64", DataType::Int64, false), + Field::new("c_uint8", DataType::UInt8, false), + Field::new("c_uint16", DataType::UInt16, false), + Field::new("c_uint32", DataType::UInt32, false), + Field::new("c_uint64", DataType::UInt64, false), + ])); + + for (i, partition) in partitions.iter().enumerate() { + let filename = format!("partition-{}.csv", i); + let file_path = tmp_dir.path().join(&filename); + let mut file = File::create(file_path)?; + for row in partition { + let row_str = format!( + "{},{}\n", + row.0, + // Populate values for each of the integer fields in the + // schema. + (0..8) + .map(|_| { row.1.to_string() }) + .collect::>() + .join(","), + ); + file.write_all(row_str.as_bytes())?; + } + } + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema).has_header(false), + )?; + + let results = collect( + &mut ctx, + " + SELECT + c_group, + COUNT(c_uint64), + COUNT(DISTINCT c_int8), + COUNT(DISTINCT c_int16), + COUNT(DISTINCT c_int32), + COUNT(DISTINCT c_int64), + COUNT(DISTINCT c_uint8), + COUNT(DISTINCT c_uint16), + COUNT(DISTINCT c_uint32), + COUNT(DISTINCT c_uint64) + FROM test + GROUP BY c_group + ", + ) + .await?; + + Ok(results) + } + + #[tokio::test] + async fn count_distinct_integers_aggregated_single_partition() -> Result<()> { + let partitions = vec![ + // The first member of each tuple will be the value for the + // `c_group` column, and the second member will be the value for + // each of the int/uint fields. + vec![ + ("a", 1), + ("a", 1), + ("a", 2), + ("b", 9), + ("c", 9), + ("c", 10), + ("c", 9), + ], + ]; + + let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 10); + assert_eq!( + test::format_batch(&batch), + vec![ + "a,3,2,2,2,2,2,2,2,2", + "c,3,2,2,2,2,2,2,2,2", + "b,1,1,1,1,1,1,1,1,1", + ], + ); + + Ok(()) + } + + #[tokio::test] + async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> { + let partitions = vec![ + // The first member of each tuple will be the value for the + // `c_group` column, and the second member will be the value for + // each of the int/uint fields. + vec![("a", 1), ("a", 1), ("a", 2), ("b", 9), ("c", 9)], + vec![("a", 1), ("a", 3), ("b", 8), ("b", 9), ("b", 10), ("b", 11)], + ]; + + let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 10); + assert_eq!( + test::format_batch(&batch), + vec![ + "a,5,3,3,3,3,3,3,3,3", + "c,1,1,1,1,1,1,1,1,1", + "b,5,4,4,4,4,4,4,4,4", + ], + ); + + Ok(()) + } + #[test] fn aggregate_with_alias() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0550741a9f9..b8d0cc7fb82 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -50,6 +50,7 @@ pub use operators::Operator; fn create_function_name( fun: &String, + distinct: bool, args: &[Expr], input_schema: &Schema, ) -> Result { @@ -57,7 +58,11 @@ fn create_function_name( .iter() .map(|e| create_name(e, input_schema)) .collect::>()?; - Ok(format!("{}({})", fun, names.join(","))) + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) } /// Returns a readable name of an expression based on the input schema. @@ -90,14 +95,17 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result { Ok(format!("{} IS NOT NULL", expr)) } Expr::ScalarFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), args, input_schema) + create_function_name(&fun.to_string(), false, args, input_schema) } Expr::ScalarUDF { fun, args, .. } => { - create_function_name(&fun.name, args, input_schema) - } - Expr::AggregateFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), args, input_schema) + create_function_name(&fun.name, false, args, input_schema) } + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => create_function_name(&fun.to_string(), *distinct, args, input_schema), Expr::AggregateUDF { fun, args } => { let mut names = Vec::with_capacity(args.len()); for e in args { @@ -195,6 +203,8 @@ pub enum Expr { fun: aggregates::AggregateFunction, /// List of expressions to feed to the functions as arguments args: Vec, + /// Whether this is a DISTINCT aggregation or not + distinct: bool, }, /// aggregate function AggregateUDF { @@ -447,6 +457,7 @@ pub fn col(name: &str) -> Expr { pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Min, + distinct: false, args: vec![expr], } } @@ -455,6 +466,7 @@ pub fn min(expr: Expr) -> Expr { pub fn max(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Max, + distinct: false, args: vec![expr], } } @@ -463,6 +475,7 @@ pub fn max(expr: Expr) -> Expr { pub fn sum(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Sum, + distinct: false, args: vec![expr], } } @@ -471,6 +484,7 @@ pub fn sum(expr: Expr) -> Expr { pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Avg, + distinct: false, args: vec![expr], } } @@ -479,6 +493,7 @@ pub fn avg(expr: Expr) -> Expr { pub fn count(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Count, + distinct: false, args: vec![expr], } } @@ -620,9 +635,18 @@ pub fn create_udaf( ) } -fn fmt_function(f: &mut fmt::Formatter, fun: &String, args: &Vec) -> fmt::Result { +fn fmt_function( + f: &mut fmt::Formatter, + fun: &String, + distinct: bool, + args: &Vec, +) -> fmt::Result { let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); - write!(f, "{}({})", fun, args.join(", ")) + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) } impl fmt::Debug for Expr { @@ -658,13 +682,20 @@ impl fmt::Debug for Expr { } } Expr::ScalarFunction { fun, args, .. } => { - fmt_function(f, &fun.to_string(), args) + fmt_function(f, &fun.to_string(), false, args) + } + Expr::ScalarUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args) } - Expr::ScalarUDF { fun, ref args, .. } => fmt_function(f, &fun.name, args), - Expr::AggregateFunction { fun, ref args, .. } => { - fmt_function(f, &fun.to_string(), args) + Expr::AggregateFunction { + fun, + distinct, + ref args, + .. + } => fmt_function(f, &fun.to_string(), *distinct, args), + Expr::AggregateUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args) } - Expr::AggregateUDF { fun, ref args, .. } => fmt_function(f, &fun.name, args), Expr::Wildcard => write!(f, "*"), Expr::Nested(expr) => write!(f, "({:?})", expr), } diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 4c44055a9da..d45c18ac27d 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -241,9 +241,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec) -> Result fun: fun.clone(), args: expressions.clone(), }), - Expr::AggregateFunction { fun, .. } => Ok(Expr::AggregateFunction { + Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { fun: fun.clone(), args: expressions.clone(), + distinct: *distinct, }), Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF { fun: fun.clone(), diff --git a/rust/datafusion/src/physical_plan/aggregates.rs b/rust/datafusion/src/physical_plan/aggregates.rs index 03833f61a52..40bb562b0e4 100644 --- a/rust/datafusion/src/physical_plan/aggregates.rs +++ b/rust/datafusion/src/physical_plan/aggregates.rs @@ -32,6 +32,7 @@ use super::{ Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{ExecutionError, Result}; +use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Schema}; use expressions::{avg_return_type, sum_return_type}; @@ -110,6 +111,7 @@ pub fn return_type( /// This function errors when `args`' can't be coerced to a valid argument type of the function. pub fn create_aggregate_expr( fun: &AggregateFunction, + distinct: bool, args: &Vec>, input_schema: &Schema, name: String, @@ -124,14 +126,40 @@ pub fn create_aggregate_expr( let return_type = return_type(&fun, &arg_types)?; - Ok(match fun { - AggregateFunction::Count => { + Ok(match (fun, distinct) { + (AggregateFunction::Count, false) => { Arc::new(expressions::Count::new(arg, name, return_type)) } - AggregateFunction::Sum => Arc::new(expressions::Sum::new(arg, name, return_type)), - AggregateFunction::Min => Arc::new(expressions::Min::new(arg, name, return_type)), - AggregateFunction::Max => Arc::new(expressions::Max::new(arg, name, return_type)), - AggregateFunction::Avg => Arc::new(expressions::Avg::new(arg, name, return_type)), + (AggregateFunction::Count, true) => { + Arc::new(distinct_expressions::DistinctCount::new( + arg_types, + args.clone(), + name, + return_type, + )) + } + (AggregateFunction::Sum, false) => { + Arc::new(expressions::Sum::new(arg, name, return_type)) + } + (AggregateFunction::Sum, true) => { + return Err(ExecutionError::NotImplemented( + "SUM(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::Min, _) => { + Arc::new(expressions::Min::new(arg, name, return_type)) + } + (AggregateFunction::Max, _) => { + Arc::new(expressions::Max::new(arg, name, return_type)) + } + (AggregateFunction::Avg, false) => { + Arc::new(expressions::Avg::new(arg, name, return_type)) + } + (AggregateFunction::Avg, true) => { + return Err(ExecutionError::NotImplemented( + "AVG(DISTINCT) aggregations are not available".to_string(), + )); + } }) } diff --git a/rust/datafusion/src/physical_plan/distinct_expressions.rs b/rust/datafusion/src/physical_plan/distinct_expressions.rs new file mode 100644 index 00000000000..2d2ab627d44 --- /dev/null +++ b/rust/datafusion/src/physical_plan/distinct_expressions.rs @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` + +use std::cell::RefCell; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::hash::Hash; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field}; + +use fnv::FnvHashSet; + +use crate::error::{ExecutionError, Result}; +use crate::physical_plan::group_scalar::GroupByScalar; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct DistinctScalarValues(Vec); + +fn format_state_name(name: &str, state_name: &str) -> String { + format!("{}[{}]", name, state_name) +} + +/// Expression for a COUNT(DISTINCT) aggregation. +#[derive(Debug)] +pub struct DistinctCount { + /// Column name + name: String, + /// The DataType for the final count + data_type: DataType, + /// The DataType for each input argument + input_data_types: Vec, + /// The input arguments + exprs: Vec>, +} + +impl DistinctCount { + /// Create a new COUNT(DISTINCT) aggregate function. + pub fn new( + input_data_types: Vec, + exprs: Vec>, + name: String, + data_type: DataType, + ) -> Self { + Self { + input_data_types, + exprs, + name, + data_type, + } + } +} + +impl AggregateExpr for DistinctCount { + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn state_fields(&self) -> Result> { + Ok(self + .input_data_types + .iter() + .map(|data_type| { + Field::new( + &format_state_name(&self.name, "count distinct"), + DataType::List(Box::new(data_type.clone())), + false, + ) + }) + .collect::>()) + } + + fn expressions(&self) -> Vec> { + self.exprs.clone() + } + + fn create_accumulator(&self) -> Result>> { + Ok(Rc::new(RefCell::new(DistinctCountAccumulator { + values: FnvHashSet::default(), + data_types: self.input_data_types.clone(), + count_data_type: self.data_type.clone(), + }))) + } +} + +#[derive(Debug)] +struct DistinctCountAccumulator { + values: FnvHashSet, + data_types: Vec, + count_data_type: DataType, +} + +impl Accumulator for DistinctCountAccumulator { + fn update(&mut self, values: &Vec) -> Result<()> { + // If a row has a NULL, it is not included in the final count. + if !values.iter().any(|v| v.is_null()) { + self.values.insert(DistinctScalarValues( + values + .iter() + .map(GroupByScalar::try_from) + .collect::>>()?, + )); + } + + Ok(()) + } + + fn merge(&mut self, states: &Vec) -> Result<()> { + if states.len() == 0 { + return Ok(()); + } + + let col_values = states + .iter() + .map(|state| match state { + ScalarValue::List(Some(values), _) => Ok(values), + _ => Err(ExecutionError::InternalError( + "Unexpected accumulator state".to_string(), + )), + }) + .collect::>>()?; + + (0..col_values[0].len()) + .map(|row_index| { + let row_values = col_values + .iter() + .map(|col| col[row_index].clone()) + .collect::>(); + self.update(&row_values) + }) + .collect::>() + } + + fn state(&self) -> Result> { + let mut cols_out = self + .data_types + .iter() + .map(|data_type| ScalarValue::List(Some(Vec::new()), data_type.clone())) + .collect::>(); + + let mut cols_vec = cols_out + .iter_mut() + .map(|c| match c { + ScalarValue::List(Some(ref mut v), _) => v, + _ => unreachable!(), + }) + .collect::>(); + + self.values.iter().for_each(|distinct_values| { + distinct_values.0.iter().enumerate().for_each( + |(col_index, distinct_value)| { + cols_vec[col_index].push(ScalarValue::from(distinct_value)); + }, + ) + }); + + Ok(cols_out) + } + + fn evaluate(&self) -> Result { + match &self.count_data_type { + DataType::UInt64 => Ok(ScalarValue::UInt64(Some(self.values.len() as u64))), + t => { + return Err(ExecutionError::InternalError(format!( + "Invalid data type {:?} for count distinct aggregation", + t + ))) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::array::ArrayRef; + use arrow::array::{ + Int16Array, Int32Array, Int64Array, Int8Array, ListArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, + }; + use arrow::array::{Int32Builder, ListBuilder, UInt64Builder}; + use arrow::datatypes::DataType; + + macro_rules! build_list { + ($LISTS:expr, $BUILDER_TYPE:ident) => {{ + let mut builder = ListBuilder::new($BUILDER_TYPE::new(0)); + for list in $LISTS.iter() { + match list { + Some(values) => { + for value in values.iter() { + match value { + Some(v) => builder.values().append_value((*v).into())?, + None => builder.values().append_null()?, + } + } + + builder.append(true)?; + } + None => { + builder.append(false)?; + } + } + } + + let array = Arc::new(builder.finish()) as ArrayRef; + + Ok(array) as Result + }}; + } + + macro_rules! state_to_vec { + ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ + match $LIST { + ScalarValue::List(_, data_type) => match data_type { + DataType::$DATA_TYPE => (), + _ => panic!("Unexpected DataType for list"), + }, + _ => panic!("Expected a ScalarValue::List"), + } + + match $LIST { + ScalarValue::List(None, _) => None, + ScalarValue::List(Some(scalar_values), _) => { + let vec = scalar_values + .iter() + .map(|scalar_value| match scalar_value { + ScalarValue::$DATA_TYPE(value) => *value, + _ => panic!("Unexpected ScalarValue variant"), + }) + .collect::>>(); + + Some(vec) + } + _ => unreachable!(), + } + }}; + } + + fn collect_states( + state1: &Vec>, + state2: &Vec>, + ) -> Vec<(Option, Option)> { + let mut states = state1 + .iter() + .zip(state2.iter()) + .map(|(a, b)| (a.clone(), b.clone())) + .collect::, Option)>>(); + states.sort(); + states + } + + fn run_update_batch( + arrays: &Vec, + ) -> Result<(Vec, ScalarValue)> { + let agg = DistinctCount::new( + arrays + .iter() + .map(|a| a.data_type().clone()) + .collect::>(), + vec![], + String::from("__col_name__"), + DataType::UInt64, + ); + + let accum = agg.create_accumulator()?; + let mut accum = accum.borrow_mut(); + accum.update_batch(arrays)?; + + Ok((accum.state()?, accum.evaluate()?)) + } + + fn run_update( + data_types: &Vec, + rows: &Vec>, + ) -> Result<(Vec, ScalarValue)> { + let agg = DistinctCount::new( + data_types.clone(), + vec![], + String::from("__col_name__"), + DataType::UInt64, + ); + + let accum = agg.create_accumulator()?; + let mut accum = accum.borrow_mut(); + + for row in rows.iter() { + accum.update(row)? + } + + Ok((accum.state()?, accum.evaluate()?)) + } + + fn run_merge_batch( + arrays: &Vec, + ) -> Result<(Vec, ScalarValue)> { + let agg = DistinctCount::new( + arrays + .iter() + .map(|a| a.as_any().downcast_ref::().unwrap()) + .map(|a| a.values().data_type().clone()) + .collect::>(), + vec![], + String::from("__col_name__"), + DataType::UInt64, + ); + + let accum = agg.create_accumulator()?; + let mut accum = accum.borrow_mut(); + accum.merge_batch(arrays)?; + + Ok((accum.state()?, accum.evaluate()?)) + } + + macro_rules! test_count_distinct_update_batch_numeric { + ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ + let values: Vec> = vec![ + Some(1), + Some(1), + None, + Some(3), + Some(2), + None, + Some(2), + Some(3), + Some(1), + ]; + + let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + let mut state_vec = + state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + state_vec.sort(); + + assert_eq!(states.len(), 1); + assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]); + assert_eq!(result, ScalarValue::UInt64(Some(3))); + + Ok(()) + }}; + } + + #[test] + fn count_distinct_update_batch_i8() -> Result<()> { + test_count_distinct_update_batch_numeric!(Int8Array, Int8, i8) + } + + #[test] + fn count_distinct_update_batch_i16() -> Result<()> { + test_count_distinct_update_batch_numeric!(Int16Array, Int16, i16) + } + + #[test] + fn count_distinct_update_batch_i32() -> Result<()> { + test_count_distinct_update_batch_numeric!(Int32Array, Int32, i32) + } + + #[test] + fn count_distinct_update_batch_i64() -> Result<()> { + test_count_distinct_update_batch_numeric!(Int64Array, Int64, i64) + } + + #[test] + fn count_distinct_update_batch_u8() -> Result<()> { + test_count_distinct_update_batch_numeric!(UInt8Array, UInt8, u8) + } + + #[test] + fn count_distinct_update_batch_u16() -> Result<()> { + test_count_distinct_update_batch_numeric!(UInt16Array, UInt16, u16) + } + + #[test] + fn count_distinct_update_batch_u32() -> Result<()> { + test_count_distinct_update_batch_numeric!(UInt32Array, UInt32, u32) + } + + #[test] + fn count_distinct_update_batch_u64() -> Result<()> { + test_count_distinct_update_batch_numeric!(UInt64Array, UInt64, u64) + } + + #[test] + fn count_distinct_update_batch_all_nulls() -> Result<()> { + let arrays = vec![Arc::new(Int32Array::from( + vec![None, None, None, None] as Vec> + )) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + assert_eq!(states.len(), 1); + assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert_eq!(result, ScalarValue::UInt64(Some(0))); + + Ok(()) + } + + #[test] + fn count_distinct_update_batch_empty() -> Result<()> { + let arrays = + vec![Arc::new(Int32Array::from(vec![] as Vec>)) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + assert_eq!(states.len(), 1); + assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert_eq!(result, ScalarValue::UInt64(Some(0))); + + Ok(()) + } + + #[test] + fn count_distinct_update_batch_multiple_columns() -> Result<()> { + let array_int8: ArrayRef = Arc::new(Int8Array::from(vec![1, 1, 2])); + let array_int16: ArrayRef = Arc::new(Int16Array::from(vec![3, 3, 4])); + let arrays = vec![array_int8, array_int16]; + + let (states, result) = run_update_batch(&arrays)?; + + let state_vec1 = state_to_vec!(&states[0], Int8, i8).unwrap(); + let state_vec2 = state_to_vec!(&states[1], Int16, i16).unwrap(); + let state_pairs = collect_states::(&state_vec1, &state_vec2); + + assert_eq!(states.len(), 2); + assert_eq!( + state_pairs, + vec![(Some(1_i8), Some(3_i16)), (Some(2_i8), Some(4_i16))] + ); + + assert_eq!(result, ScalarValue::UInt64(Some(2))); + + Ok(()) + } + + #[test] + fn count_distinct_update() -> Result<()> { + let (states, result) = run_update( + &vec![DataType::Int32, DataType::UInt64], + &vec![ + vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], + vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))], + vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], + vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))], + vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(6))], + vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(7))], + vec![ScalarValue::Int32(Some(2)), ScalarValue::UInt64(Some(7))], + ], + )?; + + let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap(); + let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap(); + let state_pairs = collect_states::(&state_vec1, &state_vec2); + + assert_eq!(states.len(), 2); + assert_eq!( + state_pairs, + vec![ + (Some(-1_i32), Some(5_u64)), + (Some(-1_i32), Some(6_u64)), + (Some(-1_i32), Some(7_u64)), + (Some(2_i32), Some(7_u64)), + (Some(5_i32), Some(1_u64)), + ] + ); + assert_eq!(result, ScalarValue::UInt64(Some(5))); + + Ok(()) + } + + #[test] + fn count_distinct_update_with_nulls() -> Result<()> { + let (states, result) = run_update( + &vec![DataType::Int32, DataType::UInt64], + &vec![ + // None of these updates contains a None, so these are accumulated. + vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], + vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], + vec![ScalarValue::Int32(Some(-2)), ScalarValue::UInt64(Some(5))], + // Each of these updates contains at least one None, so these + // won't be accumulated. + vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(None)], + vec![ScalarValue::Int32(None), ScalarValue::UInt64(Some(5))], + vec![ScalarValue::Int32(None), ScalarValue::UInt64(None)], + ], + )?; + + let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap(); + let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap(); + let state_pairs = collect_states::(&state_vec1, &state_vec2); + + assert_eq!(states.len(), 2); + assert_eq!( + state_pairs, + vec![(Some(-2_i32), Some(5_u64)), (Some(-1_i32), Some(5_u64))] + ); + + assert_eq!(result, ScalarValue::UInt64(Some(2))); + + Ok(()) + } + + #[test] + fn count_distinct_merge_batch() -> Result<()> { + let state_in1 = build_list!( + vec![ + Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32), Some(-2_i32)]), + Some(vec![Some(-2_i32), Some(-3_i32)]), + ], + Int32Builder + )?; + + let state_in2 = build_list!( + vec![ + Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]), + Some(vec![Some(5_u64), Some(7_u64)]), + ], + UInt64Builder + )?; + + let (states, result) = run_merge_batch(&vec![state_in1, state_in2])?; + + let state_out_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap(); + let state_out_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap(); + let state_pairs = collect_states::(&state_out_vec1, &state_out_vec2); + + assert_eq!( + state_pairs, + vec![ + (Some(-3_i32), Some(7_u64)), + (Some(-2_i32), Some(5_u64)), + (Some(-2_i32), Some(7_u64)), + (Some(-1_i32), Some(5_u64)), + (Some(-1_i32), Some(6_u64)), + ] + ); + + assert_eq!(result, ScalarValue::UInt64(Some(5))); + + Ok(()) + } +} diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs new file mode 100644 index 00000000000..6647df946c9 --- /dev/null +++ b/rust/datafusion/src/physical_plan/group_scalar.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines scalars used to construct groups, ex. in GROUP BY clauses. + +use std::convert::{From, TryFrom}; + +use crate::error::{ExecutionError, Result}; +use crate::scalar::ScalarValue; + +/// Enumeration of types that can be used in a GROUP BY expression (all primitives except +/// for floating point numerics) +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub(crate) enum GroupByScalar { + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + Utf8(String), +} + +impl TryFrom<&ScalarValue> for GroupByScalar { + type Error = ExecutionError; + + fn try_from(scalar_value: &ScalarValue) -> Result { + Ok(match scalar_value { + ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v), + ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v), + ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v), + ScalarValue::Int64(Some(v)) => GroupByScalar::Int64(*v), + ScalarValue::UInt8(Some(v)) => GroupByScalar::UInt8(*v), + ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v), + ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), + ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), + ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(v.clone()), + ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::UInt8(None) + | ScalarValue::UInt16(None) + | ScalarValue::UInt32(None) + | ScalarValue::UInt64(None) + | ScalarValue::Utf8(None) => { + return Err(ExecutionError::InternalError(format!( + "Cannot convert a ScalarValue holding NULL ({:?})", + scalar_value + ))); + } + v => { + return Err(ExecutionError::InternalError(format!( + "Cannot convert a ScalarValue with associated DataType {:?}", + v.get_datatype() + ))) + } + }) + } +} + +impl From<&GroupByScalar> for ScalarValue { + fn from(group_by_scalar: &GroupByScalar) -> Self { + match group_by_scalar { + GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)), + GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)), + GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)), + GroupByScalar::Int64(v) => ScalarValue::Int64(Some(*v)), + GroupByScalar::UInt8(v) => ScalarValue::UInt8(Some(*v)), + GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)), + GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)), + GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)), + GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.clone())), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::error::{ExecutionError, Result}; + + #[test] + fn from_scalar_holding_none() -> Result<()> { + let scalar_value = ScalarValue::Int8(None); + let result = GroupByScalar::try_from(&scalar_value); + + match result { + Err(ExecutionError::InternalError(error_message)) => assert_eq!( + error_message, + String::from("Cannot convert a ScalarValue holding NULL (Int8(NULL))") + ), + _ => panic!("Unexpected result"), + } + + Ok(()) + } + + #[test] + fn from_scalar_unsupported() -> Result<()> { + // Use any ScalarValue type not supported by GroupByScalar. + let scalar_value = ScalarValue::Float32(Some(1.1)); + let result = GroupByScalar::try_from(&scalar_value); + + match result { + Err(ExecutionError::InternalError(error_message)) => assert_eq!( + error_message, + String::from( + "Cannot convert a ScalarValue with associated DataType Float32" + ) + ), + _ => panic!("Unexpected result"), + } + + Ok(()) + } +} diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 1698577021b..adda633ff7d 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -40,7 +40,7 @@ use arrow::{ use fnv::FnvHashMap; -use super::{common, expressions::Column, Source}; +use super::{common, expressions::Column, group_scalar::GroupByScalar, Source}; use async_trait::async_trait; @@ -677,21 +677,6 @@ fn finalize_aggregation( } } -/// Enumeration of types that can be used in a GROUP BY expression (all primitives except -/// for floating point numerics) -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -enum GroupByScalar { - UInt8(u8), - UInt16(u16), - UInt32(u32), - UInt64(u64), - Int8(i8), - Int16(i16), - Int32(i32), - Int64(i64), - Utf8(String), -} - /// Create a Vec that can be used as a map key fn create_key( group_by_keys: &[ArrayRef], diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index 3565610ca99..79696f83b73 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -192,11 +192,13 @@ pub mod array_expressions; pub mod common; pub mod csv; pub mod datetime_expressions; +pub mod distinct_expressions; pub mod empty; pub mod explain; pub mod expressions; pub mod filter; pub mod functions; +pub mod group_scalar; pub mod hash_aggregate; pub mod limit; pub mod math_expressions; diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 95ba68bc941..663d756695f 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -476,12 +476,23 @@ impl DefaultPhysicalPlanner { }; match e { - Expr::AggregateFunction { fun, args, .. } => { + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => { let args = args .iter() .map(|e| self.create_physical_expr(e, input_schema, ctx_state)) .collect::>>()?; - aggregates::create_aggregate_expr(fun, &args, input_schema, name) + aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + input_schema, + name, + ) } Expr::AggregateUDF { fun, args, .. } => { let args = args diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index d745d706c71..cbec200d168 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -21,8 +21,12 @@ use std::{convert::TryFrom, fmt, sync::Arc}; use arrow::array::{ Array, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, LargeStringArray, StringArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, + Int8Array, LargeStringArray, ListArray, StringArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, +}; +use arrow::array::{ + Int16Builder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, UInt16Builder, + UInt32Builder, UInt64Builder, UInt8Builder, }; use arrow::{ array::{ArrayRef, PrimitiveArrayOps}, @@ -61,6 +65,8 @@ pub enum ScalarValue { Utf8(Option), /// utf-8 encoded string representing a LargeString's arrow type. LargeUtf8(Option), + /// list of nested ScalarValue + List(Option>, DataType), } macro_rules! typed_cast { @@ -73,10 +79,40 @@ macro_rules! typed_cast { }}; } +macro_rules! build_list { + ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr) => {{ + match $VALUES { + None => { + let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new(0)); + builder.append(false).unwrap(); + builder.finish() + } + Some(values) => { + let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new(values.len())); + + for scalar_value in values { + match scalar_value { + ScalarValue::$SCALAR_TY(Some(v)) => { + builder.values().append_value(*v).unwrap() + } + ScalarValue::$SCALAR_TY(None) => { + builder.values().append_null().unwrap(); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + + builder.append(true).unwrap(); + builder.finish() + } + } + }}; +} + impl ScalarValue { /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { - match *self { + match self { ScalarValue::Boolean(_) => DataType::Boolean, ScalarValue::UInt8(_) => DataType::UInt8, ScalarValue::UInt16(_) => DataType::UInt16, @@ -90,6 +126,9 @@ impl ScalarValue { ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, + ScalarValue::List(_, data_type) => { + DataType::List(Box::new(data_type.clone())) + } } } @@ -108,7 +147,8 @@ impl ScalarValue { | ScalarValue::Float32(None) | ScalarValue::Float64(None) | ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) => true, + | ScalarValue::LargeUtf8(None) + | ScalarValue::List(None, _) => true, _ => false, } } @@ -131,6 +171,17 @@ impl ScalarValue { ScalarValue::LargeUtf8(e) => { Arc::new(LargeStringArray::from(vec![e.as_deref()])) } + ScalarValue::List(values, data_type) => Arc::new(match data_type { + DataType::Int8 => build_list!(Int8Builder, Int8, values), + DataType::Int16 => build_list!(Int16Builder, Int16, values), + DataType::Int32 => build_list!(Int32Builder, Int32, values), + DataType::Int64 => build_list!(Int64Builder, Int64, values), + DataType::UInt8 => build_list!(UInt8Builder, UInt8, values), + DataType::UInt16 => build_list!(UInt16Builder, UInt16, values), + DataType::UInt32 => build_list!(UInt32Builder, UInt32, values), + DataType::UInt64 => build_list!(UInt64Builder, UInt64, values), + _ => panic!("Unexpected DataType for list"), + }), } } @@ -150,6 +201,24 @@ impl ScalarValue { DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), + DataType::List(nested_type) => { + let list_array = array.as_any().downcast_ref::().ok_or( + ExecutionError::InternalError( + "Failed to downcast ListArray".to_string(), + ), + )?; + let value = match list_array.is_null(index) { + true => None, + false => { + let nested_array = list_array.value(index); + let scalar_vec = (0..nested_array.len()) + .map(|i| ScalarValue::try_from_array(&nested_array, i)) + .collect::>>()?; + Some(scalar_vec) + } + }; + ScalarValue::List(value, *nested_type.clone()) + } other => { return Err(ExecutionError::NotImplemented(format!( "Can't create a scalar of array of type \"{:?}\"", @@ -244,6 +313,9 @@ impl TryFrom<&DataType> for ScalarValue { &DataType::UInt64 => ScalarValue::UInt64(None), &DataType::Utf8 => ScalarValue::Utf8(None), &DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + &DataType::List(ref nested_type) => { + ScalarValue::List(None, *nested_type.clone()) + } _ => { return Err(ExecutionError::NotImplemented(format!( "Can't create a scalar of type \"{:?}\"", @@ -279,6 +351,17 @@ impl fmt::Display for ScalarValue { ScalarValue::UInt64(e) => format_option!(f, e)?, ScalarValue::Utf8(e) => format_option!(f, e)?, ScalarValue::LargeUtf8(e) => format_option!(f, e)?, + ScalarValue::List(e, _) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, }; Ok(()) } @@ -300,6 +383,53 @@ impl fmt::Debug for ScalarValue { ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), ScalarValue::Utf8(_) => write!(f, "Utf8(\"{}\")", self), ScalarValue::LargeUtf8(_) => write!(f, "LargeUtf8(\"{}\")", self), + ScalarValue::List(_, _) => write!(f, "List([{}])", self), } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn scalar_list_null_to_array() -> Result<()> { + let list_array_ref = ScalarValue::List(None, DataType::UInt64).to_array(); + let list_array = list_array_ref.as_any().downcast_ref::().unwrap(); + + assert!(list_array.is_null(0)); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + + Ok(()) + } + + #[test] + fn scalar_list_to_array() -> Result<()> { + let list_array_ref = ScalarValue::List( + Some(vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]), + DataType::UInt64, + ) + .to_array(); + + let list_array = list_array_ref.as_any().downcast_ref::().unwrap(); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = prim_array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + + Ok(()) + } +} diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index b2ed920a8be..a9ec8ca15e0 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -539,7 +539,11 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { .collect::>>()? }; - return Ok(Expr::AggregateFunction { fun, args }); + return Ok(Expr::AggregateFunction { + fun, + distinct: function.distinct, + args, + }); }; // finally, user-defined functions (UDF) and UDAF diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 42d930cb8da..e2d603ebe4b 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -135,6 +135,13 @@ pub fn format_batch(batch: &RecordBatch) -> Vec { } let array = batch.column(column_index); match array.data_type() { + DataType::Utf8 => s.push_str( + array + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index), + ), DataType::Int8 => s.push_str(&format!( "{:?}", array @@ -244,3 +251,50 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { } pub mod variable; + +mod tests { + use super::*; + + use arrow::array::{BooleanArray, Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + #[test] + fn test_format_batch() -> Result<()> { + let array_int32 = Int32Array::from(vec![1000, 2000]); + let array_string = StringArray::from(vec!["bow \u{1F3F9}", "arrow \u{2191}"]); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + + let record_batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(array_int32), Arc::new(array_string)], + )?; + + let result = format_batch(&record_batch); + + assert_eq!(result, vec!["1000,bow \u{1F3F9}", "2000,arrow \u{2191}"]); + + Ok(()) + } + + #[test] + fn test_format_batch_unknown() -> Result<()> { + // Use any Array type not yet handled by format_batch(). + let array_bool = BooleanArray::from(vec![false, true]); + + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + + let record_batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array_bool)])?; + + let result = format_batch(&record_batch); + + assert_eq!(result, vec!["?", "?"]); + + Ok(()) + } +} diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 5640daa5303..afc302dba67 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1200,3 +1200,29 @@ async fn query_is_not_null() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn query_count_distinct() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + None, + Some(3), + Some(3), + ]))], + )?; + + let table = MemTable::new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + let sql = "SELECT COUNT(DISTINCT c1) FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["3".to_string()]]; + assert_eq!(expected, actual); + Ok(()) +}