Skip to content

Commit 2e07003

Browse files
joroKr21findepi
authored andcommitted
Respect ignore_nulls in array_agg (apache#15544)
* Respect ignore_nulls in array_agg * Reduce code duplication * Add another test (cherry picked from commit 5bb0a98)
1 parent a96af27 commit 2e07003

File tree

3 files changed

+106
-30
lines changed

3 files changed

+106
-30
lines changed

datafusion/functions-aggregate/benches/array_agg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) {
3636
b.iter(|| {
3737
#[allow(clippy::unit_arg)]
3838
black_box(
39-
ArrayAggAccumulator::try_new(&list_item_data_type)
39+
ArrayAggAccumulator::try_new(&list_item_data_type, false)
4040
.unwrap()
4141
.merge_batch(&[values.clone()])
4242
.unwrap(),

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20-
use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray};
21-
use arrow::compute::SortOptions;
22-
use arrow::datatypes::DataType;
20+
use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray};
21+
use arrow::compute::{filter, SortOptions};
22+
use arrow::datatypes::{DataType, Field, Fields};
2323

24-
use arrow_schema::{Field, Fields};
2524
use datafusion_common::cast::as_list_array;
2625
use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
2726
use datafusion_common::{exec_err, ScalarValue};
@@ -141,6 +140,8 @@ impl AggregateUDFImpl for ArrayAgg {
141140

142141
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
143142
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
143+
let ignore_nulls =
144+
acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
144145

145146
if acc_args.is_distinct {
146147
// Limitation similar to Postgres. The aggregation function can only mix
@@ -167,14 +168,19 @@ impl AggregateUDFImpl for ArrayAgg {
167168
}
168169
sort_option = Some(order.options)
169170
}
171+
170172
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
171173
&data_type,
172174
sort_option,
175+
ignore_nulls,
173176
)?));
174177
}
175178

176179
if acc_args.ordering_req.is_empty() {
177-
return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?));
180+
return Ok(Box::new(ArrayAggAccumulator::try_new(
181+
&data_type,
182+
ignore_nulls,
183+
)?));
178184
}
179185

180186
let ordering_dtypes = acc_args
@@ -188,6 +194,7 @@ impl AggregateUDFImpl for ArrayAgg {
188194
&ordering_dtypes,
189195
acc_args.ordering_req.clone(),
190196
acc_args.is_reversed,
197+
ignore_nulls,
191198
)
192199
.map(|acc| Box::new(acc) as _)
193200
}
@@ -205,18 +212,20 @@ impl AggregateUDFImpl for ArrayAgg {
205212
pub struct ArrayAggAccumulator {
206213
values: Vec<ArrayRef>,
207214
datatype: DataType,
215+
ignore_nulls: bool,
208216
}
209217

210218
impl ArrayAggAccumulator {
211219
/// new array_agg accumulator based on given item data type
212-
pub fn try_new(datatype: &DataType) -> Result<Self> {
220+
pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
213221
Ok(Self {
214222
values: vec![],
215223
datatype: datatype.clone(),
224+
ignore_nulls,
216225
})
217226
}
218227

219-
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
228+
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
220229
/// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
221230
fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
222231
let offsets = list_array.value_offsets();
@@ -240,15 +249,15 @@ impl ArrayAggAccumulator {
240249
return Some(list_array.values().slice(0, 0));
241250
}
242251

243-
// According to the Arrow spec, null values can point to non empty lists
252+
// According to the Arrow spec, null values can point to non-empty lists
244253
// So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
245254

246255
// Unwrapping is safe as we just checked if there is a null value
247256
let nulls = list_array.nulls().unwrap();
248257

249258
let mut valid_slices_iter = nulls.valid_slices();
250259

251-
// This is safe as we validated that that are at least 1 valid value in the array
260+
// This is safe as we validated that there is at least 1 valid value in the array
252261
let (start, end) = valid_slices_iter.next().unwrap();
253262

254263
let start_offset = offsets[start];
@@ -258,7 +267,7 @@ impl ArrayAggAccumulator {
258267
let mut end_offset_of_last_valid_value = offsets[end];
259268

260269
for (start, end) in valid_slices_iter {
261-
// If there is a null value that point to a non empty list than the start offset of the valid value
270+
// If there is a null value that point to a non-empty list than the start offset of the valid value
262271
// will be different that the end offset of the last valid value
263272
if offsets[start] != end_offset_of_last_valid_value {
264273
return None;
@@ -289,10 +298,23 @@ impl Accumulator for ArrayAggAccumulator {
289298
return internal_err!("expects single batch");
290299
}
291300

292-
let val = Arc::clone(&values[0]);
293-
if val.len() > 0 {
301+
let val = &values[0];
302+
let nulls = if self.ignore_nulls {
303+
val.logical_nulls()
304+
} else {
305+
None
306+
};
307+
308+
let val = match nulls {
309+
Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
310+
Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
311+
None => Arc::clone(val),
312+
};
313+
314+
if !val.is_empty() {
294315
self.values.push(val);
295316
}
317+
296318
Ok(())
297319
}
298320

@@ -361,17 +383,20 @@ struct DistinctArrayAggAccumulator {
361383
values: HashSet<ScalarValue>,
362384
datatype: DataType,
363385
sort_options: Option<SortOptions>,
386+
ignore_nulls: bool,
364387
}
365388

366389
impl DistinctArrayAggAccumulator {
367390
pub fn try_new(
368391
datatype: &DataType,
369392
sort_options: Option<SortOptions>,
393+
ignore_nulls: bool,
370394
) -> Result<Self> {
371395
Ok(Self {
372396
values: HashSet::new(),
373397
datatype: datatype.clone(),
374398
sort_options,
399+
ignore_nulls,
375400
})
376401
}
377402
}
@@ -386,11 +411,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
386411
return Ok(());
387412
}
388413

389-
let array = &values[0];
414+
let val = &values[0];
415+
let nulls = if self.ignore_nulls {
416+
val.logical_nulls()
417+
} else {
418+
None
419+
};
390420

391-
for i in 0..array.len() {
392-
let scalar = ScalarValue::try_from_array(&array, i)?;
393-
self.values.insert(scalar);
421+
let nulls = nulls.as_ref();
422+
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
423+
for i in 0..val.len() {
424+
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
425+
self.values.insert(ScalarValue::try_from_array(val, i)?);
426+
}
427+
}
394428
}
395429

396430
Ok(())
@@ -472,6 +506,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
472506
ordering_req: LexOrdering,
473507
/// Whether the aggregation is running in reverse.
474508
reverse: bool,
509+
/// Whether the aggregation should ignore null values.
510+
ignore_nulls: bool,
475511
}
476512

477513
impl OrderSensitiveArrayAggAccumulator {
@@ -482,6 +518,7 @@ impl OrderSensitiveArrayAggAccumulator {
482518
ordering_dtypes: &[DataType],
483519
ordering_req: LexOrdering,
484520
reverse: bool,
521+
ignore_nulls: bool,
485522
) -> Result<Self> {
486523
let mut datatypes = vec![datatype.clone()];
487524
datatypes.extend(ordering_dtypes.iter().cloned());
@@ -491,6 +528,7 @@ impl OrderSensitiveArrayAggAccumulator {
491528
datatypes,
492529
ordering_req,
493530
reverse,
531+
ignore_nulls,
494532
})
495533
}
496534
}
@@ -501,11 +539,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
501539
return Ok(());
502540
}
503541

