diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 18158114a93b7..493fb97b82b16 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -168,6 +168,7 @@ enum AggregateFunction { AVG = 3; COUNT = 4; APPROX_DISTINCT = 5; + ARRAY_AGG = 6; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index e4c7656cc8a1c..805fe3173b7f4 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1124,6 +1124,7 @@ impl TryInto for &Expr { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct } + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, AggregateFunction::Sum => protobuf::AggregateFunction::Sum, @@ -1358,6 +1359,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Avg => Self::Avg, AggregateFunction::Count => Self::Count, AggregateFunction::ApproxDistinct => Self::ApproxDistinct, + AggregateFunction::ArrayAgg => Self::ArrayAgg, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 4a32b24b95317..b5c3c3c364680 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -117,6 +117,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxDistinct => { AggregateFunction::ApproxDistinct } + protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, } } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index eb3f6ca409a4f..0c99c4f99caf7 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -34,7 +34,7 @@ use super::{ use crate::error::{DataFusionError, Result}; use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; -use arrow::datatypes::{DataType, Schema, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use expressions::{avg_return_type, sum_return_type}; use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function @@ -46,7 +46,7 @@ pub type AccumulatorFunctionImplementation = pub type StateTypeFunction = Arc Result>> + Send + Sync>; -/// Enum of all built-in scalar functions +/// Enum of all built-in aggregate functions #[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] pub enum AggregateFunction { /// count @@ -61,6 +61,8 @@ pub enum AggregateFunction { Avg, /// Approximate aggregate function ApproxDistinct, + /// array_agg + ArrayAgg, } impl fmt::Display for AggregateFunction { @@ -80,6 +82,7 @@ impl FromStr for AggregateFunction { "avg" => AggregateFunction::Avg, "sum" => AggregateFunction::Sum, "approx_distinct" => AggregateFunction::ApproxDistinct, + "array_agg" => AggregateFunction::ArrayAgg, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -105,6 +108,11 @@ pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result Ok(arg_types[0].clone()), AggregateFunction::Sum => sum_return_type(&arg_types[0]), AggregateFunction::Avg => avg_return_type(&arg_types[0]), + AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))), } } @@ -157,6 +165,9 @@ pub fn create_aggregate_expr( (AggregateFunction::ApproxDistinct, _) => Arc::new( expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()), ), + (AggregateFunction::ArrayAgg, _) => { + Arc::new(expressions::ArrayAgg::new(arg, name, arg_types[0].clone())) + } (AggregateFunction::Min, _) => { Arc::new(expressions::Min::new(arg, name, return_type)) } @@ -202,9 +213,9 @@ static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; pub fn signature(fun: &AggregateFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match fun { - AggregateFunction::Count | AggregateFunction::ApproxDistinct => { - Signature::any(1, Volatility::Immutable) - } + AggregateFunction::Count + | AggregateFunction::ApproxDistinct + | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS .iter() diff --git a/datafusion/src/physical_plan/expressions/array_agg.rs b/datafusion/src/physical_plan/expressions/array_agg.rs new file mode 100644 index 0000000000000..213b392f627b9 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/array_agg.rs @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use super::format_state_name; +use crate::error::Result; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::sync::Arc; + +/// ARRAY_AGG aggregate expression +#[derive(Debug)] +pub struct ArrayAgg { + name: String, + input_data_type: DataType, + expr: Arc, +} + +impl ArrayAgg { + /// Create a new ArrayAgg aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + input_data_type: data_type, + } + } +} + +impl AggregateExpr for ArrayAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(ArrayAggAccumulator::try_new( + &self.input_data_type, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + &format_state_name(&self.name, "array_agg"), + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } +} + +#[derive(Debug)] +pub(crate) struct ArrayAggAccumulator { + array: Vec, + datatype: DataType, +} + +impl ArrayAggAccumulator { + /// new array_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + array: vec![], + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for ArrayAggAccumulator { + fn state(&self) -> Result> { + Ok(vec![ScalarValue::List( + Some(Box::new(self.array.clone())), + Box::new(self.datatype.clone()), + )]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let value = &values[0]; + self.array.push(value.clone()); + + Ok(()) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + + assert!(states.len() == 1, "states length should be 1!"); + match &states[0] { + ScalarValue::List(Some(array), _) => { + self.array.extend((&**array).clone()); + } + _ => unreachable!(), + } + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::List( + Some(Box::new(self.array.clone())), + Box::new(self.datatype.clone()), + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::expressions::col; + use crate::physical_plan::expressions::tests::aggregate; + use crate::{error::Result, generic_test_op}; + use arrow::array::ArrayRef; + use arrow::array::Int32Array; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + + #[test] + fn array_agg_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + + let list = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + ScalarValue::Int32(Some(3)), + ScalarValue::Int32(Some(4)), + ScalarValue::Int32(Some(5)), + ])), + Box::new(DataType::Int32), + ); + + generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + } + + #[test] + fn array_agg_nested() -> Result<()> { + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l2 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(7i32), + ScalarValue::from(8i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l3 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(9i32)])), + Box::new(DataType::Int32), + )])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let list = ScalarValue::List( + Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); + + generic_test_op!( + array, + DataType::List(Box::new(Field::new( + "item", + DataType::List(Box::new(Field::new("item", DataType::Int32, true,))), + true, + ))), + ArrayAgg, + list, + DataType::List(Box::new(Field::new("item", DataType::Int32, true,))) + ) + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index dba3bde9a40ec..5647ee0a4d270 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; mod approx_distinct; +mod array_agg; mod average; #[macro_use] mod binary; @@ -58,6 +59,7 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; +pub use array_agg::ArrayAgg; pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index eeb6c10926b14..15241ee269fba 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1281,6 +1281,60 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_array_agg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------------------------------------------------------------------+", + "| ARRAYAGG(test.c13) |", + "+------------------------------------------------------------------+", + "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB] |", + "+------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_array_agg_empty() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------+", + "| ARRAYAGG(test.c13) |", + "+--------------------+", + "| [] |", + "+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_array_agg_one() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------------+", + "| ARRAYAGG(test.c13) |", + "+----------------------------------+", + "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm] |", + "+----------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_empty_over() -> Result<()> {