Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ enum AggregateFunction {
AVG = 3;
COUNT = 4;
APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
}

message AggregateExprNode {
Expand Down
2 changes: 2 additions & 0 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
}
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
Expand Down Expand Up @@ -1358,6 +1359,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::Avg => Self::Avg,
AggregateFunction::Count => Self::Count,
AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
AggregateFunction::ArrayAgg => Self::ArrayAgg,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
protobuf::AggregateFunction::ApproxDistinct => {
AggregateFunction::ApproxDistinct
}
protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg,
}
}
}
Expand Down
21 changes: 16 additions & 5 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use super::{
use crate::error::{DataFusionError, Result};
use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Schema, TimeUnit};
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use expressions::{avg_return_type, sum_return_type};
use std::{fmt, str::FromStr, sync::Arc};
/// the implementation of an aggregate function
Expand All @@ -46,7 +46,7 @@ pub type AccumulatorFunctionImplementation =
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;

/// Enum of all built-in scalar functions
/// Enum of all built-in aggregate functions
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
pub enum AggregateFunction {
/// count
Expand All @@ -61,6 +61,8 @@ pub enum AggregateFunction {
Avg,
/// Approximate aggregate function
ApproxDistinct,
/// array_agg
ArrayAgg,
}

impl fmt::Display for AggregateFunction {
Expand All @@ -80,6 +82,7 @@ impl FromStr for AggregateFunction {
"avg" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
"approx_distinct" => AggregateFunction::ApproxDistinct,
"array_agg" => AggregateFunction::ArrayAgg,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand All @@ -105,6 +108,11 @@ pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<Da
AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
AggregateFunction::Sum => sum_return_type(&arg_types[0]),
AggregateFunction::Avg => avg_return_type(&arg_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
"item",
arg_types[0].clone(),
true,
)))),
}
}

Expand Down Expand Up @@ -157,6 +165,9 @@ pub fn create_aggregate_expr(
(AggregateFunction::ApproxDistinct, _) => Arc::new(
expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
),
(AggregateFunction::ArrayAgg, _) => {
Arc::new(expressions::ArrayAgg::new(arg, name, arg_types[0].clone()))
}
(AggregateFunction::Min, _) => {
Arc::new(expressions::Min::new(arg, name, return_type))
}
Expand Down Expand Up @@ -202,9 +213,9 @@ static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Signature::any(1, Volatility::Immutable)
}
AggregateFunction::Count
| AggregateFunction::ApproxDistinct
| AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
Expand Down
257 changes: 257 additions & 0 deletions datafusion/src/physical_plan/expressions/array_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions that can evaluated at runtime during query execution

use super::format_state_name;
use crate::error::Result;
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::sync::Arc;

/// ARRAY_AGG aggregate expression
#[derive(Debug)]
pub struct ArrayAgg {
name: String,
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl ArrayAgg {
/// Create a new ArrayAgg aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
Self {
name: name.into(),
expr,
input_data_type: data_type,
}
}
}

impl AggregateExpr for ArrayAgg {
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
Ok(Field::new(
&self.name,
DataType::List(Box::new(Field::new(
"item",
self.input_data_type.clone(),
true,
))),
false,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(ArrayAggAccumulator::try_new(
&self.input_data_type,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
&format_state_name(&self.name, "array_agg"),
DataType::List(Box::new(Field::new(
"item",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe have a different name?

self.input_data_type.clone(),
true,
))),
false,
)])
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
}

#[derive(Debug)]
pub(crate) struct ArrayAggAccumulator {
array: Vec<ScalarValue>,
datatype: DataType,
}

impl ArrayAggAccumulator {
/// new array_agg accumulator based on given item data type
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
array: vec![],
datatype: datatype.clone(),
})
}
}

impl Accumulator for ArrayAggAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::List(
Some(Box::new(self.array.clone())),
Box::new(self.datatype.clone()),
)])
}

fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you leave update_batch and merge_batch on purpose?
I think at this point it is hard to think of a much more efficient implementation (avoiding converting every item to scalars), given that we don't have columnar storage for aggregates yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I just rely on the default implementation.

let value = &values[0];
self.array.push(value.clone());

Ok(())
}

fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, after we changed the state from a ScalarValue to a Vec<ScalarValue> in early commit, we also need to update how merge works. It cannot call update directly now. Found this when adding e2e tests.

if states.is_empty() {
return Ok(());
};

assert!(states.len() == 1, "states length should be 1!");
match &states[0] {
Copy link
Contributor

@alamb alamb Nov 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would makes sense to assert here that states.len() == 1? so we don't (silently) end up ignoring any other items that might be added (accidentally / erroniously)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, let me add it.

ScalarValue::List(Some(array), _) => {
self.array.extend((&**array).clone());
}
_ => unreachable!(),
}
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::List(
Some(Box::new(self.array.clone())),
Box::new(self.datatype.clone()),
))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::expressions::col;
use crate::physical_plan::expressions::tests::aggregate;
use crate::{error::Result, generic_test_op};
use arrow::array::ArrayRef;
use arrow::array::Int32Array;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;

#[test]
fn array_agg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));

let list = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::Int32(Some(1)),
ScalarValue::Int32(Some(2)),
ScalarValue::Int32(Some(3)),
ScalarValue::Int32(Some(4)),
ScalarValue::Int32(Some(5)),
])),
Box::new(DataType::Int32),
);

generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
}

#[test]
fn array_agg_nested() -> Result<()> {
let l1 = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(1i32),
ScalarValue::from(2i32),
ScalarValue::from(3i32),
])),
Box::new(DataType::Int32),
),
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(4i32),
ScalarValue::from(5i32),
])),
Box::new(DataType::Int32),
),
])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let l2 = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::List(
Some(Box::new(vec![ScalarValue::from(6i32)])),
Box::new(DataType::Int32),
),
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(7i32),
ScalarValue::from(8i32),
])),
Box::new(DataType::Int32),
),
])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let l3 = ScalarValue::List(
Some(Box::new(vec![ScalarValue::List(
Some(Box::new(vec![ScalarValue::from(9i32)])),
Box::new(DataType::Int32),
)])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let list = ScalarValue::List(
Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);

let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();

generic_test_op!(
array,
DataType::List(Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true,))),
true,
))),
ArrayAgg,
list,
DataType::List(Box::new(Field::new("item", DataType::Int32, true,)))
)
}
}
2 changes: 2 additions & 0 deletions datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::record_batch::RecordBatch;

mod approx_distinct;
mod array_agg;
mod average;
#[macro_use]
mod binary;
Expand Down Expand Up @@ -58,6 +59,7 @@ pub mod helpers {
}

pub use approx_distinct::ApproxDistinct;
pub use array_agg::ArrayAgg;
pub use average::{avg_return_type, Avg, AvgAccumulator};
pub use binary::{binary, binary_operator_data_type, BinaryExpr};
pub use case::{case, CaseExpr};
Expand Down
Loading