1717
1818//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20- use arrow:: array:: { new_empty_array, Array , ArrayRef , AsArray , ListArray , StructArray } ;
21- use arrow:: compute:: SortOptions ;
22- use arrow:: datatypes:: DataType ;
20+ use arrow:: array:: { new_empty_array, Array , ArrayRef , AsArray , BooleanArray , ListArray , StructArray } ;
21+ use arrow:: compute:: { filter , SortOptions } ;
22+ use arrow:: datatypes:: { DataType , Field , Fields } ;
2323
24- use arrow_schema:: { Field , Fields } ;
2524use datafusion_common:: cast:: as_list_array;
2625use datafusion_common:: utils:: { get_row_at_idx, SingleRowListArrayBuilder } ;
2726use datafusion_common:: { exec_err, ScalarValue } ;
@@ -141,6 +140,8 @@ impl AggregateUDFImpl for ArrayAgg {
141140
142141 fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
143142 let data_type = acc_args. exprs [ 0 ] . data_type ( acc_args. schema ) ?;
143+ let ignore_nulls =
144+ acc_args. ignore_nulls && acc_args. exprs [ 0 ] . nullable ( acc_args. schema ) ?;
144145
145146 if acc_args. is_distinct {
146147 // Limitation similar to Postgres. The aggregation function can only mix
@@ -167,14 +168,19 @@ impl AggregateUDFImpl for ArrayAgg {
167168 }
168169 sort_option = Some ( order. options )
169170 }
171+
170172 return Ok ( Box :: new ( DistinctArrayAggAccumulator :: try_new (
171173 & data_type,
172174 sort_option,
175+ ignore_nulls,
173176 ) ?) ) ;
174177 }
175178
176179 if acc_args. ordering_req . is_empty ( ) {
177- return Ok ( Box :: new ( ArrayAggAccumulator :: try_new ( & data_type) ?) ) ;
180+ return Ok ( Box :: new ( ArrayAggAccumulator :: try_new (
181+ & data_type,
182+ ignore_nulls,
183+ ) ?) ) ;
178184 }
179185
180186 let ordering_dtypes = acc_args
@@ -188,6 +194,7 @@ impl AggregateUDFImpl for ArrayAgg {
188194 & ordering_dtypes,
189195 acc_args. ordering_req . clone ( ) ,
190196 acc_args. is_reversed ,
197+ ignore_nulls,
191198 )
192199 . map ( |acc| Box :: new ( acc) as _ )
193200 }
@@ -205,18 +212,20 @@ impl AggregateUDFImpl for ArrayAgg {
205212pub struct ArrayAggAccumulator {
206213 values : Vec < ArrayRef > ,
207214 datatype : DataType ,
215+ ignore_nulls : bool ,
208216}
209217
210218impl ArrayAggAccumulator {
211219 /// new array_agg accumulator based on given item data type
212- pub fn try_new ( datatype : & DataType ) -> Result < Self > {
220+ pub fn try_new ( datatype : & DataType , ignore_nulls : bool ) -> Result < Self > {
213221 Ok ( Self {
214222 values : vec ! [ ] ,
215223 datatype : datatype. clone ( ) ,
224+ ignore_nulls,
216225 } )
217226 }
218227
219- /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
228+ /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non- empty list)
220229 /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
221230 fn get_optional_values_to_merge_as_is ( list_array : & ListArray ) -> Option < ArrayRef > {
222231 let offsets = list_array. value_offsets ( ) ;
@@ -240,15 +249,15 @@ impl ArrayAggAccumulator {
240249 return Some ( list_array. values ( ) . slice ( 0 , 0 ) ) ;
241250 }
242251
243- // According to the Arrow spec, null values can point to non empty lists
252+ // According to the Arrow spec, null values can point to non- empty lists
244253 // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
245254
246255 // Unwrapping is safe as we just checked if there is a null value
247256 let nulls = list_array. nulls ( ) . unwrap ( ) ;
248257
249258 let mut valid_slices_iter = nulls. valid_slices ( ) ;
250259
251- // This is safe as we validated that that are at least 1 valid value in the array
260+ // This is safe as we validated that there is at least 1 valid value in the array
252261 let ( start, end) = valid_slices_iter. next ( ) . unwrap ( ) ;
253262
254263 let start_offset = offsets[ start] ;
@@ -258,7 +267,7 @@ impl ArrayAggAccumulator {
258267 let mut end_offset_of_last_valid_value = offsets[ end] ;
259268
260269 for ( start, end) in valid_slices_iter {
261- // If there is a null value that point to a non empty list than the start offset of the valid value
270+ // If there is a null value that point to a non- empty list than the start offset of the valid value
262271 // will be different that the end offset of the last valid value
263272 if offsets[ start] != end_offset_of_last_valid_value {
264273 return None ;
@@ -289,10 +298,23 @@ impl Accumulator for ArrayAggAccumulator {
289298 return internal_err ! ( "expects single batch" ) ;
290299 }
291300
292- let val = Arc :: clone ( & values[ 0 ] ) ;
293- if val. len ( ) > 0 {
301+ let val = & values[ 0 ] ;
302+ let nulls = if self . ignore_nulls {
303+ val. logical_nulls ( )
304+ } else {
305+ None
306+ } ;
307+
308+ let val = match nulls {
309+ Some ( nulls) if nulls. null_count ( ) >= val. len ( ) => return Ok ( ( ) ) ,
310+ Some ( nulls) => filter ( val, & BooleanArray :: new ( nulls. inner ( ) . clone ( ) , None ) ) ?,
311+ None => Arc :: clone ( val) ,
312+ } ;
313+
314+ if !val. is_empty ( ) {
294315 self . values . push ( val) ;
295316 }
317+
296318 Ok ( ( ) )
297319 }
298320
@@ -361,17 +383,20 @@ struct DistinctArrayAggAccumulator {
361383 values : HashSet < ScalarValue > ,
362384 datatype : DataType ,
363385 sort_options : Option < SortOptions > ,
386+ ignore_nulls : bool ,
364387}
365388
366389impl DistinctArrayAggAccumulator {
367390 pub fn try_new (
368391 datatype : & DataType ,
369392 sort_options : Option < SortOptions > ,
393+ ignore_nulls : bool ,
370394 ) -> Result < Self > {
371395 Ok ( Self {
372396 values : HashSet :: new ( ) ,
373397 datatype : datatype. clone ( ) ,
374398 sort_options,
399+ ignore_nulls,
375400 } )
376401 }
377402}
@@ -386,11 +411,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
386411 return Ok ( ( ) ) ;
387412 }
388413
389- let array = & values[ 0 ] ;
414+ let val = & values[ 0 ] ;
415+ let nulls = if self . ignore_nulls {
416+ val. logical_nulls ( )
417+ } else {
418+ None
419+ } ;
390420
391- for i in 0 ..array. len ( ) {
392- let scalar = ScalarValue :: try_from_array ( & array, i) ?;
393- self . values . insert ( scalar) ;
421+ let nulls = nulls. as_ref ( ) ;
422+ if nulls. is_none_or ( |nulls| nulls. null_count ( ) < val. len ( ) ) {
423+ for i in 0 ..val. len ( ) {
424+ if nulls. is_none_or ( |nulls| nulls. is_valid ( i) ) {
425+ self . values . insert ( ScalarValue :: try_from_array ( val, i) ?) ;
426+ }
427+ }
394428 }
395429
396430 Ok ( ( ) )
@@ -472,6 +506,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
472506 ordering_req : LexOrdering ,
473507 /// Whether the aggregation is running in reverse.
474508 reverse : bool ,
509+ /// Whether the aggregation should ignore null values.
510+ ignore_nulls : bool ,
475511}
476512
477513impl OrderSensitiveArrayAggAccumulator {
@@ -482,6 +518,7 @@ impl OrderSensitiveArrayAggAccumulator {
482518 ordering_dtypes : & [ DataType ] ,
483519 ordering_req : LexOrdering ,
484520 reverse : bool ,
521+ ignore_nulls : bool ,
485522 ) -> Result < Self > {
486523 let mut datatypes = vec ! [ datatype. clone( ) ] ;
487524 datatypes. extend ( ordering_dtypes. iter ( ) . cloned ( ) ) ;
@@ -491,6 +528,7 @@ impl OrderSensitiveArrayAggAccumulator {
491528 datatypes,
492529 ordering_req,
493530 reverse,
531+ ignore_nulls,
494532 } )
495533 }
496534}
@@ -501,11 +539,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
501539 return Ok ( ( ) ) ;
502540 }
503541
504- let n_row = values[ 0 ] . len ( ) ;
505- for index in 0 ..n_row {
506- let row = get_row_at_idx ( values, index) ?;
507- self . values . push ( row[ 0 ] . clone ( ) ) ;
508- self . ordering_values . push ( row[ 1 ..] . to_vec ( ) ) ;
542+ let val = & values[ 0 ] ;
543+ let ord = & values[ 1 ..] ;
544+ let nulls = if self . ignore_nulls {
545+ val. logical_nulls ( )
546+ } else {
547+ None
548+ } ;
549+
550+ let nulls = nulls. as_ref ( ) ;
551+ if nulls. is_none_or ( |nulls| nulls. null_count ( ) < val. len ( ) ) {
552+ for i in 0 ..val. len ( ) {
553+ if nulls. is_none_or ( |nulls| nulls. is_valid ( i) ) {
554+ self . values . push ( ScalarValue :: try_from_array ( val, i) ?) ;
555+ self . ordering_values . push ( get_row_at_idx ( ord, i) ?)
556+ }
557+ }
509558 }
510559
511560 Ok ( ( ) )
@@ -666,7 +715,7 @@ impl OrderSensitiveArrayAggAccumulator {
666715#[ cfg( test) ]
667716mod tests {
668717 use super :: * ;
669- use arrow:: datatypes:: { FieldRef , Schema } ;
718+ use arrow:: datatypes:: Schema ;
670719 use datafusion_common:: cast:: as_generic_string_array;
671720 use datafusion_common:: internal_err;
672721 use datafusion_physical_expr:: expressions:: Column ;
@@ -947,14 +996,12 @@ mod tests {
947996 fn new ( data_type : DataType ) -> Self {
948997 Self {
949998 data_type : data_type. clone ( ) ,
950- distinct : Default :: default ( ) ,
999+ distinct : false ,
9511000 ordering : Default :: default ( ) ,
9521001 schema : Schema {
9531002 fields : Fields :: from ( vec ! [ Field :: new(
9541003 "col" ,
955- DataType :: List ( FieldRef :: new( Field :: new(
956- "item" , data_type, true ,
957- ) ) ) ,
1004+ DataType :: new_list( data_type, true ) ,
9581005 true ,
9591006 ) ] ) ,
9601007 metadata : Default :: default ( ) ,
0 commit comments