Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
// specific language governing permissions and limitations
// under the License.

//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]
//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]-like functions.
//!
//! This mod provides various kinds of helper functions to work with [`GroupsAccumulator`],
//! here is a quick summary of the functions provided and their purpose/differences:
//! - [`accumulate`]: Accumulate a single, primitive value per group.
//! - [`accumulate_multiple`]: Accumulate multiple, primitive values per group.
//! - [`accumulate_indices`]: Accumulate indices only (without actual value) per group.
//!
//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator

Expand Down
306 changes: 297 additions & 9 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
// under the License.

use ahash::RandomState;
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use datafusion_common::stats::Precision;
use datafusion_expr::expr::WindowFunction;
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_macros::user_doc;
use datafusion_physical_expr::expressions;
use std::collections::HashSet;
Expand All @@ -41,7 +43,7 @@ use arrow::{
};

use arrow::{
array::{Array, BooleanArray, Int64Array, PrimitiveArray},
array::{Array, BooleanArray, Int64Array, ListArray, PrimitiveArray},
buffer::BooleanBuffer,
};
use datafusion_common::{
Expand All @@ -62,7 +64,9 @@ use datafusion_functions_aggregate_common::aggregate::count_distinct::{
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::binary_map::OutputType;

use datafusion_common::cast::as_list_array;
use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;

make_udaf_expr_and_func!(
Count,
count,
Expand Down Expand Up @@ -207,7 +211,9 @@ impl AggregateUDFImpl for Count {
format_state_name(args.name, "count distinct"),
// See COMMENTS.md to understand why nullable is set to true
Field::new_list_field(args.input_types[0].clone(), true),
false,
// For group count distinct accumulator, null list item stands for an
// empty value set (i.e., all NULL value so far for that group).
true,
)])
} else {
Ok(vec![Field::new(
Expand Down Expand Up @@ -344,20 +350,23 @@ impl AggregateUDFImpl for Count {
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
// groups accumulator only supports `COUNT(c1)`, not
// groups accumulator only supports `COUNT(c1)` or `COUNT(distinct c1)`, not
// `COUNT(c1, c2)`, etc
if args.is_distinct {
return false;
}
args.exprs.len() == 1
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
// instantiate specialized accumulator
Ok(Box::new(CountGroupsAccumulator::new()))
if args.is_distinct {
Ok(Box::new(DistinctCountGroupsAccumulator::new(
args.exprs[0].data_type(args.schema)?,
)))
} else {
Ok(Box::new(CountGroupsAccumulator::new()))
}
}

fn reverse_expr(&self) -> ReversedUDAF {
Expand Down Expand Up @@ -752,10 +761,245 @@ impl Accumulator for DistinctCountAccumulator {
}
}

/// GroupsAccumulator for COUNT DISTINCT operations
#[derive(Debug)]
pub struct DistinctCountGroupsAccumulator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up, this could be specialized for types as well (e.g. PrimitveDistinctCountGroupsAccumulator)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also using the HashTable API would probably give some further gains
https://docs.rs/hashbrown/latest/hashbrown/struct.HashTable.html

/// One HashSet per group to track distinct values
distinct_sets: Vec<HashSet<ScalarValue, RandomState>>,
Copy link
Contributor

@Dandandan Dandandan Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if a single HashSet<(u64, ScalarValue), RandomState>> (i.e. also index by group id rather than create a new HashSet per group) might be faster? It will use less memory and intuitively should be more cache friendly.

data_type: DataType,
}

impl DistinctCountGroupsAccumulator {
pub fn new(data_type: DataType) -> Self {
Self {
distinct_sets: vec![],
data_type,
}
}

fn ensure_sets(&mut self, total_num_groups: usize) {
if self.distinct_sets.len() < total_num_groups {
self.distinct_sets
.resize_with(total_num_groups, HashSet::default);
}
}
}

impl GroupsAccumulator for DistinctCountGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument");
self.ensure_sets(total_num_groups);

let array = &values[0];

// Use a pattern similar to accumulate_indices to process rows
// that are not null and pass the filter
let nulls = array.logical_nulls();

match (nulls.as_ref(), opt_filter) {
(None, None) => {
// No nulls, no filter - process all rows
for (row_idx, &group_idx) in group_indices.iter().enumerate() {
if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) {
self.distinct_sets[group_idx].insert(scalar);
}
}
}
(Some(nulls), None) => {
// Has nulls, no filter
for (row_idx, (&group_idx, is_valid)) in
group_indices.iter().zip(nulls.iter()).enumerate()
{
if is_valid {
if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) {
self.distinct_sets[group_idx].insert(scalar);
}
}
}
}
(None, Some(filter)) => {
// No nulls, has filter
for (row_idx, (&group_idx, filter_value)) in
group_indices.iter().zip(filter.iter()).enumerate()
{
if let Some(true) = filter_value {
if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) {
self.distinct_sets[group_idx].insert(scalar);
}
}
}
}
(Some(nulls), Some(filter)) => {
// Has nulls and filter
let iter = filter
.iter()
.zip(group_indices.iter())
.zip(nulls.iter())
.enumerate();

for (row_idx, ((filter_value, &group_idx), is_valid)) in iter {
if is_valid && filter_value == Some(true) {
if let Ok(scalar) = ScalarValue::try_from_array(array, row_idx) {
self.distinct_sets[group_idx].insert(scalar);
}
}
}
}
}

Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let distinct_sets: Vec<HashSet<ScalarValue, RandomState>> =
emit_to.take_needed(&mut self.distinct_sets);

