diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 493fb97b82b16..36a2dd3c8a28f 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -169,6 +169,7 @@ enum AggregateFunction { COUNT = 4; APPROX_DISTINCT = 5; ARRAY_AGG = 6; + SET_AGG = 7; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 47b5df47cd730..2c162d26eece4 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1129,6 +1129,7 @@ impl TryInto for &Expr { protobuf::AggregateFunction::ApproxDistinct } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::SetAgg => protobuf::AggregateFunction::SetAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, AggregateFunction::Sum => protobuf::AggregateFunction::Sum, @@ -1364,6 +1365,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Count => Self::Count, AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, + AggregateFunction::SetAgg => Self::SetAgg, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index f5442c40e660f..1f7096849485f 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -119,6 +119,7 @@ impl From for AggregateFunction { AggregateFunction::ApproxDistinct } protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, + protobuf::AggregateFunction::SetAgg => AggregateFunction::SetAgg, } } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index e9f9696a56e8c..ed05b68c6312e 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -64,6 +64,8 @@ pub enum AggregateFunction { ApproxDistinct, /// array_agg ArrayAgg, + /// set_agg + SetAgg, } impl fmt::Display for AggregateFunction { @@ -84,6 +86,7 @@ impl FromStr for AggregateFunction { "sum" => AggregateFunction::Sum, "approx_distinct" => AggregateFunction::ApproxDistinct, "array_agg" => AggregateFunction::ArrayAgg, + "set_agg" => AggregateFunction::SetAgg, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -122,6 +125,11 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), + AggregateFunction::SetAgg => Ok(DataType::List(Box::new(Field::new( + "item", + coerced_data_types[0].clone(), + true, + )))), } } @@ -192,6 +200,11 @@ pub fn create_aggregate_expr( name, coerced_exprs_types[0].clone(), )), + (AggregateFunction::SetAgg, _) => Arc::new(expressions::SetAgg::new( + coerced_phy_exprs[0].clone(), + name, + coerced_exprs_types[0].clone(), + )), (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( coerced_phy_exprs[0].clone(), name, @@ -245,7 +258,8 @@ pub fn signature(fun: &AggregateFunction) -> Signature { match fun { AggregateFunction::Count | AggregateFunction::ApproxDistinct - | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), + | AggregateFunction::ArrayAgg + | AggregateFunction::SetAgg => Signature::any(1, Volatility::Immutable), AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS .iter() diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index e76e4a6b023e0..276cd8bf4e030 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -59,6 +59,7 @@ pub(crate) fn coerce_types( Ok(input_types.to_vec()) } AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), + AggregateFunction::SetAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type // unpack the dictionary to get the value diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 134c6d89ac4f1..2594d10722f47 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -50,6 +50,7 @@ mod nth_value; mod nullif; mod rank; mod row_number; +mod set_agg; mod sum; mod try_cast; @@ -84,6 +85,7 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub use set_agg::SetAgg; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; diff --git a/datafusion/src/physical_plan/expressions/set_agg.rs b/datafusion/src/physical_plan/expressions/set_agg.rs new file mode 100644 index 0000000000000..b12f7b0edb452 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/set_agg.rs @@ -0,0 +1,307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expression for `set_agg` which aggregates unique values into an array + +use super::format_state_name; +use crate::error::Result; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; +use arrow::datatypes::{DataType, Field}; +use hashbrown::HashSet; +use std::any::Any; +use std::sync::Arc; + +/// SET_AGG aggregate expression +#[derive(Debug)] +pub struct SetAgg { + name: String, + input_data_type: DataType, + expr: Arc, +} + +impl SetAgg { + /// Create a new SetAgg aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + input_data_type: data_type, + } + } +} + +impl AggregateExpr for SetAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(SetAggAccumulator::try_new(&self.input_data_type)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + &format_state_name(&self.name, "set_agg"), + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +pub(crate) struct SetAggAccumulator { + set: HashSet, + datatype: DataType, +} + +impl SetAggAccumulator { + /// new set_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + set: HashSet::new(), + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for SetAggAccumulator { + fn state(&self) -> Result> { + Ok(vec![ScalarValue::List( + Some(Box::new(Vec::from_iter(self.set.clone()))), + Box::new(self.datatype.clone()), + )]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let value = &values[0]; + self.set.insert(value.clone()); + + Ok(()) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + + assert!(states.len() == 1, "states length should be 1!"); + match &states[0] { + ScalarValue::List(Some(array), _) => { + for v in (&**array).iter() { + self.set.insert(v.clone()); + } + } + _ => unreachable!(), + } + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::List( + Some(Box::new(Vec::from_iter(self.set.clone().into_iter()))), + Box::new(self.datatype.clone()), + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_plan::expressions::col; + use crate::physical_plan::expressions::tests::aggregate; + use arrow::array::ArrayRef; + use arrow::array::Int32Array; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + + // When converting HashSet to Vec, ordering is unpredictable, so we are unable to use the + // generic_test_op macro. This function is similar to generic_test_op except it checks for + // the correct set_agg semantics by confirming the following: + // 1. `expected` and `actual` have the same number of elements. + // 2. `expected` contains no duplicates. + // 3. `expected` and `actual` contain the same unique elements. + fn check_set_agg( + input: ArrayRef, + expected: ScalarValue, + datatype: DataType, + ) -> Result<()> { + let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; + + let agg = Arc::new(::new( + col("a", &schema)?, + "bla".to_string(), + datatype, + )); + let actual = aggregate(&batch, agg)?; + + match (expected, actual) { + (ScalarValue::List(Some(e), _), ScalarValue::List(Some(a), _)) => { + // Check that the inputs are the same length. + assert_eq!(e.len(), a.len()); + + let h1: HashSet = HashSet::from_iter(e.clone().into_iter()); + let h2: HashSet = HashSet::from_iter(a.into_iter()); + + // Check that e's elements are unique. + assert_eq!(h1.len(), e.len()); + + // Check that a contains the same unique elements as e. + assert_eq!(h1, h2); + } + _ => { + unreachable!() + } + } + + Ok(()) + } + + #[test] + fn set_agg_i32() -> Result<()> { + let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); + + let out = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + ScalarValue::Int32(Some(7)), + ScalarValue::Int32(Some(4)), + ScalarValue::Int32(Some(5)), + ])), + Box::new(DataType::Int32), + ); + + check_set_agg(col, out, DataType::Int32) + } + + #[test] + fn set_agg_nested() -> Result<()> { + // [[1, 2, 3], [4, 5]] + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // [[6], [7, 8]] + let l2 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(7i32), + ScalarValue::from(8i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // [[9]] + let l3 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(9i32)])), + Box::new(DataType::Int32), + )])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let list = ScalarValue::List( + Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // Duplicate l1 in the input array and check that it is deduped in the output. + let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); + + check_set_agg( + array, + list, + DataType::List(Box::new(Field::new( + "item", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ))), + ) + } +} diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 243d0084d890e..2423e5faf00e3 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -219,3 +219,36 @@ async fn csv_query_array_agg_one() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn csv_query_set_agg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT set_agg(c2) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + + // Since set_agg ordering is nondeterministic, check the schema and contents. + assert_eq!( + *actual[0].schema(), + Schema::new(vec![Field::new( + "SETAGG(aggregate_test_100.c2)", + DataType::List(Box::new(Field::new("item", DataType::UInt32, true))), + false + ),]) + ); + + // Extract the underlying array buffer data and sort it to check correctness. `Buffer.typed_data` + // is unsafe so this must be wrapped in an unsafe block. + unsafe { + let data = + actual[0].column(0).data().child_data()[0].buffers()[0].typed_data::(); + + let mut sorted_data: Vec = vec![0; 5]; + sorted_data[..5].clone_from_slice(data); + sorted_data.sort_unstable(); + + assert_eq!(sorted_data, &[1, 2, 3, 4, 5]); + } + + Ok(()) +}