Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
forsaken628 committed Jan 7, 2025
1 parent 020d6c9 commit b7fe665
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 36 deletions.
8 changes: 4 additions & 4 deletions src/query/expression/src/aggregate/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ pub trait AggregateFunction: fmt::Display + Sync + Send {

fn init_state(&self, place: &AggrState);

fn is_state(&self) -> bool {
false
}

fn state_layout(&self) -> Layout;

fn register_state(&self, register: &mut AggrStateRegister) {
Expand Down Expand Up @@ -257,6 +253,10 @@ impl AggrStateRegister {
pub fn commit(&mut self) {
self.offsets.push(self.states.len());
}

pub fn states(&self) -> &[AggrStateType] {
&self.states
}
}

impl Default for AggrStateRegister {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl From<StateAddr> for usize {
}
}

pub fn get_state_layout(funcs: &[AggregateFunctionRef]) -> Result<StatesLayout> {
pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result<StatesLayout> {
let mut register = AggrStateRegister::new();
for func in funcs {
func.register_state(&mut register);
Expand Down
4 changes: 2 additions & 2 deletions src/query/expression/src/aggregate/payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use strength_reduce::StrengthReducedU64;

use super::payload_row::rowformat_size;
use super::payload_row::serialize_column_to_rowformat;
use crate::get_state_layout;
use crate::get_states_layout;
use crate::read;
use crate::store;
use crate::types::DataType;
Expand Down Expand Up @@ -90,7 +90,7 @@ impl Payload {
aggrs: Vec<AggregateFunctionRef>,
) -> Self {
let states_layout = if !aggrs.is_empty() {
Some(get_state_layout(&aggrs).unwrap())
Some(get_states_layout(&aggrs).unwrap())
} else {
None
};
Expand Down
49 changes: 37 additions & 12 deletions src/query/functions/src/aggregates/aggregate_combinator_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use databend_common_exception::Result;
use databend_common_expression::types::Bitmap;
use databend_common_expression::types::DataType;
use databend_common_expression::AggrStateRegister;
use databend_common_expression::AggrStateType;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::InputColumns;
use databend_common_expression::Scalar;
Expand All @@ -36,6 +37,7 @@ use crate::aggregates::AggregateFunctionRef;
#[derive(Clone)]
pub struct AggregateStateCombinator {
name: String,
data_type: DataType,
nested: AggregateFunctionRef,
}

Expand All @@ -56,7 +58,25 @@ impl AggregateStateCombinator {

let nested = AggregateFunctionFactory::instance().get(nested_name, params, arguments)?;

Ok(Arc::new(AggregateStateCombinator { name, nested }))
let mut register = AggrStateRegister::default();
nested.register_state(&mut register);

let sub_types = register
.states()
.iter()
.map(|typ| match typ {
AggrStateType::Bool => DataType::Boolean,
AggrStateType::Custom(_) => DataType::Binary,
})
.collect();

let data_type = DataType::Tuple(sub_types);

Ok(Arc::new(AggregateStateCombinator {
name,
data_type,
nested,
}))
}

pub fn combinator_desc() -> CombinatorDescription {
Expand All @@ -70,17 +90,13 @@ impl AggregateFunction for AggregateStateCombinator {
}

fn return_type(&self) -> Result<DataType> {
Ok(DataType::Binary)
Ok(self.data_type.clone())
}

fn init_state(&self, place: &AggrState) {
self.nested.init_state(place);
}

fn is_state(&self) -> bool {
true
}

fn state_layout(&self) -> Layout {
unreachable!()
}
Expand Down Expand Up @@ -146,12 +162,21 @@ impl AggregateFunction for AggregateStateCombinator {
self.nested.merge_states(place, rhs)
}

fn merge_result(&self, _place: &AggrState, _builder: &mut ColumnBuilder) -> Result<()> {
todo!()
// let str_builder = builder.as_binary_mut().unwrap();
// self.serialize(place, &mut str_builder.data)?;
// str_builder.commit_row();
// Ok(())
fn merge_result(&self, place: &AggrState, builder: &mut ColumnBuilder) -> Result<()> {
let builders = builder.as_tuple_mut().unwrap();

let loc = place
.loc()
.iter()
.enumerate()
.map(|(i, loc)| match loc {
AggrStateLoc::Bool(_, offset) => AggrStateLoc::Bool(i, *offset),
AggrStateLoc::Custom(_, offset) => AggrStateLoc::Custom(i, *offset),
})
.collect::<Vec<_>>()
.into_boxed_slice();
let place = AggrState::with_loc(place.addr, loc);
self.nested.serialize_builder(&place, builders)
}

fn need_manual_drop_state(&self) -> bool {
Expand Down
4 changes: 2 additions & 2 deletions src/query/functions/src/aggregates/aggregator_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use databend_common_expression::Column;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::Scalar;

use super::get_state_layout;
use super::get_states_layout;
use super::AggrState;
use super::AggregateFunctionFactory;
use super::AggregateFunctionRef;
Expand Down Expand Up @@ -117,7 +117,7 @@ struct EvalAggr {
impl EvalAggr {
fn new(func: AggregateFunctionRef) -> Self {
let funcs = [func];
let state_layout = get_state_layout(&funcs).unwrap();
let state_layout = get_states_layout(&funcs).unwrap();

let _arena = Bump::new();
let place = _arena.alloc_layout(state_layout.layout);
Expand Down
7 changes: 5 additions & 2 deletions src/query/functions/tests/it/aggregates/agg_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use std::sync::Arc;

use bumpalo::Bump;
use databend_common_expression::block_debug::assert_block_value_sort_eq;
use databend_common_expression::get_states_layout;
use databend_common_expression::types::ArgType;
use databend_common_expression::types::BooleanType;
use databend_common_expression::types::DataType;
Expand Down Expand Up @@ -193,9 +194,11 @@ fn test_layout() {
type S = DecimalSumState<false, DecimalType<i128>>;
type M = DecimalSumState<false, DecimalType<I256>>;

let states_layout = get_states_layout(&[aggrs.clone()]).unwrap();

assert_eq!(
aggrs.state_layout(),
Layout::from_size_align(24, 8).unwrap()
states_layout.layout,
Layout::from_size_align(17, 8).unwrap()
);
assert_eq!(Layout::new::<S>(), Layout::from_size_align(16, 8).unwrap());
assert_eq!(Layout::new::<M>(), Layout::from_size_align(32, 8).unwrap());
Expand Down
14 changes: 7 additions & 7 deletions src/query/functions/tests/it/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ use std::io::Write;
use bumpalo::Bump;
use comfy_table::Table;
use databend_common_exception::Result;
use databend_common_expression::get_states_layout;
use databend_common_expression::type_check;
use databend_common_expression::types::AnyType;
use databend_common_expression::types::DataType;
use databend_common_expression::AggrState;
use databend_common_expression::AggrStateLoc;
use databend_common_expression::BlockEntry;
use databend_common_expression::Column;
use databend_common_expression::ColumnBuilder;
Expand Down Expand Up @@ -192,23 +192,23 @@ pub fn simulate_two_groups_group_by(

let func = factory.get(name, params, arguments)?;
let data_type = func.return_type()?;
let states_layout = get_states_layout(&[func.clone()])?;
let loc = states_layout.loc[0].clone();

let arena = Bump::new();

// init state for two groups
let addr1 = arena.alloc_layout(func.state_layout()).into();
let state1 = AggrState::new(addr1, 0);
let addr1 = arena.alloc_layout(states_layout.layout.clone()).into();
let state1 = AggrState::with_loc(addr1, loc.clone());
func.init_state(&state1);
let addr2 = arena.alloc_layout(func.state_layout()).into();
let state2 = AggrState::new(addr2, 0);
let addr2 = arena.alloc_layout(states_layout.layout.clone()).into();
let state2 = AggrState::with_loc(addr2, loc.clone());
func.init_state(&state2);

let places = (0..rows)
.map(|i| if i % 2 == 0 { addr1 } else { addr2 })
.collect::<Vec<_>>();

let loc = vec![AggrStateLoc::Custom(0, 0)].into_boxed_slice();

func.accumulate_keys(&places, loc, columns.into(), rows)?;

let mut builder = ColumnBuilder::with_capacity(&data_type, 1024);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use databend_common_expression::types::DataType;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::DataBlock;
use databend_common_expression::DataSchemaRef;
use databend_common_functions::aggregates::get_state_layout;
use databend_common_functions::aggregates::get_states_layout;
use databend_common_functions::aggregates::AggregateFunctionRef;
use databend_common_functions::aggregates::StatesLayout;
use databend_common_sql::IndexType;
Expand Down Expand Up @@ -56,7 +56,7 @@ impl AggregatorParams {
max_spill_io_requests: usize,
) -> Result<Arc<AggregatorParams>> {
let states_layout = if !agg_funcs.is_empty() {
Some(get_state_layout(agg_funcs)?)
Some(get_states_layout(agg_funcs)?)
} else {
None
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::sync::Arc;

use databend_common_base::runtime::drop_guard;
use databend_common_exception::Result;
use databend_common_expression::get_state_layout;
use databend_common_expression::get_states_layout;
use databend_common_expression::types::DataType;
use databend_common_expression::types::NumberDataType;
use databend_common_expression::AggrState;
Expand Down Expand Up @@ -234,10 +234,10 @@ impl WindowFunctionImpl {
WindowFunctionInfo::Aggregate(agg, args) => {
let arena = Arena::new();

let state_layout = get_state_layout(&[agg.clone()])?;
let states_layout = get_states_layout(&[agg.clone()])?;
let place = AggrState::with_loc(
arena.alloc_layout(state_layout.layout).into(),
state_layout.loc[0].clone(),
arena.alloc_layout(states_layout.layout).into(),
states_layout.loc[0].clone(),
);
let agg = WindowFuncAggImpl {
_arena: arena,
Expand Down

0 comments on commit b7fe665

Please sign in to comment.