-
Notifications
You must be signed in to change notification settings - Fork 1.9k
feat: implement GroupsAccumulator for count(DISTINCT) aggr
#15324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 4 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
0840207
feat: implement GroupsAccumulator for count(DISTINCT) aggr
waynexia 7c166e6
finalize
waynexia f5c0935
fix clippy
waynexia 2bd4986
record data type
waynexia 7bc506d
dedicate accumulator
waynexia 86390ef
a new method
waynexia ca805fb
also for normal accumulator
waynexia 9ecd21f
clean up
waynexia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,9 +16,11 @@ | |
| // under the License. | ||
|
|
||
| use ahash::RandomState; | ||
| use arrow::buffer::{OffsetBuffer, ScalarBuffer}; | ||
| use datafusion_common::stats::Precision; | ||
| use datafusion_expr::expr::WindowFunction; | ||
| use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; | ||
| use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; | ||
| use datafusion_macros::user_doc; | ||
| use datafusion_physical_expr::expressions; | ||
| use std::collections::HashSet; | ||
|
|
@@ -41,7 +43,7 @@ use arrow::{ | |
| }; | ||
|
|
||
| use arrow::{ | ||
| array::{Array, BooleanArray, Int64Array, PrimitiveArray}, | ||
| array::{Array, BooleanArray, Int64Array, ListArray, PrimitiveArray}, | ||
| buffer::BooleanBuffer, | ||
| }; | ||
| use datafusion_common::{ | ||
|
|
@@ -62,7 +64,9 @@ use datafusion_functions_aggregate_common::aggregate::count_distinct::{ | |
| use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; | ||
| use datafusion_physical_expr_common::binary_map::OutputType; | ||
|
|
||
| use datafusion_common::cast::as_list_array; | ||
| use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; | ||
|
|
||
| make_udaf_expr_and_func!( | ||
| Count, | ||
| count, | ||
|
|
@@ -207,7 +211,9 @@ impl AggregateUDFImpl for Count { | |
| format_state_name(args.name, "count distinct"), | ||
| // See COMMENTS.md to understand why nullable is set to true | ||
| Field::new_list_field(args.input_types[0].clone(), true), | ||
| false, | ||
| // For group count distinct accumulator, null list item stands for an | ||
| // empty value set (i.e., all NULL value so far for that group). | ||
| true, | ||
| )]) | ||
| } else { | ||
| Ok(vec![Field::new( | ||
|
|
@@ -344,20 +350,23 @@ impl AggregateUDFImpl for Count { | |
| } | ||
|
|
||
| fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { | ||
| // groups accumulator only supports `COUNT(c1)`, not | ||
| // groups accumulator only supports `COUNT(c1)` or `COUNT(distinct c1)`, not | ||
| // `COUNT(c1, c2)`, etc | ||
| if args.is_distinct { | ||
| return false; | ||
| } | ||
| args.exprs.len() == 1 | ||
| } | ||
|
|
||
| fn create_groups_accumulator( | ||
| &self, | ||
| _args: AccumulatorArgs, | ||
| args: AccumulatorArgs, | ||
| ) -> Result<Box<dyn GroupsAccumulator>> { | ||
| // instantiate specialized accumulator | ||
| Ok(Box::new(CountGroupsAccumulator::new())) | ||
| if args.is_distinct { | ||
| Ok(Box::new(DistinctCountGroupsAccumulator::new( | ||
| args.exprs[0].data_type(args.schema)?, | ||
| ))) | ||
| } else { | ||
| Ok(Box::new(CountGroupsAccumulator::new())) | ||
| } | ||
| } | ||
|
|
||
| fn reverse_expr(&self) -> ReversedUDAF { | ||
|
|
@@ -752,10 +761,245 @@ impl Accumulator for DistinctCountAccumulator { | |
| } | ||
| } | ||
|
|
||
| /// GroupsAccumulator for COUNT DISTINCT operations | ||
| #[derive(Debug)] | ||
| pub struct DistinctCountGroupsAccumulator { | ||
| /// One HashSet per group to track distinct values | ||
| distinct_sets: Vec<HashSet<ScalarValue, RandomState>>, | ||
|
||
| data_type: DataType, | ||
| } | ||
|
|
||
| impl DistinctCountGroupsAccumulator { | ||
| pub fn new(data_type: DataType) -> Self { | ||
| Self { | ||
| distinct_sets: vec![], | ||
| data_type, | ||
| } | ||
| } | ||
|
|
||
| fn ensure_sets(&mut self, total_num_groups: usize) { | ||
| if self.distinct_sets.len() < total_num_groups { | ||
| self.distinct_sets | ||
| .resize_with(total_num_groups, HashSet::default); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl GroupsAccumulator for DistinctCountGroupsAccumulator { | ||
| fn update_batch( | ||
| &mut self, | ||
| values: &[ArrayRef], | ||
| group_indices: &[usize], | ||
| opt_filter: Option<&BooleanArray>, | ||
| total_num_groups: usize, | ||
| ) -> Result<()> { | ||
| assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument"); | ||
| self.ensure_sets(total_num_groups); | ||
|
|
||
| let array = &values[0]; | ||
|
|
||
| // Use a pattern similar to accumulate_indices to process rows | ||
| // that are not null and pass the filter | ||
| let nulls = array.logical_nulls(); | ||
|
|
||
| match (nulls.as_ref(), opt_filter) { | ||
| (None, None) => { | ||
| // No nulls, no filter - process all rows | ||
| for (row_idx, &group_idx) in group_indices.iter().enumerate() { | ||
| if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) { | ||
| self.distinct_sets[group_idx].insert(scalar); | ||
| } | ||
| } | ||
| } | ||
| (Some(nulls), None) => { | ||
| // Has nulls, no filter | ||
| for (row_idx, (&group_idx, is_valid)) in | ||
| group_indices.iter().zip(nulls.iter()).enumerate() | ||
| { | ||
| if is_valid { | ||
| if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) { | ||
| self.distinct_sets[group_idx].insert(scalar); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| (None, Some(filter)) => { | ||
| // No nulls, has filter | ||
| for (row_idx, (&group_idx, filter_value)) in | ||
| group_indices.iter().zip(filter.iter()).enumerate() | ||
| { | ||
| if let Some(true) = filter_value { | ||
| if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) { | ||
| self.distinct_sets[group_idx].insert(scalar); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| (Some(nulls), Some(filter)) => { | ||
| // Has nulls and filter | ||
| let iter = filter | ||
| .iter() | ||
| .zip(group_indices.iter()) | ||
| .zip(nulls.iter()) | ||
| .enumerate(); | ||
|
|
||
| for (row_idx, ((filter_value, &group_idx), is_valid)) in iter { | ||
| if is_valid && filter_value == Some(true) { | ||
| if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) { | ||
| self.distinct_sets[group_idx].insert(scalar); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> { | ||
| let distinct_sets: Vec<HashSet<ScalarValue, RandomState>> = | ||
| emit_to.take_needed(&mut self.distinct_sets); | ||
|
|
||
| let counts = distinct_sets | ||
| .iter() | ||
| .map(|set| set.len() as i64) | ||
| .collect::<Vec<_>>(); | ||
| Ok(Arc::new(Int64Array::from(counts))) | ||
| } | ||
|
|
||
| fn merge_batch( | ||
| &mut self, | ||
| values: &[ArrayRef], | ||
| group_indices: &[usize], | ||
| _opt_filter: Option<&BooleanArray>, | ||
| total_num_groups: usize, | ||
| ) -> Result<()> { | ||
| assert_eq!( | ||
| values.len(), | ||
| 1, | ||
| "COUNT DISTINCT merge expects a single state array" | ||
| ); | ||
| self.ensure_sets(total_num_groups); | ||
|
|
||
| let list_array = as_list_array(&values[0])?; | ||
|
|
||
| // For each group in the incoming batch | ||
| for (i, &group_idx) in group_indices.iter().enumerate() { | ||
| if i < list_array.len() { | ||
| let inner_array = list_array.value(i); | ||
| // Add each value to our set for this group | ||
| for j in 0..inner_array.len() { | ||
| if !inner_array.is_null(j) { | ||
| let scalar = ScalarValue::try_from_array(&inner_array, j)?; | ||
| self.distinct_sets[group_idx].insert(scalar); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { | ||
| let distinct_sets: Vec<HashSet<ScalarValue, RandomState>> = | ||
| emit_to.take_needed(&mut self.distinct_sets); | ||
|
|
||
| let mut offsets = Vec::with_capacity(distinct_sets.len() + 1); | ||
| offsets.push(0); | ||
| let mut curr_len = 0i32; | ||
|
|
||
| let mut value_iter = distinct_sets | ||
| .into_iter() | ||
| .flat_map(|set| { | ||
| // build offset | ||
| curr_len += set.len() as i32; | ||
| offsets.push(curr_len); | ||
| // convert into iter | ||
| set.into_iter() | ||
| }) | ||
| .peekable(); | ||
| let data_array: ArrayRef = if value_iter.peek().is_none() { | ||
| arrow::array::new_empty_array(&self.data_type) as _ | ||
| } else { | ||
| Arc::new(ScalarValue::iter_to_array(value_iter)?) as _ | ||
| }; | ||
| let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets)); | ||
|
|
||
| let list_array = ListArray::new( | ||
| Arc::new(Field::new_list_field(self.data_type.clone(), true)), | ||
| offset_buffer, | ||
| data_array, | ||
| None, | ||
| ); | ||
|
|
||
| Ok(vec![Arc::new(list_array) as _]) | ||
| } | ||
|
|
||
| fn convert_to_state( | ||
| &self, | ||
| values: &[ArrayRef], | ||
| opt_filter: Option<&BooleanArray>, | ||
| ) -> Result<Vec<ArrayRef>> { | ||
| // For a single distinct value per row, create a list array with that value | ||
| assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument"); | ||
| let values = ArrayRef::clone(&values[0]); | ||
|
|
||
| let offsets = | ||
| OffsetBuffer::new(ScalarBuffer::from_iter(0..values.len() as i32 + 1)); | ||
| let nulls = filtered_null_mask(opt_filter, &values); | ||
| let list_array = ListArray::new( | ||
| Arc::new(Field::new_list_field(values.data_type().clone(), true)), | ||
| offsets, | ||
| values, | ||
| nulls, | ||
| ); | ||
|
|
||
| Ok(vec![Arc::new(list_array)]) | ||
| } | ||
|
|
||
| fn supports_convert_to_state(&self) -> bool { | ||
| true | ||
| } | ||
|
|
||
| fn size(&self) -> usize { | ||
| // Base size of the struct | ||
| let mut size = size_of::<Self>(); | ||
|
|
||
| // Size of the vector holding the HashSets | ||
| size += size_of::<Vec<HashSet<ScalarValue, RandomState>>>() | ||
| + self.distinct_sets.capacity() | ||
| * size_of::<HashSet<ScalarValue, RandomState>>(); | ||
|
|
||
| // Estimate HashSet contents size more efficiently | ||
| // Instead of iterating through all values which is expensive, use an approximation | ||
| for set in &self.distinct_sets { | ||
| // Base size of the HashSet | ||
| size += set.capacity() * size_of::<(ScalarValue, ())>(); | ||
|
|
||
| // Estimate ScalarValue size using sample-based approach | ||
| // Only look at up to 10 items as a sample | ||
| let sample_size = 10.min(set.len()); | ||
| if sample_size > 0 { | ||
| let avg_size = set | ||
| .iter() | ||
| .take(sample_size) | ||
| .map(|v| v.size()) | ||
| .sum::<usize>() | ||
| / sample_size; | ||
|
|
||
| // Extrapolate to the full set | ||
| size += avg_size * (set.len() - sample_size); | ||
| } | ||
| } | ||
|
|
||
| size | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use arrow::array::NullArray; | ||
| use arrow::array::{Int32Array, NullArray, StringArray}; | ||
|
|
||
| #[test] | ||
| fn count_accumulator_nulls() -> Result<()> { | ||
|
|
@@ -764,4 +1008,48 @@ mod tests { | |
| assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_distinct_count_groups_basic() -> Result<()> { | ||
| let mut accumulator = DistinctCountGroupsAccumulator::new(DataType::Int32); | ||
| let values = vec![Arc::new(Int32Array::from(vec![1, 2, 1, 3, 2, 1])) as ArrayRef]; | ||
|
|
||
| // 3 groups | ||
| let group_indices = vec![0, 1, 0, 2, 1, 0]; | ||
| accumulator.update_batch(&values, &group_indices, None, 3)?; | ||
|
|
||
| let result = accumulator.evaluate(EmitTo::All)?; | ||
| let counts = result.as_primitive::<Int64Type>(); | ||
|
|
||
| // Group 0 should have distinct values [1] (1 appears 3 times) -> count 1 | ||
| // Group 1 should have distinct values [2] (2 appears 2 times) -> count 1 | ||
| // Group 2 should have distinct values [3] (3 appears 1 time) -> count 1 | ||
| assert_eq!(counts.value(0), 1); // Group 0: distinct values 1, 1, 1 -> count 1 | ||
| assert_eq!(counts.value(1), 1); // Group 1: distinct values 2, 2 -> count 1 | ||
| assert_eq!(counts.value(2), 1); // Group 2: distinct values 3 -> count 1 | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_distinct_count_groups_with_filter() -> Result<()> { | ||
| let mut accumulator = DistinctCountGroupsAccumulator::new(DataType::Utf8); | ||
| let values = vec![ | ||
| Arc::new(StringArray::from(vec!["a", "b", "a", "c", "b", "d"])) as ArrayRef, | ||
| ]; | ||
| // 2 groups | ||
| let group_indices = vec![0, 0, 0, 1, 1, 1]; | ||
| let filter = BooleanArray::from(vec![true, true, false, true, false, true]); | ||
| accumulator.update_batch(&values, &group_indices, Some(&filter), 2)?; | ||
|
|
||
| let result = accumulator.evaluate(EmitTo::All)?; | ||
| let counts = result.as_primitive::<Int64Type>(); | ||
|
|
||
| // Group 0 should have ["a", "b"] (filter excludes the second "a") | ||
| // Group 1 should have ["c", "d"] (filter excludes "b") | ||
| assert_eq!(counts.value(0), 2); | ||
| assert_eq!(counts.value(1), 2); | ||
|
|
||
| Ok(()) | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a follow-up, this could be specialized for types as well (e.g.
PrimitveDistinctCountGroupsAccumulator)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also using the
HashTableAPI would probably give some further gainshttps://docs.rs/hashbrown/latest/hashbrown/struct.HashTable.html