Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions datafusion/functions-nested/src/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
use crate::utils::make_scalar_function;
use arrow::array::{Capacities, MutableArrayData};
use arrow::compute;
use arrow::compute::cast;
use arrow_array::{
new_null_array, Array, ArrayRef, GenericListArray, Int64Array, ListArray,
OffsetSizeTrait,
new_null_array, Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait,
UInt64Array,
};
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List};
use arrow_schema::{DataType, Field};
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -86,7 +87,7 @@ impl Default for ArrayRepeat {
impl ArrayRepeat {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![String::from("list_repeat")],
}
}
Expand Down Expand Up @@ -124,19 +125,47 @@ impl ScalarUDFImpl for ArrayRepeat {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("array_repeat expects two arguments");
}

let element_type = &arg_types[0];
let first = element_type.clone();

let count_type = &arg_types[1];

// Coerce the second argument to Int64/UInt64 if it's a numeric type
let second = match count_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
Copy link
Contributor Author

@jatin510 jatin510 Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@korowa @alamb
converting negative numbers to UInt64 in this function was giving: arrow error, can't cast negative numbers to Uint.
during runtime for the negative integers.

So, converted the Int values to Int64 .

Then using array_repeat inner function to convert Int64 to UInt64 type

Copy link
Contributor

@korowa korowa Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-necessary suggestion: due to safe cast converting negative values to nulls, perhaps we should also add a simple test to verify that count array with multiple values with nulls (after casting to uint) will be processed as expected, like

select array_repeat('x', column1) from (values (-1), (2), (-3));

DataType::Int64
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
DataType::UInt64
}
_ => return exec_err!("count must be an integer type"),
};

Ok(vec![first, second])
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

/// Array_repeat SQL function
pub fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_repeat expects two arguments");
}

let element = &args[0];
let count_array = as_int64_array(&args[1])?;
let count_array = &args[1];

let count_array = match count_array.data_type() {
DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
DataType::UInt64 => count_array,
_ => return exec_err!("count must be an integer type"),
};

let count_array = as_uint64_array(count_array)?;

match element.data_type() {
List(_) => {
Expand Down Expand Up @@ -165,7 +194,7 @@ pub fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
/// ```
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &Int64Array,
count_array: &UInt64Array,
) -> Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];
Expand Down Expand Up @@ -219,7 +248,7 @@ fn general_repeat<O: OffsetSizeTrait>(
/// ```
fn general_list_repeat<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
count_array: &Int64Array,
count_array: &UInt64Array,
) -> Result<ArrayRef> {
let data_type = list_array.data_type();
let value_type = list_array.value_type();
Expand Down
24 changes: 24 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2750,6 +2750,30 @@ select
----
[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[NULL, NULL], [NULL, NULL], [NULL, NULL]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]

# array_repeat scalar function with count of different integer types
query ????????
Select
array_repeat(1, arrow_cast(2,'Int8')),
array_repeat(2, arrow_cast(2,'Int16')),
array_repeat(3, arrow_cast(2,'Int32')),
array_repeat(4, arrow_cast(2,'Int64')),
array_repeat(1, arrow_cast(2,'UInt8')),
array_repeat(2, arrow_cast(2,'UInt16')),
array_repeat(3, arrow_cast(2,'UInt32')),
array_repeat(4, arrow_cast(2,'UInt64'));
----
[1, 1] [2, 2] [3, 3] [4, 4] [1, 1] [2, 2] [3, 3] [4, 4]

# array_repeat scalar function with count of negative integer types
query ????
Select
array_repeat(1, arrow_cast(-2,'Int8')),
array_repeat(2, arrow_cast(-2,'Int16')),
array_repeat(3, arrow_cast(-2,'Int32')),
array_repeat(4, arrow_cast(-2,'Int64'));
----
[] [] [] []

# array_repeat with columns #1

statement ok
Expand Down
Loading