Skip to content

Commit 2226dba

Browse files
authored
Fix array_sort for empty record batch (#290)
1 parent b6ddc6c commit 2226dba

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

datafusion/functions-array/src/sort.rs

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use crate::utils::make_scalar_function;
2121
use arrow::compute;
2222
use arrow_array::{Array, ArrayRef, ListArray};
2323
use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer};
24-
use arrow_schema::DataType::{FixedSizeList, LargeList, List};
2524
use arrow_schema::{DataType, Field, SortOptions};
2625
use datafusion_common::cast::{as_list_array, as_string_array};
2726
use datafusion_common::{exec_err, Result};
@@ -67,19 +66,9 @@ impl ScalarUDFImpl for ArraySort {
6766

6867
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
6968
match &arg_types[0] {
70-
List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new(
71-
"item",
72-
field.data_type().clone(),
73-
true,
74-
)))),
75-
LargeList(field) => Ok(LargeList(Arc::new(Field::new(
76-
"item",
77-
field.data_type().clone(),
78-
true,
79-
)))),
80-
_ => exec_err!(
81-
"Not reachable, data_type should be List, LargeList or FixedSizeList"
82-
),
69+
DataType::Null => Ok(DataType::Null),
70+
arg_type @ DataType::List(_) => Ok(arg_type.clone()),
71+
arg_type => exec_err!("{} does not support type {arg_type}", self.name()),
8372
}
8473
}
8574

@@ -98,6 +87,16 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
9887
return exec_err!("array_sort expects one to three arguments");
9988
}
10089

90+
if args[0].data_type().is_null() {
91+
return Ok(Arc::clone(&args[0]));
92+
}
93+
94+
let list_array = as_list_array(&args[0])?;
95+
let row_count = list_array.len();
96+
if row_count == 0 || list_array.value_type().is_null() {
97+
return Ok(Arc::clone(&args[0]));
98+
}
99+
101100
let sort_option = match args.len() {
102101
1 => None,
103102
2 => {
@@ -118,12 +117,6 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
118117
_ => return exec_err!("array_sort expects 1 to 3 arguments"),
119118
};
120119

121-
let list_array = as_list_array(&args[0])?;
122-
let row_count = list_array.len();
123-
if row_count == 0 {
124-
return Ok(args[0].clone());
125-
}
126-
127120
let mut array_lengths = vec![];
128121
let mut arrays = vec![];
129122
let mut valid = BooleanBufferBuilder::new(row_count);

datafusion/sqllogictest/test_files/array.slt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,11 +2038,16 @@ NULL
20382038
[, 51, 52, 54, 55, 56, 57, 58, 59, 60]
20392039
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70]
20402040

2041-
# test with empty array
2041+
# test with empty table
20422042
query ?
2043-
select array_sort([]);
2043+
select array_sort(column1, 'DESC', 'NULLS FIRST') from arrays_values where false;
20442044
----
2045-
[]
2045+
2046+
# test with empty array
2047+
query ??
2048+
select array_sort([]), array_sort(NULL);
2049+
----
2050+
[] NULL
20462051

20472052
# test with empty row, the row that does not match the condition has row count 0
20482053
statement ok

0 commit comments

Comments
 (0)