diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 3e832691f96b0..765f5d865a60e 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -254,3 +254,13 @@ required-features = ["unicode_expressions"] harness = false name = "find_in_set" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "starts_with" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "ends_with" +required-features = ["string_expressions"] diff --git a/datafusion/functions/benches/ends_with.rs b/datafusion/functions/benches/ends_with.rs new file mode 100644 index 0000000000000..926fd9ff72a5a --- /dev/null +++ b/datafusion/functions/benches/ends_with.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Generate a StringArray/StringViewArray with random ASCII strings +fn gen_string_array( + n_rows: usize, + str_len: usize, + is_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + let strings: Vec> = (0..n_rows) + .map(|_| { + let s: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect(); + Some(s) + }) + .collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +/// Generate a scalar suffix string +fn gen_scalar_suffix(suffix_str: &str, is_string_view: bool) -> ColumnarValue { + if is_string_view { + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(suffix_str.to_string()))) + } else { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(suffix_str.to_string()))) + } +} + +/// Generate an array of suffix strings (same string repeated) +fn gen_array_suffix( + suffix_str: &str, + n_rows: usize, + is_string_view: bool, +) -> ColumnarValue { + let strings: Vec> = + (0..n_rows).map(|_| Some(suffix_str.to_string())).collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let ends_with = datafusion_functions::string::ends_with(); + let n_rows = 8192; + let str_len = 128; + let suffix_str = "xyz"; // A pattern that likely won't match + + // Benchmark: StringArray with scalar suffix (the optimized path) + let str_array = gen_string_array(n_rows, str_len, false); + let scalar_suffix = gen_scalar_suffix(suffix_str, false); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("ends_with_StringArray_scalar_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_suffix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringArray with array suffix (for comparison) + let array_suffix = gen_array_suffix(suffix_str, n_rows, false); + c.bench_function("ends_with_StringArray_array_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), array_suffix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with scalar suffix (the optimized path) + let str_view_array = gen_string_array(n_rows, str_len, true); + let scalar_suffix_view = gen_scalar_suffix(suffix_str, true); + let arg_fields_view = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function("ends_with_StringViewArray_scalar_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), scalar_suffix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with array suffix (for comparison) + let array_suffix_view = gen_array_suffix(suffix_str, n_rows, true); + c.bench_function("ends_with_StringViewArray_array_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), array_suffix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark different string lengths with scalar suffix + for str_len in [8, 32, 128, 512] { + let str_array = gen_string_array(n_rows, str_len, true); + let scalar_suffix = gen_scalar_suffix(suffix_str, true); + let arg_fields = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function( + &format!("ends_with_StringViewArray_scalar_strlen_{str_len}"), + |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_suffix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/starts_with.rs b/datafusion/functions/benches/starts_with.rs new file mode 100644 index 0000000000000..9ee39b694539c --- /dev/null +++ b/datafusion/functions/benches/starts_with.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Generate a StringArray/StringViewArray with random ASCII strings +fn gen_string_array( + n_rows: usize, + str_len: usize, + is_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + let strings: Vec> = (0..n_rows) + .map(|_| { + let s: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect(); + Some(s) + }) + .collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +/// Generate a scalar prefix string +fn gen_scalar_prefix(prefix_str: &str, is_string_view: bool) -> ColumnarValue { + if is_string_view { + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(prefix_str.to_string()))) + } else { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(prefix_str.to_string()))) + } +} + +/// Generate an array of prefix strings (same string repeated) +fn gen_array_prefix( + prefix_str: &str, + n_rows: usize, + is_string_view: bool, +) -> ColumnarValue { + let strings: Vec> = + (0..n_rows).map(|_| Some(prefix_str.to_string())).collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let starts_with = datafusion_functions::string::starts_with(); + let n_rows = 8192; + let str_len = 128; + let prefix_str = "xyz"; // A pattern that likely won't match + + // Benchmark: StringArray with scalar prefix (the optimized path) + let str_array = gen_string_array(n_rows, str_len, false); + let scalar_prefix = gen_scalar_prefix(prefix_str, false); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("starts_with_StringArray_scalar_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_prefix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringArray with array prefix (for comparison) + let array_prefix = gen_array_prefix(prefix_str, n_rows, false); + c.bench_function("starts_with_StringArray_array_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), array_prefix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with scalar prefix (the optimized path) + let str_view_array = gen_string_array(n_rows, str_len, true); + let scalar_prefix_view = gen_scalar_prefix(prefix_str, true); + let arg_fields_view = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function("starts_with_StringViewArray_scalar_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), scalar_prefix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with array prefix (for comparison) + let array_prefix_view = gen_array_prefix(prefix_str, n_rows, true); + c.bench_function("starts_with_StringViewArray_array_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), array_prefix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark different string lengths with scalar prefix + for str_len in [8, 32, 128, 512] { + let str_array = gen_string_array(n_rows, str_len, true); + let scalar_prefix = gen_scalar_prefix(prefix_str, true); + let arg_fields = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function( + &format!("starts_with_StringViewArray_scalar_strlen_{str_len}"), + |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_prefix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index e3fa7c92ca62b..a1fa124548d0a 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -18,12 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, Scalar}; +use arrow::compute::kernels::comparison::ends_with as arrow_ends_with; use arrow::datatypes::DataType; -use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; -use datafusion_common::{Result, internal_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::binary::{binary_to_string_coercion, string_coercion}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -95,13 +95,76 @@ impl ScalarUDFImpl for EndsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(ends_with, vec![])(&args.args) + let [str_arg, suffix_arg] = args.args.as_slice() else { + return exec_err!( + "ends_with was called with {} arguments, expected 2", + args.args.len() + ); + }; + + // Determine the common type for coercion + let coercion_type = string_coercion( + &str_arg.data_type(), + &suffix_arg.data_type(), + ) + .or_else(|| { + binary_to_string_coercion(&str_arg.data_type(), &suffix_arg.data_type()) + }); + + let Some(coercion_type) = coercion_type else { + return exec_err!( + "Unsupported data types {:?}, {:?} for function `ends_with`.", + str_arg.data_type(), + suffix_arg.data_type() + ); + }; + + // Helper to cast an array if needed + let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result { + if arr.data_type() == target { + Ok(Arc::clone(arr)) + } else { + Ok(arrow::compute::kernels::cast::cast(arr, target)?) + } + }; + + match (str_arg, suffix_arg) { + // Both scalars - just compute directly + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(suffix_scalar)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let suffix_arr = suffix_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?; + let result = arrow_ends_with(&str_arr, &suffix_arr)?; + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } + // String is array, suffix is scalar - use Scalar wrapper for optimization + (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(suffix_scalar)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let suffix_arr = suffix_scalar.to_array_of_size(1)?; + let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?; + let suffix_scalar = Scalar::new(suffix_arr); + let result = arrow_ends_with(&str_arr, &suffix_scalar)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + // String is scalar, suffix is array - use Scalar wrapper for string + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(suffix_arr)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let str_scalar = Scalar::new(str_arr); + let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?; + let result = arrow_ends_with(&str_scalar, &suffix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + // Both arrays - pass directly + (ColumnarValue::Array(str_arr), ColumnarValue::Array(suffix_arr)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?; + let result = arrow_ends_with(&str_arr, &suffix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) } - other => internal_err!( - "Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View" - )?, } } @@ -110,47 +173,24 @@ impl ScalarUDFImpl for EndsWithFunc { } } -/// Returns true if string ends with suffix. -/// ends_with('alphabet', 'abet') = 't' -fn ends_with(args: &[ArrayRef]) -> Result { - if let Some(coercion_data_type) = - string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { - binary_to_string_coercion(args[0].data_type(), args[1].data_type()) - }) - { - let arg0 = if args[0].data_type() == &coercion_data_type { - Arc::clone(&args[0]) - } else { - arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? - }; - let arg1 = if args[1].data_type() == &coercion_data_type { - Arc::clone(&args[1]) - } else { - arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? - }; - let result = arrow::compute::kernels::comparison::ends_with(&arg0, &arg1)?; - Ok(Arc::new(result) as ArrayRef) - } else { - internal_err!( - "Unsupported data types for ends_with. Expected Utf8, LargeUtf8 or Utf8View" - ) - } -} - #[cfg(test)] mod tests { - use arrow::array::{Array, BooleanArray}; + use arrow::array::{Array, BooleanArray, StringArray}; use arrow::datatypes::DataType::Boolean; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; use datafusion_common::Result; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::string::ends_with::EndsWithFunc; use crate::utils::test::test_function; #[test] - fn test_functions() -> Result<()> { + fn test_scalar_scalar() -> Result<()> { + // Test Scalar + Scalar combinations test_function!( EndsWithFunc::new(), vec![ @@ -196,6 +236,186 @@ mod tests { BooleanArray ); + // Test with LargeUtf8 + test_function!( + EndsWithFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + "alphabet".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bet".to_string()))), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + + // Test with Utf8View + test_function!( + EndsWithFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "alphabet".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bet".to_string()))), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + + Ok(()) + } + + #[test] + fn test_array_scalar() -> Result<()> { + // Test Array + Scalar (the optimized path) + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))); + + let args = vec![array, scalar]; + test_function!( + EndsWithFunc::new(), + args, + Ok(Some(true)), // First element result: "alphabet" ends with "bet" + bool, + Boolean, + BooleanArray + ); + Ok(()) } + + #[test] + fn test_array_scalar_full_result() { + // Test Array + Scalar and verify all results + let func = EndsWithFunc::new(); + let array = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ])); + let args = vec![ + ColumnarValue::Array(array), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" ends with "bet" + assert!(bool_array.value(1)); // "alphabet" ends with "bet" + assert!(!bool_array.value(2)); // "beta" does not end with "bet" + assert!(bool_array.is_null(3)); // null input -> null output + } + + #[test] + fn test_scalar_array() { + // Test Scalar + Array + let func = EndsWithFunc::new(); + let suffixes = Arc::new(StringArray::from(vec![ + Some("bet"), + Some("alph"), + Some("phabet"), + None, + ])); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))), + ColumnarValue::Array(suffixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" ends with "bet" + assert!(!bool_array.value(1)); // "alphabet" does not end with "alph" + assert!(bool_array.value(2)); // "alphabet" ends with "phabet" + assert!(bool_array.is_null(3)); // null suffix -> null output + } + + #[test] + fn test_array_array() { + // Test Array + Array + let func = EndsWithFunc::new(); + let strings = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("rust"), + Some("datafusion"), + None, + ])); + let suffixes = Arc::new(StringArray::from(vec![ + Some("bet"), + Some("st"), + Some("hello"), + Some("test"), + ])); + let args = vec![ + ColumnarValue::Array(strings), + ColumnarValue::Array(suffixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" ends with "bet" + assert!(bool_array.value(1)); // "rust" ends with "st" + assert!(!bool_array.value(2)); // "datafusion" does not end with "hello" + assert!(bool_array.is_null(3)); // null string -> null output + } } diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 1a60eb91aa621..259612c42997e 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -18,49 +18,22 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, Scalar}; +use arrow::compute::kernels::comparison::starts_with as arrow_starts_with; use arrow::datatypes::DataType; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::type_coercion::binary::{ binary_to_string_coercion, string_coercion, }; -use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; -use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast, }; use datafusion_macros::user_doc; -/// Returns true if string starts with prefix. -/// starts_with('alphabet', 'alph') = 't' -fn starts_with(args: &[ArrayRef]) -> Result { - if let Some(coercion_data_type) = - string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { - binary_to_string_coercion(args[0].data_type(), args[1].data_type()) - }) - { - let arg0 = if args[0].data_type() == &coercion_data_type { - Arc::clone(&args[0]) - } else { - arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? - }; - let arg1 = if args[1].data_type() == &coercion_data_type { - Arc::clone(&args[1]) - } else { - arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? - }; - let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?; - Ok(Arc::new(result) as ArrayRef) - } else { - internal_err!( - "Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View" - ) - } -} - #[user_doc( doc_section(label = "String Functions"), description = "Tests if a string starts with a substring.", @@ -119,13 +92,76 @@ impl ScalarUDFImpl for StartsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(starts_with, vec![])(&args.args) + let [str_arg, prefix_arg] = args.args.as_slice() else { + return exec_err!( + "starts_with was called with {} arguments, expected 2", + args.args.len() + ); + }; + + // Determine the common type for coercion + let coercion_type = string_coercion( + &str_arg.data_type(), + &prefix_arg.data_type(), + ) + .or_else(|| { + binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type()) + }); + + let Some(coercion_type) = coercion_type else { + return exec_err!( + "Unsupported data types {:?}, {:?} for function `starts_with`.", + str_arg.data_type(), + prefix_arg.data_type() + ); + }; + + // Helper to cast an array if needed + let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result { + if arr.data_type() == target { + Ok(Arc::clone(arr)) + } else { + Ok(arrow::compute::kernels::cast::cast(arr, target)?) + } + }; + + match (str_arg, prefix_arg) { + // Both scalars - just compute directly + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let prefix_arr = prefix_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?; + let result = arrow_starts_with(&str_arr, &prefix_arr)?; + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } + // String is array, prefix is scalar - use Scalar wrapper for optimization + (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let prefix_arr = prefix_scalar.to_array_of_size(1)?; + let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?; + let prefix_scalar = Scalar::new(prefix_arr); + let result = arrow_starts_with(&str_arr, &prefix_scalar)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + // String is scalar, prefix is array - use Scalar wrapper for string + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let str_scalar = Scalar::new(str_arr); + let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?; + let result = arrow_starts_with(&str_scalar, &prefix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + // Both arrays - pass directly + (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?; + let result = arrow_starts_with(&str_arr, &prefix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) } - _ => internal_err!( - "Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View" - )?, } } @@ -195,16 +231,19 @@ impl ScalarUDFImpl for StartsWithFunc { #[cfg(test)] mod tests { use crate::utils::test::test_function; - use arrow::array::{Array, BooleanArray}; + use arrow::array::{Array, BooleanArray, StringArray}; use arrow::datatypes::DataType::Boolean; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; use super::*; #[test] - fn test_functions() -> Result<()> { - // Generate test cases for starts_with + fn test_scalar_scalar() -> Result<()> { + // Test Scalar + Scalar combinations let test_cases = vec![ (Some("alphabet"), Some("alph"), Some(true)), (Some("alphabet"), Some("bet"), Some(false)), @@ -248,4 +287,154 @@ mod tests { Ok(()) } + + #[test] + fn test_array_scalar() -> Result<()> { + // Test Array + Scalar (the optimized path) + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))); + + let args = vec![array, scalar]; + test_function!( + StartsWithFunc::new(), + args, + Ok(Some(true)), // First element result + bool, + Boolean, + BooleanArray + ); + + Ok(()) + } + + #[test] + fn test_array_scalar_full_result() { + // Test Array + Scalar and verify all results + let func = StartsWithFunc::new(); + let array = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ])); + let args = vec![ + ColumnarValue::Array(array), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" starts with "alph" + assert!(bool_array.value(1)); // "alphabet" starts with "alph" + assert!(!bool_array.value(2)); // "beta" does not start with "alph" + assert!(bool_array.is_null(3)); // null input -> null output + } + + #[test] + fn test_scalar_array() { + // Test Scalar + Array + let func = StartsWithFunc::new(); + let prefixes = Arc::new(StringArray::from(vec![ + Some("alph"), + Some("bet"), + Some("alpha"), + None, + ])); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))), + ColumnarValue::Array(prefixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" starts with "alph" + assert!(!bool_array.value(1)); // "alphabet" does not start with "bet" + assert!(bool_array.value(2)); // "alphabet" starts with "alpha" + assert!(bool_array.is_null(3)); // null prefix -> null output + } + + #[test] + fn test_array_array() { + // Test Array + Array + let func = StartsWithFunc::new(); + let strings = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("rust"), + Some("datafusion"), + None, + ])); + let prefixes = Arc::new(StringArray::from(vec![ + Some("alph"), + Some("ru"), + Some("hello"), + Some("test"), + ])); + let args = vec![ + ColumnarValue::Array(strings), + ColumnarValue::Array(prefixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" starts with "alph" + assert!(bool_array.value(1)); // "rust" starts with "ru" + assert!(!bool_array.value(2)); // "datafusion" does not start with "hello" + assert!(bool_array.is_null(3)); // null string -> null output + } }