Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: add arguments length check in array_expressions #8622

Merged
merged 1 commit into from
Dec 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 107 additions & 3 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ where
/// For example:
/// > array_element(\[1, 2, 3], 2) -> 2
pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_element needs two arguments");
}

match &args[0].data_type() {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
Expand Down Expand Up @@ -557,6 +561,10 @@ pub fn array_except(args: &[ArrayRef]) -> Result<ArrayRef> {
///
/// See test cases in `array.slt` for more details.
pub fn array_slice(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return exec_err!("array_slice needs three arguments");
}

let array_data_type = args[0].data_type();
match array_data_type {
DataType::List(_) => {
Expand Down Expand Up @@ -708,6 +716,10 @@ where

/// array_pop_back SQL function
pub fn array_pop_back(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_pop_back needs one argument");
}

let list_array = as_list_array(&args[0])?;
let from_array = Int64Array::from(vec![1; list_array.len()]);
let to_array = Int64Array::from(
Expand Down Expand Up @@ -857,6 +869,10 @@ pub fn array_pop_front(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_append SQL function
pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_append expects two arguments");
}

let list_array = as_list_array(&args[0])?;
let element_array = &args[1];

Expand All @@ -883,6 +899,10 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_sort SQL function
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() || args.len() > 3 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make it more usable in future for all built in functions. To provide the exact available signatures, to let user quickly find what is his mistake

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make it more usable in future for all built in functions. To provide the exact available signatures, to let user quickly find what is his mistake

@comphead
But we will also handle this while checking signature of array. Why do we need to check it here again? Is there any usage that skip the signature checking but jump to this call directly? If yes, we need to find a way to reuse the checking because the coercion is not supported here currently, length checking may not be enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking input arguments validity in public methods is good practice imho.

But the comment above was mostly to thinking on having a generic approach to return a friendly message to the user about what is wrong and what is the next step. Now we usually saying the input arguments has to be 1, 2, 3 arguments but it is mostly meaningless, we can probably improve it with more detailed message

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply. I agree with the message that more specific information would be better.

return exec_err!("array_sort expects one to three arguments");
}

let sort_option = match args.len() {
1 => None,
2 => {
Expand Down Expand Up @@ -962,6 +982,10 @@ fn order_nulls_first(modifier: &str) -> Result<bool> {

/// Array_prepend SQL function
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_prepend expects two arguments");
}

let list_array = as_list_array(&args[1])?;
let element_array = &args[0];

Expand Down Expand Up @@ -1082,6 +1106,10 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_concat/Array_cat SQL function
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() {
return exec_err!("array_concat expects at least one arguments");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return exec_err!("array_concat expects at least one arguments");
return exec_err!("array_concat expects at least one argument");

}

let mut new_args = vec![];
for arg in args {
let ndim = list_ndims(arg.data_type());
Expand All @@ -1098,6 +1126,10 @@ pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_empty SQL function
pub fn array_empty(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_empty expects one argument");
}

if as_null_array(&args[0]).is_ok() {
// Make sure to return Boolean type.
return Ok(Arc::new(BooleanArray::new_null(args[0].len())));
Expand All @@ -1122,6 +1154,10 @@ fn array_empty_dispatch<O: OffsetSizeTrait>(array: &ArrayRef) -> Result<ArrayRef

/// Array_repeat SQL function
pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_repeat expects two arguments");
}

let element = &args[0];
let count_array = as_int64_array(&args[1])?;

Expand Down Expand Up @@ -1257,6 +1293,10 @@ fn general_list_repeat(

/// Array_position SQL function
pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("array_position expects two or three arguments");
}

let list_array = as_list_array(&args[0])?;
let element_array = &args[1];

Expand Down Expand Up @@ -1321,6 +1361,10 @@ fn general_position<OffsetSize: OffsetSizeTrait>(

/// Array_positions SQL function
pub fn array_positions(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_positions expects two arguments");
}

let element = &args[1];

match &args[0].data_type() {
Expand Down Expand Up @@ -1480,16 +1524,28 @@ fn array_remove_internal(
}

pub fn array_remove_all(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_remove_all expects two arguments");
}

let arr_n = vec![i64::MAX; args[0].len()];
array_remove_internal(&args[0], &args[1], arr_n)
}

pub fn array_remove(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_remove expects two arguments");
}

let arr_n = vec![1; args[0].len()];
array_remove_internal(&args[0], &args[1], arr_n)
}

pub fn array_remove_n(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return exec_err!("array_remove_n expects three arguments");
}

let arr_n = as_int64_array(&args[2])?.values().to_vec();
array_remove_internal(&args[0], &args[1], arr_n)
}
Expand Down Expand Up @@ -1593,18 +1649,30 @@ fn general_replace(
}

pub fn array_replace(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return exec_err!("array_replace expects three arguments");
}

// replace at most one occurence for each element
let arr_n = vec![1; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}

pub fn array_replace_n(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 4 {
return exec_err!("array_replace_n expects four arguments");
}

// replace the specified number of occurences
let arr_n = as_int64_array(&args[3])?.values().to_vec();
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}

pub fn array_replace_all(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return exec_err!("array_replace_all expects three arguments");
}

// replace all occurrences (up to "i64::MAX")
let arr_n = vec![i64::MAX; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
Expand Down Expand Up @@ -1682,7 +1750,7 @@ fn union_generic_lists<OffsetSize: OffsetSizeTrait>(
/// Array_union SQL function
pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_union needs two arguments");
return exec_err!("array_union needs 2 arguments");
}
let array1 = &args[0];
let array2 = &args[1];
Expand Down Expand Up @@ -1724,6 +1792,10 @@ pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_to_string SQL function
pub fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("array_to_string expects two or three arguments");
}

let arr = &args[0];

let delimiters = as_string_array(&args[1])?;
Expand Down Expand Up @@ -1833,6 +1905,10 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Cardinality SQL function
pub fn cardinality(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("cardinality expects one argument");
}

let list_array = as_list_array(&args[0])?.clone();

let result = list_array
Expand Down Expand Up @@ -1889,6 +1965,10 @@ fn flatten_internal(

/// Flatten SQL function
pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("flatten expects one argument");
}

let flattened_array = flatten_internal(&args[0], None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
Expand All @@ -1913,6 +1993,10 @@ fn array_length_dispatch<O: OffsetSizeTrait>(array: &[ArrayRef]) -> Result<Array

/// Array_length SQL function
pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 && args.len() != 2 {
return exec_err!("array_length expects one or two arguments");
}

match &args[0].data_type() {
DataType::List(_) => array_length_dispatch::<i32>(args),
DataType::LargeList(_) => array_length_dispatch::<i64>(args),
Expand Down Expand Up @@ -1959,6 +2043,10 @@ pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_ndims SQL function
pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_ndims needs one argument");
}

if let Some(list_array) = args[0].as_list_opt::<i32>() {
let ndims = datafusion_common::utils::list_ndims(list_array.data_type());

Expand Down Expand Up @@ -2049,6 +2137,10 @@ fn general_array_has_dispatch<O: OffsetSizeTrait>(

/// Array_has SQL function
pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_has needs two arguments");
}

let array_type = args[0].data_type();

match array_type {
Expand All @@ -2064,6 +2156,10 @@ pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_has_any SQL function
pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_has_any needs two arguments");
}

let array_type = args[0].data_type();

match array_type {
Expand All @@ -2079,6 +2175,10 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_has_all SQL function
pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_has_all needs two arguments");
}

let array_type = args[0].data_type();

match array_type {
Expand Down Expand Up @@ -2183,7 +2283,9 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef

/// array_intersect SQL function
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);
if args.len() != 2 {
return exec_err!("array_intersect needs two arguments");
}

let first_array = &args[0];
let second_array = &args[1];
Expand Down Expand Up @@ -2286,7 +2388,9 @@ pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
/// array_distinct SQL function
/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 1);
if args.len() != 1 {
return exec_err!("array_distinct needs one argument");
}

// handle null
if args[0].data_type() == &DataType::Null {
Expand Down