Skip to content

Commit

Permalink
add test for BlockedNullState.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Aug 21, 2024
1 parent 76b91ce commit 5af6be3
Showing 1 changed file with 255 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,238 @@ mod test {
use rand::{rngs::ThreadRng, Rng};
use std::collections::HashSet;

trait TestNullState {
fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send;

fn build_bool_buffer(&self) -> BooleanBuffer;

fn build_null_buffer(&mut self) -> NullBuffer;
}

// The original `NullState`
impl TestNullState for NullState {
fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
self.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
value_fn,
);
}

fn build_bool_buffer(&self) -> BooleanBuffer {
self.seen_values.finish_cloned()
}

fn build_null_buffer(&mut self) -> NullBuffer {
self.build(EmitTo::All)
}
}

// The new `BlockedNullState` in flat mode
struct BlockedNullStateInFlatMode(BlockedNullState);

impl BlockedNullStateInFlatMode {
fn new() -> Self {
let null_state = BlockedNullState::new(GroupStatesMode::Flat);

Self(null_state)
}
}

impl TestNullState for BlockedNullStateInFlatMode {
fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
self.0.accumulate_for_flat(
group_indices,
values,
opt_filter,
total_num_groups,
value_fn,
);
}

fn build_bool_buffer(&self) -> BooleanBuffer {
self.0.seen_values_blocks.current().unwrap().finish_cloned()
}

fn build_null_buffer(&mut self) -> NullBuffer {
self.0.build(EmitTo::All)
}
}

// The new `BlockedNullState` in blocked mode
struct BlockedNullStateInBlockedMode {
null_state: BlockedNullState,
block_size: usize,
}

impl BlockedNullStateInBlockedMode {
fn new() -> Self {
let null_state = BlockedNullState::new(GroupStatesMode::Blocked(4));

Self {
null_state,
block_size: 4,
}
}
}

impl TestNullState for BlockedNullStateInBlockedMode {
fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
self.null_state.accumulate_for_blocked(
group_indices,
values,
opt_filter,
total_num_groups,
self.block_size,
value_fn,
);
}

fn build_bool_buffer(&self) -> BooleanBuffer {
let mut ret_builder = BooleanBufferBuilder::new(0);
for blk in self.null_state.seen_values_blocks.iter() {
let buf = blk.finish_cloned();
for seen in buf.iter() {
ret_builder.append(seen);
}
}
ret_builder.finish()
}

fn build_null_buffer(&mut self) -> NullBuffer {
let mut init_buffer = NullBuffer::new(BooleanBufferBuilder::new(0).finish());
loop {
let blk = self.null_state.build(EmitTo::NextBlock(false));
if blk.is_empty() {
break;
}

init_buffer = NullBuffer::union(Some(&init_buffer), Some(&blk)).unwrap();
}

init_buffer
}
}

#[derive(Debug, Clone, Copy)]
enum AccumulateTest {
Original,
Flat,
Blocked,
}

impl AccumulateTest {
fn run(
&self,
group_indices: &[usize],
values: &UInt32Array,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) {
match self {
AccumulateTest::Original => {
Fixture::accumulate_test(
group_indices,
&values,
opt_filter,
total_num_groups,
NullState::new(),
);
}
AccumulateTest::Flat => {
Fixture::accumulate_test(
group_indices,
&values,
opt_filter,
total_num_groups,
BlockedNullStateInFlatMode::new(),
);
}
AccumulateTest::Blocked => {
Fixture::accumulate_test(
group_indices,
&values,
opt_filter,
total_num_groups,
BlockedNullStateInBlockedMode::new(),
);
}
}
}
}

#[test]
fn accumulate_test_original() {
do_accumulate_test(AccumulateTest::Original);
}

#[test]
fn accumulate_test_flat() {
do_accumulate_test(AccumulateTest::Flat);
}

#[test]
fn accumulate_test_blocked() {
do_accumulate_test(AccumulateTest::Blocked);
}