504-
let n_row = values[0].len();
505-
for index in 0..n_row {
506-
let row = get_row_at_idx(values, index)?;
507-
self.values.push(row[0].clone());
508-
self.ordering_values.push(row[1..].to_vec());
542+
let val = &values[0];
543+
let ord = &values[1..];
544+
let nulls = if self.ignore_nulls {
545+
val.logical_nulls()
546+
} else {
547+
None
548+
};
549+
550+
let nulls = nulls.as_ref();
551+
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
552+
for i in 0..val.len() {
553+
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
554+
self.values.push(ScalarValue::try_from_array(val, i)?);
555+
self.ordering_values.push(get_row_at_idx(ord, i)?)
556+
}
557+
}
509558
}
510559

511560
Ok(())
@@ -666,7 +715,7 @@ impl OrderSensitiveArrayAggAccumulator {
666715
#[cfg(test)]
667716
mod tests {
668717
use super::*;
669-
use arrow::datatypes::{FieldRef, Schema};
718+
use arrow::datatypes::Schema;
670719
use datafusion_common::cast::as_generic_string_array;
671720
use datafusion_common::internal_err;
672721
use datafusion_physical_expr::expressions::Column;
@@ -947,14 +996,12 @@ mod tests {
947996
fn new(data_type: DataType) -> Self {
948997
Self {
949998
data_type: data_type.clone(),
950-
distinct: Default::default(),
999+
distinct: false,
9511000
ordering: Default::default(),
9521001
schema: Schema {
9531002
fields: Fields::from(vec![Field::new(
9541003
"col",
955-
DataType::List(FieldRef::new(Field::new(
956-
"item", data_type, true,
957-
))),
1004+
DataType::new_list(data_type, true),
9581005
true,
9591006
)]),
9601007
metadata: Default::default(),

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,17 +289,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES
289289
('b', [1,0]),
290290
('b', [1,0]),
291291
('b', [1,0]),
292-
('b', [0,1])
292+
('b', [0,1]),
293+
(NULL, [0,1]),
294+
('b', NULL)
293295
;
294296

295297
# Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort,
296298
# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs`
297299
query ??
298300
select array_sort(c1), array_sort(c2) from (
299-
select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table
301+
select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table
300302
);
301303
----
302-
[b, w] [[0, 1], [1, 0]]
304+
[NULL, b, w] [[0, 1], [1, 0]]
303305

304306
statement ok
305307
drop table array_agg_distinct_list_table;
@@ -3194,6 +3196,33 @@ select array_agg(column1) from t;
31943196
statement ok
31953197
drop table t;
31963198

3199+
# array_agg_ignore_nulls
3200+
statement ok
3201+
create table t as values (NULL, ''), (1, 'c'), (2, 'a'), (NULL, 'b'), (4, NULL), (NULL, NULL), (5, 'a');
3202+
3203+
query ?
3204+
select array_agg(column1) ignore nulls as c1 from t;
3205+
----
3206+
[1, 2, 4, 5]
3207+
3208+
query II
3209+
select count(*), array_length(array_agg(distinct column2) ignore nulls) from t;
3210+
----
3211+
7 4
3212+
3213+
query ?
3214+
select array_agg(column2 order by column1) ignore nulls from t;
3215+
----
3216+
[c, a, a, , b]
3217+
3218+
query ?
3219+
select array_agg(DISTINCT column2 order by column2) ignore nulls from t;
3220+
----
3221+
[, a, b, c]
3222+
3223+
statement ok
3224+
drop table t;
3225+
31973226
# variance_single_value
31983227
query RRRR
31993228
select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq;

0 commit comments

Comments
 (0)