Skip to content

Commit

Permalink
implement max_by aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Sep 2, 2024
1 parent 94d178e commit d64d16f
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 1 deletion.
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl FirstValueAccumulator {
}

// Updates state with the values in the given row.
fn update_with_new_row(&mut self, row: &[ScalarValue]) {
pub fn update_with_new_row(&mut self, row: &[ScalarValue]) {
self.first = row[0].clone();
self.orderings = row[1..].to_vec();
self.is_set = true;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub mod average;
pub mod bit_and_or_xor;
pub mod bool_and_or;
pub mod grouping;
pub mod max_by;
pub mod nth_value;
pub mod string_agg;

Expand Down Expand Up @@ -169,6 +170,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
average::avg_udaf(),
grouping::grouping_udaf(),
nth_value::nth_value_udaf(),
max_by::max_by_udaf(),
]
}

Expand Down
168 changes: 168 additions & 0 deletions datafusion/functions-aggregate/src/max_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayRef};
use arrow_schema::DataType;
use datafusion_common::utils::get_row_at_idx;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
use std::any::Any;
use std::fmt::Debug;
use std::ops::Deref;

use crate::first_last::FirstValueAccumulator;

make_udaf_expr_and_func!(
MaxByFunction,
max_by,
x y,
"Returns the value of the first column corresponding to the maximum value in the second column.",
max_by_udaf
);

pub struct MaxByFunction {
signature: Signature,
requirement_satisfied: bool,
}

impl Debug for MaxByFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("MaxBy")
.field("name", &self.name())
.field("signature", &self.signature)
.field("accumulator", &"<FUNC>")
.finish()
}
}
impl Default for MaxByFunction {
fn default() -> Self {
Self::new()
}
}

impl MaxByFunction {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
requirement_satisfied: false,
}
}
}

fn get_min_max_by_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
// min and max support the dictionary data type
// unpack the dictionary to get the value
match &input_types[0] {
DataType::Dictionary(_, dict_value_type) => {
// TODO add checker, if the value type is complex data type
Ok(vec![dict_value_type.deref().clone()])
}
// TODO add checker for datatype which min and max supported
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
_ => Ok(input_types.to_vec()),
}
}

impl AggregateUDFImpl for MaxByFunction {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"max_by"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].to_owned())
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let ordering_dtypes = acc_args
.ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;
let requirement_satisfied =
acc_args.ordering_req.is_empty() || self.requirement_satisfied;
let first_value_accumulator = FirstValueAccumulator::try_new(
acc_args.return_type,
&ordering_dtypes,
acc_args.ordering_req.to_vec(),
acc_args.ignore_nulls,
)?
.with_requirement_satisfied(requirement_satisfied);
MaxByAccumulator::try_new(first_value_accumulator)
.map(|acc| Box::new(acc) as Box<dyn Accumulator>)
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
get_min_max_by_result_type(arg_types)
}
}

#[derive(Debug)]
pub struct MaxByAccumulator {
first_value_accumulator: FirstValueAccumulator,
max_order: ScalarValue,
}

impl MaxByAccumulator {
pub fn try_new(first_value_accumulator: FirstValueAccumulator) -> Result<Self> {
Ok(Self {
first_value_accumulator,
max_order: ScalarValue::Null,
})
}

fn update_with_new_row(&mut self, row: &[ScalarValue]) {
let order = row[1].clone();
if self.max_order.is_null() || order > self.max_order {
self.max_order = order;
self.first_value_accumulator.update_with_new_row(row);
}
}
}

impl Accumulator for MaxByAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.first_value_accumulator.state()
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
for i in 0..values[0].len() {
let row = get_row_at_idx(values, i)?;
self.update_with_new_row(&row);
}
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.first_value_accumulator.merge_batch(states)
}

fn evaluate(&mut self) -> Result<ScalarValue> {
self.first_value_accumulator.evaluate()
}

fn size(&self) -> usize {
self.first_value_accumulator.size()
}
}
29 changes: 29 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5863,3 +5863,32 @@ ORDER BY k;
----
1 1.8125 6.8007813 Float16 Float16
2 8.5 8.5 Float16 Float16


# test max_by function
query I
SELECT max_by(num_column, value) AS max_num
FROM VALUES
(10, 1),
(20, 2),
(30, 3) AS tab(num_column, value);
----
30

query R
SELECT max_by(float_column, value) AS max_float
FROM VALUES
(10.5, 5),
(20.75, 10),
(15.25, 15) AS tab(float_column, value);
----
15.25

query T
SELECT max_by(date_column, value) AS max_date
FROM VALUES
('2024-01-01', 1),
('2024-02-01', 2),
('2024-03-01', 3) AS tab(date_column, value);
----
2024-03-01

0 comments on commit d64d16f

Please sign in to comment.