Skip to content

Commit

Permalink
distinct
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jul 17, 2024
1 parent 9d17c1c commit ac640fb
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 497 deletions.
3 changes: 2 additions & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1842,8 +1842,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
// TODO: Remove this after array_agg are all udafs
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::UDF(udf)
if udf.name() == "ARRAY_AGG" && (*distinct || order_by.is_some()) =>
if udf.name() == "ARRAY_AGG" && order_by.is_some() =>
{
// not yet support UDAF, fallback to builtin
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
Expand Down
79 changes: 78 additions & 1 deletion datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Defines physical expressions that can evaluated at runtime during query execution

use arrow::array::{Array, ArrayRef};
use arrow::array::{Array, ArrayRef, AsArray};
use arrow::datatypes::DataType;
use arrow_schema::Field;

Expand All @@ -29,6 +29,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::AggregateUDFImpl;
use datafusion_expr::{Accumulator, Signature, Volatility};
use std::collections::HashSet;
use std::sync::Arc;

make_udaf_expr_and_func!(
Expand Down Expand Up @@ -82,6 +83,14 @@ impl AggregateUDFImpl for ArrayAgg {
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
if args.is_distinct {
return Ok(vec![Field::new_list(
format_state_name(args.name, "distinct_array_agg"),
Field::new("item", args.input_type.clone(), true),
true,
)]);
}

Ok(vec![Field::new_list(
format_state_name(args.name, "array_agg"),
Field::new("item", args.input_type.clone(), true),
Expand All @@ -90,6 +99,12 @@ impl AggregateUDFImpl for ArrayAgg {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if acc_args.is_distinct {
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
acc_args.input_type,
)?));
}

Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?))
}
}
Expand Down Expand Up @@ -170,3 +185,65 @@ impl Accumulator for ArrayAggAccumulator {
- std::mem::size_of_val(&self.datatype)
}
}

#[derive(Debug)]
struct DistinctArrayAggAccumulator {
values: HashSet<ScalarValue>,
datatype: DataType,
}

impl DistinctArrayAggAccumulator {
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
values: HashSet::new(),
datatype: datatype.clone(),
})
}
}

impl Accumulator for DistinctArrayAggAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
assert_eq!(values.len(), 1, "batch input should only include 1 column!");

let array = &values[0];

for i in 0..array.len() {
let scalar = ScalarValue::try_from_array(&array, i)?;
self.values.insert(scalar);
}

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}

states[0]
.as_list::<i32>()
.iter()
.flatten()
.try_for_each(|val| self.update_batch(&[val]))
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
if values.is_empty() {
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
}
let arr = ScalarValue::new_list(&values, &self.datatype, true);
Ok(ScalarValue::List(arr))
}

fn size(&self) -> usize {
std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
- std::mem::size_of_val(&self.values)
+ self.datatype.size()
- std::mem::size_of_val(&self.datatype)
}
}
Loading

0 comments on commit ac640fb

Please sign in to comment.