Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions datafusion/functions/benches/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use arrow::util::bench_util::{
};
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
use datafusion_common::DataFusionError;
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_functions::string;
Expand Down Expand Up @@ -80,6 +81,44 @@ fn invoke_repeat_with_args(
}

fn criterion_benchmark(c: &mut Criterion) {
let repeat_fn = string::repeat();
let config_options = Arc::new(ConfigOptions::default());

// Scalar benchmarks (outside loop)
c.bench_function("repeat/scalar_utf8", |b| {
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
],
arg_fields: vec![
Field::new("a", DataType::Utf8, false).into(),
Field::new("b", DataType::Int64, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
};
b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap()))
});

c.bench_function("repeat/scalar_utf8view", |b| {
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
],
arg_fields: vec![
Field::new("a", DataType::Utf8View, false).into(),
Field::new("b", DataType::Int64, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
};
b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap()))
});

for size in [1024, 4096] {
// REPEAT 3 TIMES
let repeat_times = 3;
Expand Down
170 changes: 134 additions & 36 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
use std::any::Any;
use std::sync::Arc;

use crate::utils::{make_scalar_function, utf8_to_str_type};
use crate::utils::utf8_to_str_type;
use arrow::array::{
ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
OffsetSizeTrait, StringArrayType, StringViewArray,
};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
use datafusion_common::cast::as_int64_array;
use datafusion_common::types::{NativeType, logical_int64, logical_string};
use datafusion_common::{DataFusionError, Result, exec_err};
use datafusion_common::utils::take_function_args;
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
Expand Down Expand Up @@ -99,39 +100,121 @@ impl ScalarUDFImpl for RepeatFunc {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(repeat, vec![])(&args.args)
let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;

// Helper to create null result with correct type (follows utf8_to_str_type)
let null_result = |dt: &DataType| -> ColumnarValue {
let scalar = if matches!(dt, LargeUtf8) {
ScalarValue::LargeUtf8(None)
} else {
ScalarValue::Utf8(None)
};
ColumnarValue::Scalar(scalar)
};

// Early return if either argument is a scalar null
if let ColumnarValue::Scalar(s) = &string_arg
&& s.is_null()
{
return Ok(null_result(&s.data_type()));
}
if let ColumnarValue::Scalar(c) = &count_arg
&& c.is_null()
{
let dt = match &string_arg {
ColumnarValue::Scalar(s) => s.data_type(),
ColumnarValue::Array(a) => a.data_type().clone(),
};
Comment thread
Jefffrey marked this conversation as resolved.
Outdated
return Ok(null_result(&dt));
}

match (&string_arg, &count_arg) {
(
ColumnarValue::Scalar(string_scalar),
ColumnarValue::Scalar(count_scalar),
) => {
let count = match count_scalar {
ScalarValue::Int64(Some(n)) => *n,
_ => {
return internal_err!(
"Unexpected data type {:?} for repeat count",
count_scalar.data_type()
);
}
};

let result = match string_scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
ScalarValue::Utf8(Some(compute_repeat(s, count)?))
}
ScalarValue::LargeUtf8(Some(s)) => {
ScalarValue::LargeUtf8(Some(compute_repeat(s, count)?))
}
_ => {
return internal_err!(
"Unexpected data type {:?} for function repeat",
string_scalar.data_type()
);
}
};

Ok(ColumnarValue::Scalar(result))
}
_ => {
let string_array = string_arg.to_array(args.number_rows)?;
let count_array = count_arg.to_array(args.number_rows)?;
Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
}
}
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

/// Computes repeat for a single string value
#[inline]
fn compute_repeat(s: &str, count: i64) -> Result<String> {
if count <= 0 {
return Ok(String::new());
}
let result_len = s.len().saturating_mul(count as usize);
if result_len > i32::MAX as usize {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
i32::MAX,
Comment thread
Jefffrey marked this conversation as resolved.
Outdated
result_len
);
}
Ok(s.repeat(count as usize))
}

/// Repeats string the specified number of times.
/// repeat('Pg', 4) = 'PgPgPgPg'
fn repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
let number_array = as_int64_array(&args[1])?;
match args[0].data_type() {
fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
let number_array = as_int64_array(count_array)?;
match string_array.data_type() {
Utf8View => {
let string_view_array = args[0].as_string_view();
let string_view_array = string_array.as_string_view();
repeat_impl::<i32, &StringViewArray>(
&string_view_array,
number_array,
i32::MAX as usize,
)
}
Utf8 => {
let string_array = args[0].as_string::<i32>();
let string_arr = string_array.as_string::<i32>();
repeat_impl::<i32, &GenericStringArray<i32>>(
&string_array,
&string_arr,
number_array,
i32::MAX as usize,
)
}
LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let string_arr = string_array.as_string::<i64>();
repeat_impl::<i64, &GenericStringArray<i64>>(
&string_array,
&string_arr,
number_array,
i64::MAX as usize,
)
Expand All @@ -150,7 +233,7 @@ fn repeat_impl<'a, T, S>(
) -> Result<ArrayRef>
where
T: OffsetSizeTrait,
S: StringArrayType<'a>,
S: StringArrayType<'a> + 'a,
{
let mut total_capacity = 0;
let mut max_item_capacity = 0;
Expand Down Expand Up @@ -181,37 +264,52 @@ where
// Reusable buffer to avoid allocations in string.repeat()
let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);

string_array
.iter()
.zip(number_array.iter())
.for_each(|(string, number)| {
// Helper function to repeat a string into a buffer using doubling strategy
// count must be > 0
#[inline]
fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
buffer.clear();
if !string.is_empty() {
let src = string.as_bytes();
buffer.extend_from_slice(src);
while buffer.len() < src.len() * count {
let copy_len = buffer.len().min(src.len() * count - buffer.len());
buffer.extend_from_within(..copy_len);
}
}
}

// Fast path: no nulls in either array
if string_array.null_count() == 0 && number_array.null_count() == 0 {
for i in 0..string_array.len() {
// SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value
let string = unsafe { string_array.value_unchecked(i) };
let count = number_array.value(i);
if count > 0 {
repeat_to_buffer(&mut buffer, string, count as usize);
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
} else {
builder.append_value("");
}
}
} else {
// Slow path: handle nulls
for (string, number) in string_array.iter().zip(number_array.iter()) {
match (string, number) {
(Some(string), Some(number)) if number >= 0 => {
buffer.clear();
let count = number as usize;
if count > 0 && !string.is_empty() {
let src = string.as_bytes();
// Initial copy
buffer.extend_from_slice(src);
// Doubling strategy: copy what we have so far until we reach the target
while buffer.len() < src.len() * count {
let copy_len =
buffer.len().min(src.len() * count - buffer.len());
// SAFETY: we're copying valid UTF-8 bytes that we already verified
Comment thread
Jefffrey marked this conversation as resolved.
buffer.extend_from_within(..copy_len);
}
}
// SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str
(Some(string), Some(count)) if count > 0 => {
repeat_to_buffer(&mut buffer, string, count as usize);
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
builder
.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
}
(Some(_), Some(_)) => builder.append_value(""),
_ => builder.append_null(),
}
});
let array = builder.finish();
}
}

Ok(Arc::new(array) as ArrayRef)
Ok(Arc::new(builder.finish()) as ArrayRef)
}

#[cfg(test)]
Expand Down