let counts = distinct_sets
.iter()
.map(|set| set.len() as i64)
.collect::<Vec<_>>();
Ok(Arc::new(Int64Array::from(counts)))
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
_opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(
values.len(),
1,
"COUNT DISTINCT merge expects a single state array"
);
self.ensure_sets(total_num_groups);

let list_array = as_list_array(&values[0])?;

// For each group in the incoming batch
for (i, &group_idx) in group_indices.iter().enumerate() {
if i < list_array.len() {
let inner_array = list_array.value(i);
// Add each value to our set for this group
for j in 0..inner_array.len() {
if !inner_array.is_null(j) {
let scalar = ScalarValue::try_from_array(&inner_array, j)?;
self.distinct_sets[group_idx].insert(scalar);
}
}
}
}

Ok(())
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let distinct_sets: Vec<HashSet<ScalarValue, RandomState>> =
emit_to.take_needed(&mut self.distinct_sets);

let mut offsets = Vec::with_capacity(distinct_sets.len() + 1);
offsets.push(0);
let mut curr_len = 0i32;

let mut value_iter = distinct_sets
.into_iter()
.flat_map(|set| {
// build offset
curr_len += set.len() as i32;
offsets.push(curr_len);
// convert into iter
set.into_iter()
})
.peekable();
let data_array: ArrayRef = if value_iter.peek().is_none() {
arrow::array::new_empty_array(&self.data_type) as _
} else {
Arc::new(ScalarValue::iter_to_array(value_iter)?) as _
};
let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets));

let list_array = ListArray::new(
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
offset_buffer,
data_array,
None,
);

Ok(vec![Arc::new(list_array) as _])
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
// For a single distinct value per row, create a list array with that value
assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument");
let values = ArrayRef::clone(&values[0]);

let offsets =
OffsetBuffer::new(ScalarBuffer::from_iter(0..values.len() as i32 + 1));
let nulls = filtered_null_mask(opt_filter, &values);
let list_array = ListArray::new(
Arc::new(Field::new_list_field(values.data_type().clone(), true)),
offsets,
values,
nulls,
);

Ok(vec![Arc::new(list_array)])
}

fn supports_convert_to_state(&self) -> bool {
true
}

fn size(&self) -> usize {
// Base size of the struct
let mut size = size_of::<Self>();

// Size of the vector holding the HashSets
size += size_of::<Vec<HashSet<ScalarValue, RandomState>>>()
+ self.distinct_sets.capacity()
* size_of::<HashSet<ScalarValue, RandomState>>();

// Estimate HashSet contents size more efficiently
// Instead of iterating through all values which is expensive, use an approximation
for set in &self.distinct_sets {
// Base size of the HashSet
size += set.capacity() * size_of::<(ScalarValue, ())>();

// Estimate ScalarValue size using sample-based approach
// Only look at up to 10 items as a sample
let sample_size = 10.min(set.len());
if sample_size > 0 {
let avg_size = set
.iter()
.take(sample_size)
.map(|v| v.size())
.sum::<usize>()
/ sample_size;

// Extrapolate to the full set
size += avg_size * (set.len() - sample_size);
}
}

size
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::NullArray;
use arrow::array::{Int32Array, NullArray, StringArray};

#[test]
fn count_accumulator_nulls() -> Result<()> {
Expand All @@ -764,4 +1008,48 @@ mod tests {
assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
Ok(())
}

#[test]
fn test_distinct_count_groups_basic() -> Result<()> {
let mut accumulator = DistinctCountGroupsAccumulator::new(DataType::Int32);
let values = vec![Arc::new(Int32Array::from(vec![1, 2, 1, 3, 2, 1])) as ArrayRef];

// 3 groups
let group_indices = vec![0, 1, 0, 2, 1, 0];
accumulator.update_batch(&values, &group_indices, None, 3)?;

let result = accumulator.evaluate(EmitTo::All)?;
let counts = result.as_primitive::<Int64Type>();

// Group 0 should have distinct values [1] (1 appears 3 times) -> count 1
// Group 1 should have distinct values [2] (2 appears 2 times) -> count 1
// Group 2 should have distinct values [3] (3 appears 1 time) -> count 1
assert_eq!(counts.value(0), 1); // Group 0: distinct values 1, 1, 1 -> count 1
assert_eq!(counts.value(1), 1); // Group 1: distinct values 2, 2 -> count 1
assert_eq!(counts.value(2), 1); // Group 2: distinct values 3 -> count 1

Ok(())
}

#[test]
fn test_distinct_count_groups_with_filter() -> Result<()> {
let mut accumulator = DistinctCountGroupsAccumulator::new(DataType::Utf8);
let values = vec![
Arc::new(StringArray::from(vec!["a", "b", "a", "c", "b", "d"])) as ArrayRef,
];
// 2 groups
let group_indices = vec![0, 0, 0, 1, 1, 1];
let filter = BooleanArray::from(vec![true, true, false, true, false, true]);
accumulator.update_batch(&values, &group_indices, Some(&filter), 2)?;

let result = accumulator.evaluate(EmitTo::All)?;
let counts = result.as_primitive::<Int64Type>();

// Group 0 should have ["a", "b"] (filter excludes the second "a")
// Group 1 should have ["c", "d"] (filter excludes "b")
assert_eq!(counts.value(0), 2);
assert_eq!(counts.value(1), 2);

Ok(())
}
}
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,12 @@ SELECT count(c1, c2) FROM test
query error DataFusion error: This feature is not implemented: COUNT DISTINCT with multiple arguments
SELECT count(distinct c1, c2) FROM test

# count(distinct) and count() together
query II
SELECT count(c1), count(distinct c1) FROM test
----
4 3

# count_null
query III
SELECT count(null), count(null, null), count(distinct null) FROM test
Expand Down