Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ enum AggregateFunction {
AVG = 3;
COUNT = 4;
APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
}

message AggregateExprNode {
Expand Down
2 changes: 2 additions & 0 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
}
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
Expand Down Expand Up @@ -1358,6 +1359,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::Avg => Self::Avg,
AggregateFunction::Count => Self::Count,
AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
AggregateFunction::ArrayAgg => Self::ArrayAgg,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
protobuf::AggregateFunction::ApproxDistinct => {
AggregateFunction::ApproxDistinct
}
protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg,
}
}
}
Expand Down
21 changes: 16 additions & 5 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use super::{
use crate::error::{DataFusionError, Result};
use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Schema, TimeUnit};
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use expressions::{avg_return_type, sum_return_type};
use std::{fmt, str::FromStr, sync::Arc};
/// the implementation of an aggregate function
Expand All @@ -46,7 +46,7 @@ pub type AccumulatorFunctionImplementation =
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;

/// Enum of all built-in scalar functions
/// Enum of all built-in aggregate functions
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
pub enum AggregateFunction {
/// count
Expand All @@ -61,6 +61,8 @@ pub enum AggregateFunction {
Avg,
/// Approximate aggregate function
ApproxDistinct,
/// array_agg
ArrayAgg,
}

impl fmt::Display for AggregateFunction {
Expand All @@ -80,6 +82,7 @@ impl FromStr for AggregateFunction {
"avg" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
"approx_distinct" => AggregateFunction::ApproxDistinct,
"array_agg" => AggregateFunction::ArrayAgg,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand All @@ -105,6 +108,11 @@ pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<Da
AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
AggregateFunction::Sum => sum_return_type(&arg_types[0]),
AggregateFunction::Avg => avg_return_type(&arg_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
"element",
arg_types[0].clone(),
true,
)))),
}
}

Expand Down Expand Up @@ -157,6 +165,9 @@ pub fn create_aggregate_expr(
(AggregateFunction::ApproxDistinct, _) => Arc::new(
expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
),
(AggregateFunction::ArrayAgg, _) => {
Arc::new(expressions::ArrayAgg::new(arg, name, arg_types[0].clone()))
}
(AggregateFunction::Min, _) => {
Arc::new(expressions::Min::new(arg, name, return_type))
}
Expand Down Expand Up @@ -202,9 +213,9 @@ static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Signature::any(1, Volatility::Immutable)
}
AggregateFunction::Count
| AggregateFunction::ApproxDistinct
| AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
Expand Down
246 changes: 246 additions & 0 deletions datafusion/src/physical_plan/expressions/array_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions that can evaluated at runtime during query execution

use super::format_state_name;
use crate::error::Result;
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::sync::Arc;

/// ARRAY_AGG aggregate expression
#[derive(Debug)]
pub struct ArrayAgg {
name: String,
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl ArrayAgg {
/// Create a new ArrayAgg aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
Self {
name: name.into(),
expr,
input_data_type: data_type,
}
}
}

impl AggregateExpr for ArrayAgg {
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
Ok(Field::new(
&self.name,
DataType::List(Box::new(Field::new(
"element",

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching this. should be item to be consistent.

self.input_data_type.clone(),
true,
))),
false,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(ArrayAggAccumulator::try_new(
&self.input_data_type,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
&format_state_name(&self.name, "array_agg"),
DataType::List(Box::new(Field::new(
"item",

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe have a different name?

self.input_data_type.clone(),
true,
))),
false,
)])
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
}

#[derive(Debug)]
pub(crate) struct ArrayAggAccumulator {
array: Vec<ScalarValue>,
datatype: DataType,
}

impl ArrayAggAccumulator {
/// new array_agg accumulator based on given element data type
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
array: vec![],
datatype: datatype.clone(),
})
}
}

impl Accumulator for ArrayAggAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::List(
Some(Box::new(self.array.clone())),
Box::new(self.datatype.clone()),
)])
}

fn update(&mut self, values: &[ScalarValue]) -> Result<()> {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you leave update_batch and merge_batch on purpose?
I think at this point it is hard to think of a much more efficient implementation (avoiding converting every item to scalars), given that we don't have columnar storage for aggregates yet.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I just rely on the default implementation.

let value = &values[0];
self.array.push(value.clone());

Ok(())
}

fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, after we changed the state from a ScalarValue to a Vec<ScalarValue> in early commit, we also need to update how merge works. It cannot call update directly now. Found this when adding e2e tests.

self.update(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::List(
Some(Box::new(self.array.clone())),
Box::new(self.datatype.clone()),
))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::expressions::col;
use crate::physical_plan::expressions::tests::aggregate;
use crate::{error::Result, generic_test_op};
use arrow::array::ArrayRef;
use arrow::array::Int32Array;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;

#[test]
fn array_agg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));

let list = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::Int32(Some(1)),
ScalarValue::Int32(Some(2)),
ScalarValue::Int32(Some(3)),
ScalarValue::Int32(Some(4)),
ScalarValue::Int32(Some(5)),
])),
Box::new(DataType::Int32),
);

generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
}

#[test]
fn array_agg_nested() -> Result<()> {
let l1 = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(1i32),
ScalarValue::from(2i32),
ScalarValue::from(3i32),
])),
Box::new(DataType::Int32),
),
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(4i32),
ScalarValue::from(5i32),
])),
Box::new(DataType::Int32),
),
])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let l2 = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::List(
Some(Box::new(vec![ScalarValue::from(6i32)])),
Box::new(DataType::Int32),
),
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(7i32),
ScalarValue::from(8i32),
])),
Box::new(DataType::Int32),
),
])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let l3 = ScalarValue::List(
Some(Box::new(vec![ScalarValue::List(
Some(Box::new(vec![ScalarValue::from(9i32)])),
Box::new(DataType::Int32),
)])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let list = ScalarValue::List(
Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();

generic_test_op!(
array,
DataType::List(Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true,))),
true,
))),
ArrayAgg,
list,
DataType::List(Box::new(Field::new("item", DataType::Int32, true,)))
)
}
}
2 changes: 2 additions & 0 deletions datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::record_batch::RecordBatch;

mod approx_distinct;
mod array_agg;
mod average;
#[macro_use]
mod binary;
Expand Down Expand Up @@ -58,6 +59,7 @@ pub mod helpers {
}

pub use approx_distinct::ApproxDistinct;
pub use array_agg::ArrayAgg;
pub use average::{avg_return_type, Avg, AvgAccumulator};
pub use binary::{binary, binary_operator_data_type, BinaryExpr};
pub use case::{case, CaseExpr};
Expand Down