Skip to content

Commit a4c6f42

Browse files
committed
add new emit mode.
1 parent f690940 commit a4c6f42

File tree

18 files changed

+102
-26
lines changed

18 files changed

+102
-26
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
307307

308308
/// Generate output, as specified by `emit_to` and update the intermediate state
309309
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
310-
let counts = emit_to.take_needed(&mut self.counts);
311-
let prods = emit_to.take_needed(&mut self.prods);
310+
let counts = emit_to.take_needed_rows(&mut self.counts);
311+
let prods = emit_to.take_needed_rows(&mut self.prods);
312312
let nulls = self.null_state.build(emit_to);
313313

314314
assert_eq!(nulls.len(), prods.len());
@@ -346,10 +346,10 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
346346
let nulls = self.null_state.build(emit_to);
347347
let nulls = Some(nulls);
348348

349-
let counts = emit_to.take_needed(&mut self.counts);
349+
let counts = emit_to.take_needed_rows(&mut self.counts);
350350
let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy
351351

352-
let prods = emit_to.take_needed(&mut self.prods);
352+
let prods = emit_to.take_needed_rows(&mut self.prods);
353353
let prods = PrimitiveArray::<Float64Type>::new(prods.into(), nulls) // zero copy
354354
.with_data_type(self.prod_data_type.clone());
355355

datafusion/expr-common/src/groups_accumulator.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
//! Vectorized [`GroupsAccumulator`]
1919
20+
use std::collections::VecDeque;
21+
2022
use arrow::array::{ArrayRef, BooleanArray};
2123
use datafusion_common::{not_impl_err, DataFusionError, Result};
2224

@@ -31,15 +33,40 @@ pub enum EmitTo {
3133
/// For example, if `n=10`, group_index `0, 1, ... 9` are emitted
3234
/// and group indexes `10, 11, 12, ...` become `0, 1, 2, ...`.
3335
First(usize),
36+
/// Emit next block in the blocked managed groups
37+
///
38+
/// The flag's meaning:
39+
/// - `true` represents new groups still will be added,
40+
/// and we need to shift the values down.
41+
/// - `false` represents no new groups will be added again,
42+
/// and we don't need to shift the values down.
43+
NextBlock(bool),
3444
}
3545

3646
impl EmitTo {
47+
/// Remove and return `needed values` from `values`.
48+
pub fn take_needed<T>(
49+
&self,
50+
values: &mut VecDeque<Vec<T>>,
51+
is_blocked_groups: bool,
52+
) -> Vec<T> {
53+
if is_blocked_groups {
54+
self.take_needed_block(values)
55+
} else {
56+
assert_eq!(values.len(), 1);
57+
self.take_needed_rows(values.back_mut().unwrap())
58+
}
59+
}
60+
3761
/// Removes the number of rows from `v` required to emit the right
3862
/// number of rows, returning a `Vec` with elements taken, and the
3963
/// remaining values in `v`.
4064
///
4165
/// This avoids copying if Self::All
42-
pub fn take_needed<T>(&self, v: &mut Vec<T>) -> Vec<T> {
66+
///
67+
/// NOTICE: only support emit strategies: `Self::All` and `Self::First`
68+
///
69+
pub fn take_needed_rows<T>(&self, v: &mut Vec<T>) -> Vec<T> {
4370
match self {
4471
Self::All => {
4572
// Take the entire vector, leave new (empty) vector
@@ -52,8 +79,23 @@ impl EmitTo {
5279
std::mem::swap(v, &mut t);
5380
t
5481
}
82+
Self::NextBlock(_) => unreachable!("don't support take block in take_needed"),
5583
}
5684
}
85+
86+
/// Removes one block required to emit and return it
87+
///
88+
/// NOTICE: only support emit strategy `Self::NextBlock`
89+
///
90+
fn take_needed_block<T>(&self, blocks: &mut VecDeque<Vec<T>>) -> Vec<T> {
91+
assert!(
92+
matches!(self, Self::NextBlock(_)),
93+
"only support take block in take_needed_block"
94+
);
95+
blocks
96+
.pop_front()
97+
.expect("should not call emit for empty blocks")
98+
}
5799
}
58100

59101
/// `GroupsAccumulator` implements a single aggregate (e.g. AVG) and

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
321321
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
322322
let vec_size_pre = self.states.allocated_size();
323323

324-
let states = emit_to.take_needed(&mut self.states);
324+
let states = emit_to.take_needed_rows(&mut self.states);
325325

326326
let results: Vec<ScalarValue> = states
327327
.into_iter()
@@ -341,7 +341,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
341341
// filtered_null_mask(opt_filter, &values);
342342
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
343343
let vec_size_pre = self.states.allocated_size();
344-
let states = emit_to.take_needed(&mut self.states);
344+
let states = emit_to.take_needed_rows(&mut self.states);
345345

346346
// each accumulator produces a potential vector of values
347347
// which we need to form into columns

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ impl NullState {
228228
}
229229
first_n_null
230230
}
231+
EmitTo::NextBlock(_) => todo!(),
231232
};
232233
NullBuffer::new(nulls)
233234
}

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ where
117117
}
118118
first_n
119119
}
120+
EmitTo::NextBlock(_) => {
121+
unreachable!("this accumulator still not support blocked groups")
122+
}
120123
};
121124