#[test]
fn accumulate_fuzz_test_original() {
do_accumulate_fuzz_test(AccumulateTest::Original);
}

#[test]
fn accumulate_fuzz_test_flat() {
do_accumulate_fuzz_test(AccumulateTest::Flat);
}

#[test]
fn accumulate() {
fn accumulate_fuzz_test_blocked() {
do_accumulate_fuzz_test(AccumulateTest::Blocked);
}

fn do_accumulate_test(accumulate_test: AccumulateTest) {
let group_indices = (0..100).collect();
let values = (0..100).map(|i| (i + 1) * 10).collect();
let values_with_nulls = (0..100)
Expand Down Expand Up @@ -828,15 +1058,15 @@ mod test {
values,
values_with_nulls,
filter,
accumulate_test,
}
.run()
}

#[test]
fn accumulate_fuzz() {
fn do_accumulate_fuzz_test(accumulate_test: AccumulateTest) {
let mut rng = rand::thread_rng();
for _ in 0..100 {
Fixture::new_random(&mut rng).run();
Fixture::new_random(&mut rng, accumulate_test).run();
}
}

Expand All @@ -854,10 +1084,13 @@ mod test {

/// filter (defaults to None)
filter: BooleanArray,

/// tested null state for value test
accumulate_test: AccumulateTest,
}

impl Fixture {
fn new_random(rng: &mut ThreadRng) -> Self {
fn new_random(rng: &mut ThreadRng, accumulate_test: AccumulateTest) -> Self {
// Number of input values in a batch
let num_values: usize = rng.gen_range(1..200);
// number of distinct groups
Expand Down Expand Up @@ -905,6 +1138,7 @@ mod test {
values,
values_with_nulls,
filter,
accumulate_test,
}
}

Expand All @@ -929,26 +1163,31 @@ mod test {
let filter = &self.filter;

// no null, no filters
Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
self.accumulate_test.run(
group_indices,
&values_array,
None,
total_num_groups,
);

// nulls, no filters
Self::accumulate_test(
self.accumulate_test.run(
group_indices,
&values_with_nulls_array,
None,
total_num_groups,
);

// no nulls, filters
Self::accumulate_test(
self.accumulate_test.run(
group_indices,
&values_array,
Some(filter),
total_num_groups,
);

// nulls, filters
Self::accumulate_test(
self.accumulate_test.run(
group_indices,
&values_with_nulls_array,
Some(filter),
Expand All @@ -959,17 +1198,19 @@ mod test {
/// Calls `NullState::accumulate` and `accumulate_indices` to
/// ensure it generates the correct values.
///
fn accumulate_test(
fn accumulate_test<SV: TestNullState>(
group_indices: &[usize],
values: &UInt32Array,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
null_state_for_value_test: SV,
) {
Self::accumulate_values_test(
group_indices,
values,
opt_filter,
total_num_groups,
null_state_for_value_test,
);
Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);

Expand All @@ -988,14 +1229,14 @@ mod test {

/// This is effectively a different implementation of
/// accumulate that we compare with the above implementation
fn accumulate_values_test(
fn accumulate_values_test<S: TestNullState>(
group_indices: &[usize],
values: &UInt32Array,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut null_state: S,
) {
let mut accumulated_values = vec![];
let mut null_state = NullState::new();

null_state.accumulate(
group_indices,
Expand Down Expand Up @@ -1039,13 +1280,13 @@ mod test {

assert_eq!(accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
let seen_values = null_state.seen_values.finish_cloned();
let seen_values = null_state.build_bool_buffer();
mock.validate_seen_values(&seen_values);

// Validate the final buffer (one value per group)
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);

let null_buffer = null_state.build(EmitTo::All);
let null_buffer = null_state.build_null_buffer();

assert_eq!(null_buffer, expected_null_buffer);
}
Expand Down

0 comments on commit 5af6be3

Please sign in to comment.