diff --git a/Cargo.lock b/Cargo.lock index 64d9d9267653b..afa04e74e308f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2874,6 +2874,7 @@ dependencies = [ "serde", "serde_json", "simdutf8", + "strength_reduce", "terminal_size", "tonic 0.10.2", "typetag", diff --git a/src/query/expression/Cargo.toml b/src/query/expression/Cargo.toml index a374f2872aca2..ee5d12239c8b3 100644 --- a/src/query/expression/Cargo.toml +++ b/src/query/expression/Cargo.toml @@ -54,6 +54,7 @@ rust_decimal = "1.26" serde = { workspace = true } serde_json = { workspace = true } simdutf8 = "0.1.4" +strength_reduce = "0.2.4" terminal_size = "0.2.6" tonic = { workspace = true } typetag = { workspace = true } diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index c3d87cad9cea7..279df546121bd 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -42,6 +42,8 @@ pub type Entry = u64; pub struct AggregateHashTable { pub payload: PartitionedPayload, + // use for append rows directly during deserialize + pub direct_append: bool, config: HashTableConfig, current_radix_bits: u64, entries: Vec, @@ -71,6 +73,7 @@ impl AggregateHashTable { Self { entries: vec![0u64; capacity], count: 0, + direct_append: false, current_radix_bits: config.initial_radix_bits, payload: PartitionedPayload::new(group_types, aggrs, 1 << config.initial_radix_bits), capacity, @@ -134,7 +137,15 @@ impl AggregateHashTable { state.row_count = row_count; group_hash_columns(group_columns, &mut state.group_hashes); - let new_group_count = self.probe_and_create(state, group_columns, row_count); + let new_group_count = if self.direct_append { + for idx in 0..row_count { + state.empty_vector[idx] = idx; + } + self.payload.append_rows(state, row_count, group_columns); + row_count + } else { + self.probe_and_create(state, group_columns, row_count) + }; if !self.payload.aggrs.is_empty() { for i in 0..row_count { diff --git a/src/query/expression/src/aggregate/payload.rs b/src/query/expression/src/aggregate/payload.rs index 1f10ece16aa8c..d045ce41ee1a3 100644 --- a/src/query/expression/src/aggregate/payload.rs +++ b/src/query/expression/src/aggregate/payload.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use bumpalo::Bump; use databend_common_base::runtime::drop_guard; +use strength_reduce::StrengthReducedU64; use super::payload_row::rowformat_size; use super::payload_row::serialize_column_to_rowformat; @@ -26,8 +27,10 @@ use crate::store; use crate::types::DataType; use crate::AggregateFunctionRef; use crate::Column; +use crate::PayloadFlushState; use crate::SelectVector; use crate::StateAddr; +use crate::BATCH_SIZE; use crate::MAX_PAGE_SIZE; // payload layout @@ -38,6 +41,7 @@ use crate::MAX_PAGE_SIZE; // [STATE_ADDRS] is the state_addrs of the aggregate functions, 8 bytes each pub struct Payload { pub arena: Arc, + pub arenas: Vec>, // if true, the states are moved out of the payload into other payload, and will not be dropped pub state_move_out: bool, pub group_types: Vec, @@ -120,7 +124,8 @@ impl Payload { let row_per_page = (u16::MAX as usize).min(MAX_PAGE_SIZE / tuple_size).max(1); Self { - arena, + arena: arena.clone(), + arenas: vec![arena], state_move_out: false, pages: vec![], current_write_page: 0, @@ -333,6 +338,44 @@ impl Payload { self.pages.iter().map(|x| x.rows).sum::() ); } + + pub fn scatter(&self, state: &mut PayloadFlushState, partition_count: usize) -> bool { + if state.flush_page >= self.pages.len() { + return false; + } + + let page = &self.pages[state.flush_page]; + + // ToNext + if state.flush_page_row >= page.rows { + state.flush_page += 1; + state.flush_page_row = 0; + state.row_count = 0; + return self.scatter(state, partition_count); + } + + let end = (state.flush_page_row + BATCH_SIZE).min(page.rows); + let rows = end - state.flush_page_row; + state.row_count = rows; + + state.probe_state.reset_partitions(partition_count); + + let mods: StrengthReducedU64 = StrengthReducedU64::new(partition_count as u64); + for idx in 0..rows { + state.addresses[idx] = self.data_ptr(page, idx + state.flush_page_row); + + let hash = + unsafe { core::ptr::read::(state.addresses[idx].add(self.hash_offset) as _) }; + + let partition_idx = (hash % mods) as usize; + + let sel = &mut state.probe_state.partition_entries[partition_idx]; + sel[state.probe_state.partition_count[partition_idx]] = idx; + state.probe_state.partition_count[partition_idx] += 1; + } + state.flush_page_row = end; + true + } } impl Drop for Payload { diff --git a/src/query/service/src/pipelines/builders/builder_aggregate.rs b/src/query/service/src/pipelines/builders/builder_aggregate.rs index a0df6d3100ab2..ee0f333513d3c 100644 --- a/src/query/service/src/pipelines/builders/builder_aggregate.rs +++ b/src/query/service/src/pipelines/builders/builder_aggregate.rs @@ -104,8 +104,7 @@ impl PipelineBuilder { let enable_experimental_aggregate_hashtable = self .settings - .get_enable_experimental_aggregate_hashtable()? - && self.ctx.get_cluster().is_empty(); + .get_enable_experimental_aggregate_hashtable()?; let params = Self::build_aggregator_params( aggregate.input.output_schema()?, @@ -213,8 +212,7 @@ impl PipelineBuilder { let max_block_size = self.settings.get_max_block_size()?; let enable_experimental_aggregate_hashtable = self .settings - .get_enable_experimental_aggregate_hashtable()? - && self.ctx.get_cluster().is_empty(); + .get_enable_experimental_aggregate_hashtable()?; let params = Self::build_aggregator_params( aggregate.before_group_by_schema.clone(), diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs index 4deee8193f708..1232679ae8ac7 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs @@ -21,6 +21,8 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::BlockMetaInfoDowncast; use databend_common_expression::DataBlock; +use databend_common_expression::Payload; +use databend_common_expression::PayloadFlushState; use databend_common_hashtable::FastHash; use databend_common_hashtable::HashtableEntryMutRefLike; use databend_common_hashtable::HashtableEntryRefLike; @@ -78,6 +80,7 @@ impl ExchangeSorting AggregateMeta::Serialized(v) => Ok(v.bucket), AggregateMeta::HashTable(v) => Ok(v.bucket), AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(v) => Ok(v.bucket), AggregateMeta::Spilled(_) | AggregateMeta::Spilling(_) | AggregateMeta::BucketSpilled(_) => Ok(-1), @@ -139,6 +142,42 @@ fn scatter( Ok(res) } +fn scatter_paylaod(mut payload: Payload, buckets: usize) -> Result> { + let mut buckets = Vec::with_capacity(buckets); + + let group_types = payload.group_types.clone(); + let aggrs = payload.aggrs.clone(); + let mut state = PayloadFlushState::default(); + + for _ in 0..buckets.capacity() { + buckets.push(Payload::new( + Arc::new(Bump::new()), + group_types.clone(), + aggrs.clone(), + )); + } + + for bucket in buckets.iter_mut() { + bucket.arenas.extend_from_slice(&payload.arenas); + } + + // scatter each page of the payload. + while payload.scatter(&mut state, buckets.len()) { + // copy to the corresponding bucket. + for (idx, bucket) in buckets.iter_mut().enumerate() { + let count = state.probe_state.partition_count[idx]; + + if count > 0 { + let sel = &state.probe_state.partition_entries[idx]; + bucket.copy_rows(sel, count, &state.addresses); + } + } + } + payload.state_move_out = true; + + Ok(buckets) +} + impl FlightScatter for HashTableHashScatter { @@ -176,8 +215,18 @@ impl FlightScatter }); } } - - AggregateMeta::AggregateHashTable(_) => todo!("AGG_HASHTABLE"), + AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(p) => { + for payload in scatter_paylaod(p.payload, self.buckets)? { + blocks.push(DataBlock::empty_with_meta( + AggregateMeta::::create_agg_payload( + p.bucket, + payload, + p.max_partition_count, + ), + )) + } + } }; return Ok(blocks); diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs index 6bd84eb63bbf0..b6f944a4908a0 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs @@ -21,6 +21,7 @@ use databend_common_expression::BlockMetaInfoPtr; use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::PartitionedPayload; +use databend_common_expression::Payload; use crate::pipelines::processors::transforms::aggregator::HashTableCell; use crate::pipelines::processors::transforms::group_by::HashMethodBounds; @@ -34,6 +35,8 @@ pub struct HashTablePayload { pub struct SerializedPayload { pub bucket: isize, pub data_block: DataBlock, + // use for new agg_hashtable + pub max_partition_count: usize, } impl SerializedPayload { @@ -50,10 +53,17 @@ pub struct BucketSpilledPayload { pub columns_layout: Vec, } +pub struct AggregatePayload { + pub bucket: isize, + pub payload: Payload, + pub max_partition_count: usize, +} + pub enum AggregateMeta { Serialized(SerializedPayload), HashTable(HashTablePayload), AggregateHashTable(PartitionedPayload), + AggregatePayload(AggregatePayload), BucketSpilled(BucketSpilledPayload), Spilled(Vec), Spilling(HashTablePayload, V>), @@ -73,10 +83,29 @@ impl AggregateMeta::AggregateHashTable(payload)) } - pub fn create_serialized(bucket: isize, block: DataBlock) -> BlockMetaInfoPtr { + pub fn create_agg_payload( + bucket: isize, + payload: Payload, + max_partition_count: usize, + ) -> BlockMetaInfoPtr { + Box::new(AggregateMeta::::AggregatePayload( + AggregatePayload { + bucket, + payload, + max_partition_count, + }, + )) + } + + pub fn create_serialized( + bucket: isize, + block: DataBlock, + max_partition_count: usize, + ) -> BlockMetaInfoPtr { Box::new(AggregateMeta::::Serialized(SerializedPayload { bucket, data_block: block, + max_partition_count, })) } @@ -136,6 +165,9 @@ impl Debug for AggregateMeta AggregateMeta::AggregateHashTable(_) => { f.debug_struct("AggregateMeta:AggHashTable").finish() } + AggregateMeta::AggregatePayload(_) => { + f.debug_struct("AggregateMeta:AggregatePayload").finish() + } } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/serde_meta.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/serde_meta.rs index 9e10af17b7c08..abec404babfe0 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/serde_meta.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/serde_meta.rs @@ -29,6 +29,9 @@ pub struct AggregateSerdeMeta { pub location: Option, pub data_range: Option>, pub columns_layout: Vec, + // use for new agg_hashtable + pub is_agg_payload: bool, + pub max_partition_count: usize, } impl AggregateSerdeMeta { @@ -39,6 +42,20 @@ impl AggregateSerdeMeta { location: None, data_range: None, columns_layout: vec![], + is_agg_payload: false, + max_partition_count: 0, + }) + } + + pub fn create_agg_payload(bucket: isize, max_partition_count: usize) -> BlockMetaInfoPtr { + Box::new(AggregateSerdeMeta { + typ: BUCKET_TYPE, + bucket, + location: None, + data_range: None, + columns_layout: vec![], + is_agg_payload: true, + max_partition_count, }) } @@ -54,6 +71,8 @@ impl AggregateSerdeMeta { columns_layout, location: Some(location), data_range: Some(data_range), + is_agg_payload: false, + max_partition_count: 0, }) } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs index 9d212ff15dabb..95f770ba1adc7 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs @@ -22,6 +22,7 @@ use databend_common_expression::types::binary::BinaryColumnBuilder; use databend_common_expression::BlockMetaInfoDowncast; use databend_common_expression::Column; use databend_common_expression::DataBlock; +use databend_common_expression::PayloadFlushState; use databend_common_functions::aggregates::StateAddr; use databend_common_hashtable::HashtableEntryRefLike; use databend_common_hashtable::HashtableLike; @@ -31,15 +32,15 @@ use databend_common_pipeline_core::processors::OutputPort; use databend_common_pipeline_core::processors::Processor; use databend_common_pipeline_core::processors::ProcessorPtr; +use super::SerializePayload; use crate::pipelines::processors::transforms::aggregator::create_state_serializer; +use crate::pipelines::processors::transforms::aggregator::empty_block; use crate::pipelines::processors::transforms::aggregator::estimated_key_size; use crate::pipelines::processors::transforms::aggregator::AggregateMeta; use crate::pipelines::processors::transforms::aggregator::AggregateSerdeMeta; use crate::pipelines::processors::transforms::aggregator::AggregatorParams; -use crate::pipelines::processors::transforms::aggregator::HashTablePayload; use crate::pipelines::processors::transforms::group_by::HashMethodBounds; use crate::pipelines::processors::transforms::group_by::KeysColumnBuilder; - pub struct TransformAggregateSerializer { method: Method, params: Arc, @@ -137,15 +138,23 @@ impl TransformAggregateSerializer { AggregateMeta::Serialized(_) => unreachable!(), AggregateMeta::BucketSpilled(_) => unreachable!(), AggregateMeta::Partitioned { .. } => unreachable!(), + AggregateMeta::AggregateHashTable(_) => unreachable!(), AggregateMeta::HashTable(payload) => { self.input_data = Some(SerializeAggregateStream::create( &self.method, &self.params, - payload, + SerializePayload::::HashTablePayload(payload), + )); + return Ok(Event::Sync); + } + AggregateMeta::AggregatePayload(p) => { + self.input_data = Some(SerializeAggregateStream::create( + &self.method, + &self.params, + SerializePayload::::AggregatePayload(p), )); return Ok(Event::Sync); } - AggregateMeta::AggregateHashTable(_) => todo!("AGG_HASHTABLE"), } } } @@ -198,8 +207,9 @@ pub fn serialize_aggregate( pub struct SerializeAggregateStream { method: Method, params: Arc, - pub payload: Pin>>, - iter: as HashtableLike>::Iterator<'static>, + pub payload: Pin>>, + // old hashtable' iter + iter: Option< as HashtableLike>::Iterator<'static>>, end_iter: bool, } @@ -211,13 +221,16 @@ impl SerializeAggregateStream { pub fn create( method: &Method, params: &Arc, - payload: HashTablePayload, + payload: SerializePayload, ) -> Self { unsafe { let payload = Box::pin(payload); - let point = NonNull::from(&payload.cell.hashtable); - let iter = point.as_ref().iter(); + let iter = if let SerializePayload::HashTablePayload(p) = payload.as_ref().get_ref() { + Some(NonNull::from(&p.cell.hashtable).as_ref().iter()) + } else { + None + }; SerializeAggregateStream:: { iter, @@ -244,49 +257,100 @@ impl SerializeAggregateStream { return Ok(None); } - let max_block_rows = std::cmp::min(8192, self.payload.cell.hashtable.len()); - let max_block_bytes = std::cmp::min( - 8 * 1024 * 1024 + 1024, - self.payload - .cell - .hashtable - .unsize_key_size() - .unwrap_or(usize::MAX), - ); - - let funcs = &self.params.aggregate_functions; - let offsets_aggregate_states = &self.params.offsets_aggregate_states; - - let mut state_builders: Vec = funcs - .iter() - .map(|func| create_state_serializer(func, max_block_rows)) - .collect(); - - let mut group_key_builder = self - .method - .keys_column_builder(max_block_rows, max_block_bytes); - - #[allow(clippy::while_let_on_iterator)] - while let Some(group_entity) = self.iter.next() { - let mut bytes = 0; - let place = Into::::into(*group_entity.get()); - - for (idx, func) in funcs.iter().enumerate() { - let arg_place = place.next(offsets_aggregate_states[idx]); - func.serialize(arg_place, &mut state_builders[idx].data)?; - state_builders[idx].commit_row(); - bytes += state_builders[idx].memory_size(); + match self.payload.as_ref().get_ref() { + SerializePayload::HashTablePayload(p) => { + let max_block_rows = std::cmp::min(8192, p.cell.hashtable.len()); + let max_block_bytes = std::cmp::min( + 8 * 1024 * 1024 + 1024, + p.cell.hashtable.unsize_key_size().unwrap_or(usize::MAX), + ); + + let funcs = &self.params.aggregate_functions; + let offsets_aggregate_states = &self.params.offsets_aggregate_states; + + let mut state_builders: Vec = funcs + .iter() + .map(|func| create_state_serializer(func, max_block_rows)) + .collect(); + + let mut group_key_builder = self + .method + .keys_column_builder(max_block_rows, max_block_bytes); + + #[allow(clippy::while_let_on_iterator)] + while let Some(group_entity) = self.iter.as_mut().and_then(|iter| iter.next()) { + let mut bytes = 0; + let place = Into::::into(*group_entity.get()); + + for (idx, func) in funcs.iter().enumerate() { + let arg_place = place.next(offsets_aggregate_states[idx]); + func.serialize(arg_place, &mut state_builders[idx].data)?; + state_builders[idx].commit_row(); + bytes += state_builders[idx].memory_size(); + } + + group_key_builder.append_value(group_entity.key()); + + if bytes >= 8 * 1024 * 1024 { + return self.finish(state_builders, group_key_builder); + } + } + + self.end_iter = true; + self.finish(state_builders, group_key_builder) } + SerializePayload::AggregatePayload(p) => { + let mut state = PayloadFlushState::default(); + let mut blocks = vec![]; + + while p.payload.flush(&mut state) { + let row_count = state.row_count; + + let mut state_builders: Vec = p + .payload + .aggrs + .iter() + .map(|agg| create_state_serializer(agg, row_count)) + .collect(); + + for place in state.state_places.as_slice()[0..row_count].iter() { + for (idx, (addr_offset, aggr)) in p + .payload + .state_addr_offsets + .iter() + .zip(p.payload.aggrs.iter()) + .enumerate() + { + let arg_place = place.next(*addr_offset); + aggr.serialize(arg_place, &mut state_builders[idx].data)?; + state_builders[idx].commit_row(); + } + } + + let mut cols = + Vec::with_capacity(p.payload.aggrs.len() + p.payload.group_types.len()); + for builder in state_builders.into_iter() { + let col = Column::Binary(builder.build()); + cols.push(col); + } + + cols.extend_from_slice(&state.take_group_columns()); - group_key_builder.append_value(group_entity.key()); + blocks.push(DataBlock::new_from_columns(cols)); + } - if bytes >= 8 * 1024 * 1024 { - return self.finish(state_builders, group_key_builder); + self.end_iter = true; + + let data_block = if blocks.is_empty() { + empty_block(p) + } else { + DataBlock::concat(&blocks).unwrap() + }; + Ok(Some(data_block.add_meta(Some( + AggregateSerdeMeta::create_agg_payload(p.bucket, p.max_partition_count), + ))?)) } } - - self.end_iter = true; - self.finish(state_builders, group_key_builder) } fn finish( @@ -300,7 +364,12 @@ impl SerializeAggregateStream { columns.push(Column::Binary(builder.build())); } - let bucket = self.payload.bucket; + let bucket = if let SerializePayload::HashTablePayload(p) = self.payload.as_ref().get_ref() + { + p.bucket + } else { + 0 + }; columns.push(group_key_builder.finish()); let block = DataBlock::new_from_columns(columns); Ok(Some( diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs index 01a2a3a50c062..e8dafbb02358b 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs @@ -118,6 +118,7 @@ impl TransformDeserializer { diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs index ef14759d3e189..f34b8a23435f1 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs @@ -48,6 +48,7 @@ use futures_util::future::BoxFuture; use log::info; use opendal::Operator; +use super::SerializePayload; use crate::api::serialize_block; use crate::api::ExchangeShuffleMeta; use crate::pipelines::processors::transforms::aggregator::aggregate_meta::AggregateMeta; @@ -134,6 +135,7 @@ impl BlockMetaTransform Some(AggregateMeta::Serialized(_)) => unreachable!(), Some(AggregateMeta::BucketSpilled(_)) => unreachable!(), Some(AggregateMeta::Partitioned { .. }) => unreachable!(), + Some(AggregateMeta::AggregateHashTable(_)) => unreachable!(), Some(AggregateMeta::Spilling(payload)) => { serialized_blocks.push(FlightSerialized::Future( match index == self.local_pos { @@ -156,7 +158,6 @@ impl BlockMetaTransform }, )); } - Some(AggregateMeta::HashTable(payload)) => { if index == self.local_pos { serialized_blocks.push(FlightSerialized::DataBlock(block.add_meta( @@ -165,9 +166,12 @@ impl BlockMetaTransform continue; } - let mut stream = - SerializeAggregateStream::create(&self.method, &self.params, payload); - let bucket = stream.payload.bucket; + let bucket = payload.bucket; + let mut stream = SerializeAggregateStream::create( + &self.method, + &self.params, + SerializePayload::::HashTablePayload(payload), + ); serialized_blocks.push(FlightSerialized::DataBlock(match stream.next() { None => DataBlock::empty(), Some(data_block) => { @@ -175,8 +179,29 @@ impl BlockMetaTransform } })); } + Some(AggregateMeta::AggregatePayload(p)) => { + if index == self.local_pos { + serialized_blocks.push(FlightSerialized::DataBlock(block.add_meta( + Some(Box::new(AggregateMeta::::AggregatePayload( + p, + ))), + )?)); + continue; + } - Some(AggregateMeta::AggregateHashTable(_)) => todo!("AGG_HASHTABLE"), + let bucket = p.bucket; + let mut stream = SerializeAggregateStream::create( + &self.method, + &self.params, + SerializePayload::::AggregatePayload(p), + ); + serialized_blocks.push(FlightSerialized::DataBlock(match stream.next() { + None => DataBlock::empty(), + Some(data_block) => { + serialize_block(bucket, data_block?, &self.ipc_fields, &self.options)? + } + })); + } }; } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_group_by_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_group_by_serializer.rs index 815afda2743f1..5990abaa102ce 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_group_by_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_group_by_serializer.rs @@ -54,6 +54,7 @@ use futures_util::future::BoxFuture; use log::info; use opendal::Operator; +use super::SerializePayload; use crate::api::serialize_block; use crate::api::ExchangeShuffleMeta; use crate::pipelines::processors::transforms::aggregator::exchange_defines; @@ -187,6 +188,7 @@ impl BlockMetaTransform Some(AggregateMeta::BucketSpilled(_)) => unreachable!(), Some(AggregateMeta::Serialized(_)) => unreachable!(), Some(AggregateMeta::Partitioned { .. }) => unreachable!(), + Some(AggregateMeta::AggregateHashTable(_)) => unreachable!(), Some(AggregateMeta::Spilling(payload)) => { serialized_blocks.push(FlightSerialized::Future( match index == self.local_pos { @@ -207,7 +209,6 @@ impl BlockMetaTransform }, )); } - Some(AggregateMeta::AggregateHashTable(_)) => todo!("AGG_HASHTABLE"), Some(AggregateMeta::HashTable(payload)) => { if index == self.local_pos { serialized_blocks.push(FlightSerialized::DataBlock(block.add_meta( @@ -216,8 +217,31 @@ impl BlockMetaTransform continue; } - let mut stream = SerializeGroupByStream::create(&self.method, payload); - let bucket = stream.payload.bucket; + let bucket = payload.bucket; + let mut stream = SerializeGroupByStream::create( + &self.method, + SerializePayload::::HashTablePayload(payload), + ); + serialized_blocks.push(FlightSerialized::DataBlock(match stream.next() { + None => DataBlock::empty(), + Some(data_block) => { + serialize_block(bucket, data_block?, &self.ipc_fields, &self.options)? + } + })); + } + Some(AggregateMeta::AggregatePayload(p)) => { + if index == self.local_pos { + serialized_blocks.push(FlightSerialized::DataBlock(block.add_meta( + Some(Box::new(AggregateMeta::::AggregatePayload(p))), + )?)); + continue; + } + + let bucket = p.bucket; + let mut stream = SerializeGroupByStream::create( + &self.method, + SerializePayload::::AggregatePayload(p), + ); serialized_blocks.push(FlightSerialized::DataBlock(match stream.next() { None => DataBlock::empty(), Some(data_block) => { diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_group_by_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_group_by_serializer.rs index ecc137af40a64..2d87c2c81ad9b 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_group_by_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_group_by_serializer.rs @@ -19,7 +19,9 @@ use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::BlockMetaInfoDowncast; +use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; +use databend_common_expression::PayloadFlushState; use databend_common_hashtable::HashtableEntryRefLike; use databend_common_hashtable::HashtableLike; use databend_common_pipeline_core::processors::Event; @@ -27,9 +29,11 @@ use databend_common_pipeline_core::processors::InputPort; use databend_common_pipeline_core::processors::OutputPort; use databend_common_pipeline_core::processors::Processor; use databend_common_pipeline_core::processors::ProcessorPtr; +use itertools::Itertools; use crate::pipelines::processors::transforms::aggregator::estimated_key_size; use crate::pipelines::processors::transforms::aggregator::AggregateMeta; +use crate::pipelines::processors::transforms::aggregator::AggregatePayload; use crate::pipelines::processors::transforms::aggregator::AggregateSerdeMeta; use crate::pipelines::processors::transforms::aggregator::HashTablePayload; use crate::pipelines::processors::transforms::group_by::HashMethodBounds; @@ -126,10 +130,19 @@ impl TransformGroupBySerializer { AggregateMeta::Serialized(_) => unreachable!(), AggregateMeta::BucketSpilled(_) => unreachable!(), AggregateMeta::Partitioned { .. } => unreachable!(), - AggregateMeta::AggregateHashTable(_) => todo!("AGG_HASHTABLE"), + AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(p) => { + self.input_data = Some(SerializeGroupByStream::create( + &self.method, + SerializePayload::::AggregatePayload(p), + )); + return Ok(Event::Sync); + } AggregateMeta::HashTable(payload) => { - self.input_data = - Some(SerializeGroupByStream::create(&self.method, payload)); + self.input_data = Some(SerializeGroupByStream::create( + &self.method, + SerializePayload::::HashTablePayload(payload), + )); return Ok(Event::Sync); } } @@ -157,10 +170,16 @@ pub fn serialize_group_by( ])) } +pub enum SerializePayload { + HashTablePayload(HashTablePayload), + AggregatePayload(AggregatePayload), +} + pub struct SerializeGroupByStream { method: Method, - pub payload: Pin>>, - iter: as HashtableLike>::Iterator<'static>, + pub payload: Pin>>, + // old hashtable' iter + iter: Option< as HashtableLike>::Iterator<'static>>, end_iter: bool, } @@ -169,10 +188,15 @@ unsafe impl Send for SerializeGroupByStream {} unsafe impl Sync for SerializeGroupByStream {} impl SerializeGroupByStream { - pub fn create(method: &Method, payload: HashTablePayload) -> Self { + pub fn create(method: &Method, payload: SerializePayload) -> Self { unsafe { let payload = Box::pin(payload); - let iter = NonNull::from(&payload.cell.hashtable).as_ref().iter(); + + let iter = if let SerializePayload::HashTablePayload(p) = payload.as_ref().get_ref() { + Some(NonNull::from(&p.cell.hashtable).as_ref().iter()) + } else { + None + }; SerializeGroupByStream:: { iter, @@ -192,34 +216,74 @@ impl Iterator for SerializeGroupByStream { return None; } - let max_block_rows = std::cmp::min(8192, self.payload.cell.hashtable.len()); - let max_block_bytes = std::cmp::min( - 8 * 1024 * 1024 + 1024, - self.payload - .cell - .hashtable - .unsize_key_size() - .unwrap_or(usize::MAX), - ); - - let mut group_key_builder = self - .method - .keys_column_builder(max_block_rows, max_block_bytes); - - #[allow(clippy::while_let_on_iterator)] - while let Some(group_entity) = self.iter.next() { - group_key_builder.append_value(group_entity.key()); - - if group_key_builder.bytes_size() >= 8 * 1024 * 1024 { - let bucket = self.payload.bucket; + match self.payload.as_ref().get_ref() { + SerializePayload::HashTablePayload(p) => { + let max_block_rows = std::cmp::min(8192, p.cell.hashtable.len()); + let max_block_bytes = std::cmp::min( + 8 * 1024 * 1024 + 1024, + p.cell.hashtable.unsize_key_size().unwrap_or(usize::MAX), + ); + + let mut group_key_builder = self + .method + .keys_column_builder(max_block_rows, max_block_bytes); + + #[allow(clippy::while_let_on_iterator)] + while let Some(group_entity) = self.iter.as_mut()?.next() { + group_key_builder.append_value(group_entity.key()); + + if group_key_builder.bytes_size() >= 8 * 1024 * 1024 { + let bucket = p.bucket; + let data_block = + DataBlock::new_from_columns(vec![group_key_builder.finish()]); + return Some(data_block.add_meta(Some(AggregateSerdeMeta::create(bucket)))); + } + } + + self.end_iter = true; + let bucket = p.bucket; let data_block = DataBlock::new_from_columns(vec![group_key_builder.finish()]); - return Some(data_block.add_meta(Some(AggregateSerdeMeta::create(bucket)))); + Some(data_block.add_meta(Some(AggregateSerdeMeta::create(bucket)))) } - } + SerializePayload::AggregatePayload(p) => { + let mut state = PayloadFlushState::default(); + let mut blocks = vec![]; + + while p.payload.flush(&mut state) { + let col = state.take_group_columns(); + blocks.push(DataBlock::new_from_columns(col)); + } - self.end_iter = true; - let bucket = self.payload.bucket; - let data_block = DataBlock::new_from_columns(vec![group_key_builder.finish()]); - Some(data_block.add_meta(Some(AggregateSerdeMeta::create(bucket)))) + self.end_iter = true; + + let data_block = if blocks.is_empty() { + empty_block(p) + } else { + DataBlock::concat(&blocks).unwrap() + }; + Some( + data_block.add_meta(Some(AggregateSerdeMeta::create_agg_payload( + p.bucket, + p.max_partition_count, + ))), + ) + } + } } } + +pub fn empty_block(p: &AggregatePayload) -> DataBlock { + let columns = p + .payload + .aggrs + .iter() + .map(|f| ColumnBuilder::with_capacity(&f.return_type().unwrap(), 0).build()) + .chain( + p.payload + .group_types + .iter() + .map(|t| ColumnBuilder::with_capacity(t, 0).build()), + ) + .collect_vec(); + DataBlock::new_from_columns(columns) +} diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_spill_reader.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_spill_reader.rs index 93dabf6290f02..8d69e57b03220 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_spill_reader.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_spill_reader.rs @@ -137,6 +137,7 @@ impl Processor AggregateMeta::Spilled(_) => unreachable!(), AggregateMeta::Spilling(_) => unreachable!(), AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(_) => unreachable!(), AggregateMeta::HashTable(_) => unreachable!(), AggregateMeta::Serialized(_) => unreachable!(), AggregateMeta::BucketSpilled(payload) => { @@ -179,6 +180,7 @@ impl Processor AggregateMeta::Spilling(_) => unreachable!(), AggregateMeta::HashTable(_) => unreachable!(), AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(_) => unreachable!(), AggregateMeta::Serialized(_) => unreachable!(), AggregateMeta::BucketSpilled(payload) => { let instant = Instant::now(); @@ -298,6 +300,7 @@ impl TransformSpillReader::Serialized(SerializedPayload { bucket: payload.bucket, data_block: DataBlock::new_from_columns(columns), + max_partition_count: 0, }) } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs index 953cc0301f485..556062bb53dc6 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs @@ -240,6 +240,7 @@ where Method: HashMethodBounds } }, AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(_) => unreachable!(), } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs index 06b95674a0596..3094d21c28e7f 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs @@ -419,9 +419,22 @@ impl AccumulatingTransform for TransformPartialAggrega blocks } - HashTable::AggregateHashTable(hashtable) => vec![DataBlock::empty_with_meta( - AggregateMeta::::create_agg_hashtable(hashtable.payload), - )], + HashTable::AggregateHashTable(hashtable) => { + let partition_count = hashtable.payload.partition_count(); + let mut blocks = Vec::with_capacity(partition_count); + for (bucket, mut payload) in hashtable.payload.payloads.into_iter().enumerate() { + payload.arenas.extend_from_slice(&hashtable.payload.arenas); + blocks.push(DataBlock::empty_with_meta( + AggregateMeta::::create_agg_payload( + bucket as isize, + payload, + partition_count, + ), + )); + } + + blocks + } }) } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_final.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_final.rs index 9eb4a59fb3541..1e3653bee01e1 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_final.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_final.rs @@ -161,6 +161,7 @@ where Method: HashMethodBounds } }, AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(_) => unreachable!(), } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs index 139e349f57202..584e155704367 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs @@ -264,9 +264,22 @@ impl AccumulatingTransform for TransformPartialGroupBy blocks } - HashTable::AggregateHashTable(hashtable) => vec![DataBlock::empty_with_meta( - AggregateMeta::::create_agg_hashtable(hashtable.payload), - )], + HashTable::AggregateHashTable(hashtable) => { + let partition_count = hashtable.payload.partition_count(); + let mut blocks = Vec::with_capacity(partition_count); + for (bucket, mut payload) in hashtable.payload.payloads.into_iter().enumerate() { + payload.arenas.extend_from_slice(&hashtable.payload.arenas); + blocks.push(DataBlock::empty_with_meta( + AggregateMeta::::create_agg_payload( + bucket as isize, + payload, + partition_count, + ), + )); + } + + blocks + } }) } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_partition_bucket.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_partition_bucket.rs index 719239cb9e1ac..0ef1d25529896 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_partition_bucket.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_partition_bucket.rs @@ -21,10 +21,13 @@ use std::sync::Arc; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::AggregateHashTable; use databend_common_expression::BlockMetaInfoDowncast; use databend_common_expression::DataBlock; +use databend_common_expression::HashTableConfig; use databend_common_expression::PartitionedPayload; use databend_common_expression::PayloadFlushState; +use databend_common_expression::ProbeState; use databend_common_hashtable::hash2bucket; use databend_common_hashtable::HashtableLike; use databend_common_pipeline_core::processors::Event; @@ -36,8 +39,8 @@ use databend_common_pipeline_core::Pipe; use databend_common_pipeline_core::PipeItem; use databend_common_pipeline_core::Pipeline; use databend_common_storage::DataOperator; -use itertools::Itertools; +use super::AggregatePayload; use crate::pipelines::processors::transforms::aggregator::aggregate_meta::AggregateMeta; use crate::pipelines::processors::transforms::aggregator::aggregate_meta::HashTablePayload; use crate::pipelines::processors::transforms::aggregator::aggregate_meta::SerializedPayload; @@ -61,14 +64,14 @@ struct InputPortState { pub struct TransformPartitionBucket { output: Arc, inputs: Vec, - + params: Arc, method: Method, working_bucket: isize, pushing_bucket: isize, initialized_all_inputs: bool, buckets_blocks: BTreeMap>, flush_state: PayloadFlushState, - partition_payloads: Vec, + agg_payloads: Vec, unsplitted_blocks: Vec, max_partition_count: usize, _phantom: PhantomData, @@ -77,7 +80,11 @@ pub struct TransformPartitionBucket TransformPartitionBucket { - pub fn create(method: Method, input_nums: usize) -> Result { + pub fn create( + method: Method, + input_nums: usize, + params: Arc, + ) -> Result { let mut inputs = Vec::with_capacity(input_nums); for _index in 0..input_nums { @@ -89,7 +96,7 @@ impl Ok(TransformPartitionBucket { method, - // params, + params, inputs, working_bucket: 0, pushing_bucket: 0, @@ -97,7 +104,7 @@ impl buckets_blocks: BTreeMap::new(), unsplitted_blocks: vec![], flush_state: PayloadFlushState::default(), - partition_payloads: vec![], + agg_payloads: vec![], initialized_all_inputs: false, max_partition_count: 0, _phantom: Default::default(), @@ -127,7 +134,8 @@ impl } // We pull the first unsplitted data block - if self.inputs[index].bucket > SINGLE_LEVEL_BUCKET_NUM { + if self.inputs[index].bucket > SINGLE_LEVEL_BUCKET_NUM && self.max_partition_count == 0 + { continue; } @@ -140,7 +148,8 @@ impl let data_block = self.inputs[index].port.pull_data().unwrap()?; self.inputs[index].bucket = self.add_bucket(data_block); - if self.inputs[index].bucket <= SINGLE_LEVEL_BUCKET_NUM { + if self.inputs[index].bucket <= SINGLE_LEVEL_BUCKET_NUM || self.max_partition_count > 0 + { self.inputs[index].port.set_need_data(); self.initialized_all_inputs = false; } @@ -155,10 +164,17 @@ impl let (bucket, res) = match block_meta { AggregateMeta::Spilling(_) => unreachable!(), AggregateMeta::Partitioned { .. } => unreachable!(), + AggregateMeta::AggregateHashTable(_) => unreachable!(), AggregateMeta::BucketSpilled(payload) => { (payload.bucket, SINGLE_LEVEL_BUCKET_NUM) } - AggregateMeta::Serialized(payload) => (payload.bucket, payload.bucket), + AggregateMeta::Serialized(payload) => { + if payload.max_partition_count > 0 { + self.max_partition_count = + self.max_partition_count.max(payload.max_partition_count); + } + (payload.bucket, payload.bucket) + } AggregateMeta::HashTable(payload) => (payload.bucket, payload.bucket), AggregateMeta::Spilled(_) => { let meta = data_block.take_meta().unwrap(); @@ -190,11 +206,10 @@ impl unreachable!() } - AggregateMeta::AggregateHashTable(p) => { + AggregateMeta::AggregatePayload(p) => { self.max_partition_count = - self.max_partition_count.max(p.partition_count()); - - (0, 0) + self.max_partition_count.max(p.max_partition_count); + (p.bucket, p.bucket) } }; @@ -215,10 +230,73 @@ impl if self.max_partition_count > 0 { let meta = data_block.take_meta().unwrap(); - if let Some(AggregateMeta::AggregateHashTable(p)) = - AggregateMeta::::downcast_from(meta) - { - self.partition_payloads.push(p); + if let Some(block_meta) = AggregateMeta::::downcast_from(meta) { + return match block_meta { + AggregateMeta::AggregatePayload(p) => { + let res = p.bucket; + self.agg_payloads.push(p); + res + } + AggregateMeta::Serialized(p) => { + let rows_num = p.data_block.num_rows(); + let radix_bits = p.max_partition_count.trailing_zeros() as u64; + let config = HashTableConfig::default().with_initial_radix_bits(radix_bits); + let mut state = ProbeState::default(); + let capacity = AggregateHashTable::get_capacity_for_count(rows_num); + let mut hashtable = AggregateHashTable::new_with_capacity( + self.params.group_data_types.clone(), + self.params.aggregate_functions.clone(), + config, + capacity, + ); + hashtable.direct_append = true; + + let agg_len = self.params.aggregate_functions.len(); + let group_len = self.params.group_columns.len(); + let agg_states = (0..agg_len) + .map(|i| { + p.data_block + .get_by_offset(i) + .value + .as_column() + .unwrap() + .clone() + }) + .collect::>(); + let group_columns = (agg_len..(agg_len + group_len)) + .map(|i| { + p.data_block + .get_by_offset(i) + .value + .as_column() + .unwrap() + .clone() + }) + .collect::>(); + + let _ = hashtable + .add_groups( + &mut state, + &group_columns, + &[vec![]], + &agg_states, + rows_num, + ) + .unwrap(); + + for (bucket, payload) in hashtable.payload.payloads.into_iter().enumerate() + { + self.agg_payloads.push(AggregatePayload { + bucket: bucket as isize, + payload, + max_partition_count: p.max_partition_count, + }); + } + + p.bucket + } + _ => unreachable!(), + }; } return 0; } @@ -293,7 +371,7 @@ impl blocks.push(match data_block.is_empty() { true => None, false => Some(DataBlock::empty_with_meta( - AggregateMeta::::create_serialized(bucket as isize, data_block), + AggregateMeta::::create_serialized(bucket as isize, data_block, 0), )), }); } @@ -349,8 +427,10 @@ impl Processor return Ok(Event::NeedData); } - if self.partition_payloads.len() == self.inputs.len() - || (!self.buckets_blocks.is_empty() && !self.unsplitted_blocks.is_empty()) + if !self.agg_payloads.is_empty() + || (!self.buckets_blocks.is_empty() + && !self.unsplitted_blocks.is_empty() + && self.max_partition_count == 0) { // Split data blocks if it's unsplitted. return Ok(Event::Sync); @@ -423,50 +503,42 @@ impl Processor } fn process(&mut self) -> Result<()> { - if !self.partition_payloads.is_empty() { - let mut payloads = Vec::with_capacity(self.partition_payloads.len()); - - for p in self.partition_payloads.drain(0..) { - if p.partition_count() != self.max_partition_count { - let p = p.repartition(self.max_partition_count, &mut self.flush_state); - payloads.push(p); + if !self.agg_payloads.is_empty() { + let group_types = self.params.group_data_types.clone(); + let aggrs = self.params.aggregate_functions.clone(); + + let mut partitioned_payload = PartitionedPayload::new( + group_types.clone(), + aggrs.clone(), + self.max_partition_count as u64, + ); + + for agg_payload in self.agg_payloads.drain(0..) { + partitioned_payload + .arenas + .extend_from_slice(&agg_payload.payload.arenas); + if agg_payload.max_partition_count != self.max_partition_count { + debug_assert!(agg_payload.max_partition_count < self.max_partition_count); + partitioned_payload.combine_single(agg_payload.payload, &mut self.flush_state); } else { - payloads.push(p); - }; - } - - let group_types = payloads[0].group_types.clone(); - let aggrs = payloads[0].aggrs.clone(); - - let mut payload_map = (0..self.max_partition_count).map(|_| vec![]).collect_vec(); - - // All arenas should be kept in the bucket partition payload - let mut arenas = vec![]; - - for mut payload in payloads.into_iter() { - for (bucket, p) in payload.payloads.into_iter().enumerate() { - payload_map[bucket].push(p); + partitioned_payload.payloads[agg_payload.bucket as usize] + .combine(agg_payload.payload); } - arenas.append(&mut payload.arenas); } - for (bucket, mut payloads) in payload_map.into_iter().enumerate() { - let mut partition_payload = - PartitionedPayload::new(group_types.clone(), aggrs.clone(), 1); - - for payload in payloads.drain(0..) { - partition_payload.combine_single(payload, &mut self.flush_state); - } - - partition_payload.arenas.extend_from_slice(&arenas); + for (bucket, payload) in partitioned_payload.payloads.into_iter().enumerate() { + let mut part = PartitionedPayload::new(group_types.clone(), aggrs.clone(), 1); + part.arenas.extend_from_slice(&partitioned_payload.arenas); + part.combine_single(payload, &mut self.flush_state); - if partition_payload.len() != 0 { + if part.len() != 0 { self.buckets_blocks .insert(bucket as isize, vec![DataBlock::empty_with_meta( - AggregateMeta::::create_agg_hashtable(partition_payload), + AggregateMeta::::create_agg_hashtable(part), )]); } } + return Ok(()); } @@ -489,6 +561,7 @@ impl Processor AggregateMeta::Serialized(payload) => self.partition_block(payload)?, AggregateMeta::HashTable(payload) => self.partition_hashtable(payload)?, AggregateMeta::AggregateHashTable(_) => unreachable!(), + AggregateMeta::AggregatePayload(_) => unreachable!(), }; for (bucket, block) in data_blocks.into_iter().enumerate() { @@ -516,7 +589,8 @@ pub fn build_partition_bucket, ) -> Result<()> { let input_nums = pipeline.output_len(); - let transform = TransformPartitionBucket::::create(method.clone(), input_nums)?; + let transform = + TransformPartitionBucket::::create(method.clone(), input_nums, params.clone())?; let output = transform.get_output(); let inputs_port = transform.get_inputs(); diff --git a/src/query/sql/src/executor/physical_plan_visitor.rs b/src/query/sql/src/executor/physical_plan_visitor.rs index e1a258e50fa61..d536298c14800 100644 --- a/src/query/sql/src/executor/physical_plan_visitor.rs +++ b/src/query/sql/src/executor/physical_plan_visitor.rs @@ -178,6 +178,7 @@ pub trait PhysicalPlanReplacer { Ok(PhysicalPlan::AggregatePartial(AggregatePartial { plan_id: plan.plan_id, input: Box::new(input), + enable_experimental_aggregate_hashtable: plan.enable_experimental_aggregate_hashtable, group_by: plan.group_by.clone(), group_by_display: plan.group_by_display.clone(), agg_funcs: plan.agg_funcs.clone(), diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs index e68e0e3704adf..f5285e628608b 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs @@ -157,6 +157,8 @@ impl PhysicalPlanBuilder { let settings = self.ctx.get_settings(); let group_by_shuffle_mode = settings.get_group_by_shuffle_mode()?; + let enable_experimental_aggregate_hashtable = + settings.get_enable_experimental_aggregate_hashtable()?; if let Some(grouping_sets) = agg.grouping_sets.as_ref() { assert_eq!(grouping_sets.dup_group_items.len(), group_items.len() - 1); // ignore `_grouping_id`. @@ -191,6 +193,7 @@ impl PhysicalPlanBuilder { plan_id: 0, input: Box::new(PhysicalPlan::AggregateExpand(expand)), agg_funcs, + enable_experimental_aggregate_hashtable, group_by_display, group_by: group_items, stat_info: Some(stat_info), @@ -200,6 +203,7 @@ impl PhysicalPlanBuilder { plan_id: 0, input, agg_funcs, + enable_experimental_aggregate_hashtable, group_by_display, group_by: group_items, stat_info: Some(stat_info), @@ -208,17 +212,42 @@ impl PhysicalPlanBuilder { let settings = self.ctx.get_settings(); let efficiently_memory = settings.get_efficiently_memory_group_by()?; - - let group_by_key_index = - aggregate_partial.output_schema()?.num_fields() - 1; - let group_by_key_data_type = DataBlock::choose_hash_method_with_types( - &agg.group_items - .iter() - .map(|v| v.scalar.data_type()) - .collect::>>()?, - efficiently_memory, - )? - .data_type(); + let enable_experimental_aggregate_hashtable = + settings.get_enable_experimental_aggregate_hashtable()?; + + let keys = if enable_experimental_aggregate_hashtable { + let schema = aggregate_partial.output_schema()?; + let start = aggregate_partial.agg_funcs.len(); + let end = schema.num_fields(); + let mut groups = Vec::with_capacity(end - start); + for idx in start..end { + let group_key = RemoteExpr::ColumnRef { + span: None, + id: idx, + data_type: schema.field(idx).data_type().clone(), + display_name: (idx - start).to_string(), + }; + groups.push(group_key); + } + groups + } else { + let group_by_key_index = + aggregate_partial.output_schema()?.num_fields() - 1; + let group_by_key_data_type = DataBlock::choose_hash_method_with_types( + &agg.group_items + .iter() + .map(|v| v.scalar.data_type()) + .collect::>>()?, + efficiently_memory, + )? + .data_type(); + vec![RemoteExpr::ColumnRef { + span: None, + id: group_by_key_index, + data_type: group_by_key_data_type, + display_name: "_group_by_key".to_string(), + }] + }; PhysicalPlan::Exchange(Exchange { plan_id: 0, @@ -226,12 +255,7 @@ impl PhysicalPlanBuilder { allow_adjust_parallelism: true, ignore_exchange: false, input: Box::new(PhysicalPlan::AggregatePartial(aggregate_partial)), - keys: vec![RemoteExpr::ColumnRef { - span: None, - id: group_by_key_index, - data_type: group_by_key_data_type, - display_name: "_group_by_key".to_string(), - }], + keys, }) } _ => { @@ -246,6 +270,7 @@ impl PhysicalPlanBuilder { PhysicalPlan::AggregatePartial(AggregatePartial { plan_id: 0, agg_funcs, + enable_experimental_aggregate_hashtable, group_by_display, group_by: group_items, input: Box::new(PhysicalPlan::AggregateExpand(expand)), @@ -255,6 +280,7 @@ impl PhysicalPlanBuilder { PhysicalPlan::AggregatePartial(AggregatePartial { plan_id: 0, agg_funcs, + enable_experimental_aggregate_hashtable, group_by_display, group_by: group_items, input: Box::new(input), diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs index f17b1ccb0d9ad..88d110102cafc 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs @@ -14,6 +14,7 @@ use databend_common_exception::Result; use databend_common_expression::types::DataType; +#[allow(unused_imports)] use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchemaRef; @@ -31,7 +32,7 @@ pub struct AggregatePartial { pub input: Box, pub group_by: Vec, pub agg_funcs: Vec, - + pub enable_experimental_aggregate_hashtable: bool, pub group_by_display: Vec, // Only used for explain @@ -41,6 +42,33 @@ pub struct AggregatePartial { impl AggregatePartial { pub fn output_schema(&self) -> Result { let input_schema = self.input.output_schema()?; + + if self.enable_experimental_aggregate_hashtable { + let mut fields = Vec::with_capacity(self.agg_funcs.len() + self.group_by.len()); + for agg in self.agg_funcs.iter() { + fields.push(DataField::new( + &agg.output_column.to_string(), + DataType::Binary, + )); + } + + let group_types = self + .group_by + .iter() + .map(|index| { + Ok(input_schema + .field_with_name(&index.to_string())? + .data_type() + .clone()) + }) + .collect::>>()?; + + for (idx, data_type) in group_types.iter().enumerate() { + fields.push(DataField::new(&idx.to_string(), data_type.clone())); + } + return Ok(DataSchemaRefExt::create(fields)); + } + let mut fields = Vec::with_capacity(self.agg_funcs.len() + self.group_by.is_empty() as usize); for agg in self.agg_funcs.iter() { @@ -65,6 +93,7 @@ impl AggregatePartial { )?; fields.push(DataField::new("_group_by_key", method.data_type())); } + Ok(DataSchemaRefExt::create(fields)) } }