122125
let nulls = self.null_state.build(emit_to);

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ where
115115
}
116116

117117
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
118-
let values = emit_to.take_needed(&mut self.values);
118+
let values = emit_to.take_needed_rows(&mut self.values);
119119
let nulls = self.null_state.build(emit_to);
120120
let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy
121121
.with_data_type(self.data_type.clone());

datafusion/functions-aggregate/src/average.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ where
486486
}
487487

488488
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
489-
let counts = emit_to.take_needed(&mut self.counts);
490-
let sums = emit_to.take_needed(&mut self.sums);
489+
let counts = emit_to.take_needed_rows(&mut self.counts);
490+
let sums = emit_to.take_needed_rows(&mut self.sums);
491491
let nulls = self.null_state.build(emit_to);
492492

493493
assert_eq!(nulls.len(), sums.len());
@@ -526,10 +526,10 @@ where
526526
let nulls = self.null_state.build(emit_to);
527527
let nulls = Some(nulls);
528528

529-
let counts = emit_to.take_needed(&mut self.counts);
529+
let counts = emit_to.take_needed_rows(&mut self.counts);
530530
let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
531531

532-
let sums = emit_to.take_needed(&mut self.sums);
532+
let sums = emit_to.take_needed_rows(&mut self.sums);
533533
let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
534534
.with_data_type(self.sum_data_type.clone());
535535

datafusion/functions-aggregate/src/correlation.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,9 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
448448
let n = match emit_to {
449449
EmitTo::All => self.count.len(),
450450
EmitTo::First(n) => n,
451+
EmitTo::NextBlock(_) => {
452+
unreachable!("this accumulator still not support blocked groups")
453+
}
451454
};
452455

453456
let mut values = Vec::with_capacity(n);
@@ -501,6 +504,9 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
501504
let n = match emit_to {
502505
EmitTo::All => self.count.len(),
503506
EmitTo::First(n) => n,
507+
EmitTo::NextBlock(_) => {
508+
unreachable!("this accumulator still not support blocked groups")
509+
}
504510
};
505511

506512
Ok(vec![

datafusion/functions-aggregate/src/count.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ impl GroupsAccumulator for CountGroupsAccumulator {
537537
}
538538

539539
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
540-
let counts = emit_to.take_needed(&mut self.counts);
540+
let counts = emit_to.take_needed_rows(&mut self.counts);
541541

542542
// Count is always non null (null inputs just don't contribute to the overall values)
543543
let nulls = None;
@@ -548,7 +548,7 @@ impl GroupsAccumulator for CountGroupsAccumulator {
548548

549549
// return arrays for counts
550550
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
551-
let counts = emit_to.take_needed(&mut self.counts);
551+
let counts = emit_to.take_needed_rows(&mut self.counts);
552552
let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
553553
Ok(vec![Arc::new(counts) as ArrayRef])
554554
}

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,14 +396,17 @@ where
396396
}
397397

398398
fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
399-
let result = emit_to.take_needed(&mut self.orderings);
399+
let result = emit_to.take_needed_rows(&mut self.orderings);
400400

401401
match emit_to {
402402
EmitTo::All => self.size_of_orderings = 0,
403403
EmitTo::First(_) => {
404404
self.size_of_orderings -=
405405
result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
406406
}
407+
EmitTo::NextBlock(_) => {
408+
unreachable!("this accumulator still not support blocked groups")
409+
}
407410
}
408411

409412
result
@@ -428,6 +431,9 @@ where
428431
}
429432
first_n
430433
}
434+
EmitTo::NextBlock(_) => {
435+
unreachable!("this group values still not support blocked groups")
436+
}
431437
}
432438
}
433439

@@ -481,7 +487,7 @@ where
481487
&mut self,
482488
emit_to: EmitTo,
483489
) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
484-
emit_to.take_needed(&mut self.min_of_each_group_buf.0);
490+
emit_to.take_needed_rows(&mut self.min_of_each_group_buf.0);
485491
self.min_of_each_group_buf
486492
.1
487493
.truncate(self.min_of_each_group_buf.0.len());
@@ -579,7 +585,7 @@ where
579585
}
580586

581587
fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
582-
let r = emit_to.take_needed(&mut self.vals);
588+
let r = emit_to.take_needed_rows(&mut self.vals);
583589

584590
let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
585591

0 commit comments

Comments
 (0)