diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index e629e99e1657..bb1b1bc3387d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -15,7 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] +//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]-like functions. +//! +//! This mod provides various kinds of helper functions to work with [`GroupsAccumulator`], +//! here is a quick summary of the functions provided and their purpose/differences: +//! - [`accumulate`]: Accumulate a single, primitive value per group. +//! - [`accumulate_multiple`]: Accumulate multiple, primitive values per group. +//! - [`accumulate_indices`]: Accumulate indices only (without actual value) per group. //! //! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2d995b4a4179..fb87e713b1c7 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,9 +16,13 @@ // under the License. use ahash::RandomState; +use arrow::array::UInt64Array; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; +use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_expr::expr::WindowFunction; -use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; use std::collections::HashSet; @@ -30,18 +34,11 @@ use std::sync::Arc; use arrow::{ array::{ArrayRef, AsArray}, compute, - datatypes::{ - DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, - }, + datatypes::{DataType, Field, Int64Type}, }; use arrow::{ - array::{Array, BooleanArray, Int64Array, PrimitiveArray}, + array::{Array, BooleanArray, Int64Array, ListArray, PrimitiveArray}, buffer::BooleanBuffer, }; use datafusion_common::{ @@ -55,14 +52,13 @@ use datafusion_expr::{ use datafusion_expr::{ Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, }; -use datafusion_functions_aggregate_common::aggregate::count_distinct::{ - BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, - PrimitiveDistinctCountAccumulator, -}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; -use datafusion_physical_expr_common::binary_map::OutputType; +use datafusion_common::cast::as_list_array; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; + +type HashValueType = u64; + make_udaf_expr_and_func!( Count, count, @@ -206,8 +202,11 @@ impl AggregateUDFImpl for Count { Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), - false, + // Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(DataType::UInt64, true), + // For group count distinct accumulator, null list item stands for an + // empty value set (i.e., all NULL value so far for that group). + true, )]) } else { Ok(vec![Field::new( @@ -227,116 +226,11 @@ impl AggregateUDFImpl for Count { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - DataType::Int8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - DataType::Date32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Date64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time32(TimeUnit::Millisecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time32(TimeUnit::Second) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time64(TimeUnit::Microsecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time64(TimeUnit::Nanosecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Second, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - - DataType::Float16 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float32 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float64 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - - DataType::Utf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - DataType::Utf8View => { - Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) - } - DataType::LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( - OutputType::BinaryView, - )), - DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: data_type.clone(), - }), - }) + Ok(Box::new(DistinctCountAccumulator { + values: HashSet::default(), + random_state: RandomState::with_seeds(1, 2, 3, 4), + batch_hashes: vec![], + })) } fn aliases(&self) -> &[String] { @@ -344,20 +238,25 @@ impl AggregateUDFImpl for Count { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - // groups accumulator only supports `COUNT(c1)`, not + // groups accumulator only supports `COUNT(c1)` or `COUNT(distinct c1)`, not // `COUNT(c1, c2)`, etc - if args.is_distinct { - return false; - } args.exprs.len() == 1 } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) + if args.is_distinct { + if args.exprs.len() > 1 { + return not_impl_err!("COUNT DISTINCT with multiple arguments"); + } + + Ok(Box::new(DistinctCountGroupsAccumulator::new())) + } else { + Ok(Box::new(CountGroupsAccumulator::new())) + } } fn reverse_expr(&self) -> ReversedUDAF { @@ -654,8 +553,9 @@ fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { /// [`BytesDistinctCountAccumulator`] #[derive(Debug)] struct DistinctCountAccumulator { - values: HashSet, - state_data_type: DataType, + values: HashSet, + random_state: RandomState, + batch_hashes: Vec, } impl DistinctCountAccumulator { @@ -664,37 +564,24 @@ impl DistinctCountAccumulator { // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { size_of_val(self) - + (size_of::() * self.values.capacity()) + + (size_of::() * self.values.capacity()) + self .values .iter() .next() - .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) + .map(|vals| size_of::() - size_of_val(vals)) .unwrap_or(0) + size_of::() } - - // calculates the size as accurately as possible. Note that calling this - // method is expensive - fn full_size(&self) -> usize { - size_of_val(self) - + (size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) - .sum::() - + size_of::() - } } impl Accumulator for DistinctCountAccumulator { /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { - let scalars = self.values.iter().cloned().collect::>(); - let arr = - ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type); - Ok(vec![ScalarValue::List(arr)]) + let values = self.values.iter().cloned().collect::>(); + let arr = Arc::new(UInt64Array::from(values)) as _; + let list_scalar = SingleRowListArrayBuilder::new(arr).build_list_scalar(); + Ok(vec![list_scalar]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -707,13 +594,24 @@ impl Accumulator for DistinctCountAccumulator { return Ok(()); } - (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.insert(scalar); - } - Ok(()) - }) + // (0..arr.len()).try_for_each(|index| { + // if !arr.is_null(index) { + // let scalar = ScalarValue::try_from_array(arr, index)?; + // self.values.insert(scalar); + // } + // Ok(()) + // }) + self.batch_hashes.clear(); + self.batch_hashes.resize(arr.len(), 0); + let hashes = create_hashes( + &[ArrayRef::clone(arr)], + &self.random_state, + &mut self.batch_hashes, + )?; + for hash in hashes.as_slice() { + self.values.insert(*hash); + } + Ok(()) } /// Merges multiple sets of distinct values into the current set. @@ -734,7 +632,11 @@ impl Accumulator for DistinctCountAccumulator { "Intermediate results of COUNT DISTINCT should always be non null" ); }; - self.update_batch(&[inner_array])?; + // self.update_batch(&[inner_array])?; + let hash_array = inner_array.as_any().downcast_ref::().unwrap(); + for i in 0..hash_array.len() { + self.values.insert(hash_array.value(i)); + } } Ok(()) } @@ -744,18 +646,245 @@ impl Accumulator for DistinctCountAccumulator { } fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), + self.fixed_size() + } +} + +/// GroupsAccumulator for COUNT DISTINCT operations +#[derive(Debug)] +pub struct DistinctCountGroupsAccumulator { + /// One HashSet per group to track distinct values + distinct_sets: Vec>, + random_state: RandomState, + batch_hashes: Vec, +} + +impl Default for DistinctCountGroupsAccumulator { + fn default() -> Self { + Self::new() + } +} + +impl DistinctCountGroupsAccumulator { + pub fn new() -> Self { + Self { + distinct_sets: vec![], + random_state: RandomState::with_seeds(1, 2, 3, 4), + batch_hashes: vec![], + } + } + + fn ensure_sets(&mut self, total_num_groups: usize) { + if self.distinct_sets.len() < total_num_groups { + self.distinct_sets + .resize_with(total_num_groups, HashSet::default); + } + } +} + +impl GroupsAccumulator for DistinctCountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument"); + self.ensure_sets(total_num_groups); + + let array = &values[0]; + self.batch_hashes.clear(); + self.batch_hashes.resize(array.len(), 0); + let hashes = create_hashes( + &[ArrayRef::clone(array)], + &self.random_state, + &mut self.batch_hashes, + )?; + + // Use a pattern similar to accumulate_indices to process rows + // that are not null and pass the filter + let nulls = array.logical_nulls(); + + match (nulls.as_ref(), opt_filter) { + (None, None) => { + // No nulls, no filter - process all rows + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + self.distinct_sets[group_idx].insert(hashes[row_idx]); + } + } + (Some(nulls), None) => { + // Has nulls, no filter + for (row_idx, (&group_idx, is_valid)) in + group_indices.iter().zip(nulls.iter()).enumerate() + { + if is_valid { + self.distinct_sets[group_idx].insert(hashes[row_idx]); + } + } + } + (None, Some(filter)) => { + // No nulls, has filter + for (row_idx, (&group_idx, filter_value)) in + group_indices.iter().zip(filter.iter()).enumerate() + { + if let Some(true) = filter_value { + self.distinct_sets[group_idx].insert(hashes[row_idx]); + } + } + } + (Some(nulls), Some(filter)) => { + // Has nulls and filter + let iter = filter + .iter() + .zip(group_indices.iter()) + .zip(nulls.iter()) + .enumerate(); + + for (row_idx, ((filter_value, &group_idx), is_valid)) in iter { + if is_valid && filter_value == Some(true) { + self.distinct_sets[group_idx].insert(hashes[row_idx]); + } + } + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let distinct_sets: Vec> = + emit_to.take_needed(&mut self.distinct_sets); + + let counts = distinct_sets + .iter() + .map(|set| set.len() as i64) + .collect::>(); + Ok(Arc::new(Int64Array::from(counts))) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!( + values.len(), + 1, + "COUNT DISTINCT merge expects a single state array" + ); + self.ensure_sets(total_num_groups); + + let list_array = as_list_array(&values[0])?; + + // For each group in the incoming batch + for (i, &group_idx) in group_indices.iter().enumerate() { + if i < list_array.len() { + let inner_array = list_array.value(i); + let inner_array = + inner_array.as_any().downcast_ref::().unwrap(); + // Add each value to our set for this group + for j in 0..inner_array.len() { + if !inner_array.is_null(j) { + // let scalar = ScalarValue::try_from_array(&inner_array, j)?; + self.distinct_sets[group_idx].insert(inner_array.value(j)); + } + } + } + } + + Ok(()) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let distinct_sets: Vec> = + emit_to.take_needed(&mut self.distinct_sets); + + let mut offsets = Vec::with_capacity(distinct_sets.len() + 1); + offsets.push(0); + let mut curr_len = 0i32; + + let mut value_iter = distinct_sets + .into_iter() + .flat_map(|set| { + // build offset + curr_len += set.len() as i32; + offsets.push(curr_len); + // convert into iter + set.into_iter() + }) + .peekable(); + let data_array: ArrayRef = if value_iter.peek().is_none() { + arrow::array::new_empty_array(&DataType::UInt64) as _ + } else { + Arc::new(UInt64Array::from_iter_values(value_iter)) + }; + let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets)); + + let list_array = ListArray::new( + Arc::new(Field::new_list_field(DataType::UInt64, true)), + offset_buffer, + data_array, + None, + ); + + Ok(vec![Arc::new(list_array) as _]) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // For a single distinct value per row, create a list array with that value + assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument"); + let values = ArrayRef::clone(&values[0]); + + let offsets = + OffsetBuffer::new(ScalarBuffer::from_iter(0..values.len() as i32 + 1)); + let nulls = filtered_null_mask(opt_filter, &values); + let list_array = ListArray::new( + // Arc::new(Field::new_list_field(values.data_type().clone(), true)), + Arc::new(Field::new_list_field(DataType::UInt64, true)), + offsets, + values, + nulls, + ); + + Ok(vec![Arc::new(list_array)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + // Base size of the struct + let mut size = size_of::(); + + // Size of the vector holding the HashSets + size += size_of::>>() + + self.distinct_sets.capacity() + * size_of::>(); + + // Estimate HashSet contents size more efficiently + // Instead of iterating through all values which is expensive, use an approximation + for set in &self.distinct_sets { + // Base size of the HashSet + size += set.capacity() * size_of::<(HashValueType, ())>(); + size += size_of::() * set.len(); } + + size } } #[cfg(test)] mod tests { use super::*; - use arrow::array::NullArray; + use arrow::array::{Int32Array, NullArray, StringArray}; #[test] fn count_accumulator_nulls() -> Result<()> { @@ -764,4 +893,48 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn test_distinct_count_groups_basic() -> Result<()> { + let mut accumulator = DistinctCountGroupsAccumulator::new(); + let values = vec![Arc::new(Int32Array::from(vec![1, 2, 1, 3, 2, 1])) as ArrayRef]; + + // 3 groups + let group_indices = vec![0, 1, 0, 2, 1, 0]; + accumulator.update_batch(&values, &group_indices, None, 3)?; + + let result = accumulator.evaluate(EmitTo::All)?; + let counts = result.as_primitive::(); + + // Group 0 should have distinct values [1] (1 appears 3 times) -> count 1 + // Group 1 should have distinct values [2] (2 appears 2 times) -> count 1 + // Group 2 should have distinct values [3] (3 appears 1 time) -> count 1 + assert_eq!(counts.value(0), 1); // Group 0: distinct values 1, 1, 1 -> count 1 + assert_eq!(counts.value(1), 1); // Group 1: distinct values 2, 2 -> count 1 + assert_eq!(counts.value(2), 1); // Group 2: distinct values 3 -> count 1 + + Ok(()) + } + + #[test] + fn test_distinct_count_groups_with_filter() -> Result<()> { + let mut accumulator = DistinctCountGroupsAccumulator::new(); + let values = vec![ + Arc::new(StringArray::from(vec!["a", "b", "a", "c", "b", "d"])) as ArrayRef, + ]; + // 2 groups + let group_indices = vec![0, 0, 0, 1, 1, 1]; + let filter = BooleanArray::from(vec![true, true, false, true, false, true]); + accumulator.update_batch(&values, &group_indices, Some(&filter), 2)?; + + let result = accumulator.evaluate(EmitTo::All)?; + let counts = result.as_primitive::(); + + // Group 0 should have ["a", "b"] (filter excludes the second "a") + // Group 1 should have ["c", "d"] (filter excludes "b") + assert_eq!(counts.value(0), 2); + assert_eq!(counts.value(1), 2); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index bc43f6bc8e61..28088b56b102 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2302,6 +2302,12 @@ SELECT count(c1, c2) FROM test query error DataFusion error: This feature is not implemented: COUNT DISTINCT with multiple arguments SELECT count(distinct c1, c2) FROM test +# count(distinct) and count() together +query II +SELECT count(c1), count(distinct c1) FROM test +---- +4 3 + # count_null query III SELECT count(null), count(null, null), count(distinct null) FROM test