Skip to content

Commit

Permalink
a unified and low cost way to compute the different type `BlockedGrou…
Browse files Browse the repository at this point in the history
…pIndex`s.
  • Loading branch information
Rachelint committed Sep 1, 2024
1 parent c2cb573 commit ef91012
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};

pub const MAX_PREALLOC_BLOCK_SIZE: usize = 8192;
const FLAT_GROUP_INDEX_ID_MASK: u64 = 0;
const FLAT_GROUP_INDEX_OFFSET_MASK: u64 = u64::MAX;
const BLOCKED_GROUP_INDEX_ID_MASK: u64 = 0xffffffff00000000;
const BLOCKED_GROUP_INDEX_OFFSET_MASK: u64 = 0x00000000ffffffff;

/// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`]
///
Expand Down Expand Up @@ -452,37 +456,14 @@ impl EmitToExt for EmitTo {
pub struct BlockedGroupIndex {
pub block_id: u32,
pub block_offset: u64,
pub is_blocked: bool,
}

impl BlockedGroupIndex {
#[inline]
pub fn new_from_parts(block_id: u32, block_offset: u64, is_blocked: bool) -> Self {
pub fn new_from_parts(block_id: u32, block_offset: u64) -> Self {
Self {
block_id,
block_offset,
is_blocked,
}
}

#[inline]
pub fn new_flat(raw_index: usize) -> Self {
Self {
block_id: 0,
block_offset: raw_index as u64,
is_blocked: false,
}
}

#[inline]
pub fn new_blocked(raw_index: usize) -> Self {
let block_id = ((raw_index as u64 >> 32) & 0x00000000ffffffff) as u32;
let block_offset = (raw_index as u64) & 0x00000000ffffffff;

Self {
block_id,
block_offset,
is_blocked: true,
}
}

Expand All @@ -496,11 +477,41 @@ impl BlockedGroupIndex {
self.block_offset as usize
}

#[inline]
pub fn as_packed_index(&self) -> usize {
if self.is_blocked {
(((self.block_id as u64) << 32) | self.block_offset) as usize
(((self.block_id as u64) << 32) | self.block_offset) as usize
}
}

pub struct BlockedGroupIndexBuilder {
block_id_mask: u64,
block_offset_mask: u64,
}

impl BlockedGroupIndexBuilder {
#[inline]
pub fn new(is_blocked: bool) -> Self {
if is_blocked {
Self {
block_id_mask: BLOCKED_GROUP_INDEX_ID_MASK,
block_offset_mask: BLOCKED_GROUP_INDEX_OFFSET_MASK,
}
} else {
self.block_offset as usize
Self {
block_id_mask: FLAT_GROUP_INDEX_ID_MASK,
block_offset_mask: FLAT_GROUP_INDEX_OFFSET_MASK,
}
}
}

#[inline]
pub fn build(&self, packed_index: usize) -> BlockedGroupIndex {
let block_id = (((packed_index as u64) & self.block_id_mask) >> 32) as u32;
let block_offset = (packed_index as u64) & self.block_offset_mask;

BlockedGroupIndex {
block_id,
block_offset,
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::datatypes::ArrowPrimitiveType;
use datafusion_expr_common::groups_accumulator::EmitTo;

use crate::aggregate::groups_accumulator::{
BlockedGroupIndex, Blocks, MAX_PREALLOC_BLOCK_SIZE,
BlockedGroupIndex, BlockedGroupIndexBuilder, Blocks, MAX_PREALLOC_BLOCK_SIZE,
};

/// Track the accumulator null state per row: if any values for that
Expand Down Expand Up @@ -303,32 +303,20 @@ impl BlockedNullState {
false,
);
let seen_values_blocks = &mut self.seen_values_blocks;
let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());

if self.block_size.is_some() {
do_blocked_accumulate(
group_indices,
values,
opt_filter,
BlockedGroupIndex::new_blocked,
value_fn,
|index: &BlockedGroupIndex| {
seen_values_blocks[index.block_id()]
.set_bit(index.block_offset(), true);
},
)
} else {
do_blocked_accumulate(
group_indices,
values,
opt_filter,
BlockedGroupIndex::new_flat,
value_fn,
|index: &BlockedGroupIndex| {
seen_values_blocks[index.block_id()]
.set_bit(index.block_offset(), true);
},
);
}
do_blocked_accumulate(
group_indices,
values,
opt_filter,
&group_index_builder,
value_fn,
|group_index| {
seen_values_blocks[group_index.block_id()]
.set_bit(group_index.block_offset(), true);
},
)
}

/// Similar as [NullState::build] but support the blocked version accumulator
Expand Down Expand Up @@ -598,16 +586,15 @@ pub fn accumulate_indices<F>(
}
}

fn do_blocked_accumulate<T, F1, F2, G>(
fn do_blocked_accumulate<T, F1, F2>(
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
group_index_parse_fn: G,
group_index_builder: &BlockedGroupIndexBuilder,
mut value_fn: F1,
mut set_valid_fn: F2,
) where
T: ArrowPrimitiveType + Send,
G: Fn(usize) -> BlockedGroupIndex,
F1: FnMut(&BlockedGroupIndex, T::Native) + Send,
F2: FnMut(&BlockedGroupIndex) + Send,
{
Expand All @@ -617,7 +604,7 @@ fn do_blocked_accumulate<T, F1, F2, G>(
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
let blocked_index = group_index_parse_fn(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand Down Expand Up @@ -645,7 +632,8 @@ fn do_blocked_accumulate<T, F1, F2, G>(
// valid bit was set, real value
let is_valid = (mask & index_mask) != 0;
if is_valid {
let blocked_index = group_index_parse_fn(group_index);
let blocked_index =
group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand All @@ -663,7 +651,7 @@ fn do_blocked_accumulate<T, F1, F2, G>(
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
let blocked_index = group_index_parse_fn(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand All @@ -681,7 +669,7 @@ fn do_blocked_accumulate<T, F1, F2, G>(
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
let blocked_index = group_index_parse_fn(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand All @@ -700,9 +688,9 @@ fn do_blocked_accumulate<T, F1, F2, G>(
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
let blocked_index = group_index_parse_fn(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value)
value_fn(&blocked_index, new_value);
}
}
})
Expand Down Expand Up @@ -933,7 +921,6 @@ mod test {
BlockedGroupIndex::new_from_parts(
block_id as u32,
block_offset as u64,
true,
)
.as_packed_index()
})
Expand Down
78 changes: 27 additions & 51 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use ahash::RandomState;
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{
ensure_enough_room_for_values, BlockedGroupIndex, Blocks, EmitToExt, VecBlocks,
ensure_enough_room_for_values, BlockedGroupIndexBuilder, Blocks, EmitToExt, VecBlocks,
};
use std::collections::HashSet;
use std::ops::BitAnd;
Expand Down Expand Up @@ -399,31 +399,19 @@ impl GroupsAccumulator for CountGroupsAccumulator {
0,
);

if self.block_size.is_some() {
accumulate_indices(
group_indices,
values.logical_nulls().as_ref(),
opt_filter,
|group_index| {
let blocked_index = BlockedGroupIndex::new_blocked(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += 1;
},
);
} else {
accumulate_indices(
group_indices,
values.logical_nulls().as_ref(),
opt_filter,
|group_index| {
let blocked_index = BlockedGroupIndex::new_flat(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += 1;
},
);
}
let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());
accumulate_indices(
group_indices,
values.logical_nulls().as_ref(),
opt_filter,
|group_index| {
let blocked_index = group_index_builder.build(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += 1;
},
);

Ok(())
}
Expand All @@ -450,31 +438,19 @@ impl GroupsAccumulator for CountGroupsAccumulator {
0,
);

if self.block_size.is_some() {
do_count_merge_batch(
values,
group_indices,
opt_filter,
|group_index, partial_count| {
let blocked_index = BlockedGroupIndex::new_blocked(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += partial_count;
},
);
} else {
do_count_merge_batch(
values,
group_indices,
opt_filter,
|group_index, partial_count| {
let blocked_index = BlockedGroupIndex::new_flat(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += partial_count;
},
);
}
let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());
do_count_merge_batch(
values,
group_indices,
opt_filter,
|group_index, partial_count| {
let blocked_index = group_index_builder.build(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += partial_count;
},
);

Ok(())
}
Expand Down
Loading

0 comments on commit ef91012

Please sign in to comment.