Skip to content

Commit

Permalink
remove unnecessary indirection
Browse files Browse the repository at this point in the history
  • Loading branch information
Kev1n8 committed Aug 18, 2024
1 parent 934516a commit 48e1643
Showing 1 changed file with 16 additions and 33 deletions.
49 changes: 16 additions & 33 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ use arrow::array::{
};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_int64_array, as_string_view_array};
use datafusion_common::cast::as_int64_array;
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use crate::utils::{make_scalar_function, utf8_to_str_type};
use crate::utils::{make_scalar_function, optimized_utf8_to_str_type, utf8_to_str_type};

#[derive(Debug)]
pub struct SubstrFunc {
Expand Down Expand Up @@ -79,7 +79,7 @@ impl ScalarUDFImpl for SubstrFunc {

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types[0] == DataType::Utf8View {
Ok(DataType::Utf8View)
optimized_utf8_to_str_type(&arg_types[0], "substr")
} else {
utf8_to_str_type(&arg_types[0], "substr")
}
Expand All @@ -94,21 +94,28 @@ impl ScalarUDFImpl for SubstrFunc {
}
}

/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
/// substr('alphabet', 3) = 'phabet'
/// substr('alphabet', 3, 2) = 'ph'
/// The implementation uses UTF-8 code points as characters
pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
calculate_substr::<_, i32>(string_array, &args[1..])
string_substr::<_, i32>(string_array, &args[1..])
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
calculate_substr::<_, i64>(string_array, &args[1..])
string_substr::<_, i64>(string_array, &args[1..])
}
DataType::Utf8View => {
let string_array = args[0].as_string_view();
calculate_substr::<_, i32>(string_array, &args[1..])
string_view_substr(string_array, &args[1..])
}
other => exec_err!("Unsupported data type {other:?} for function substr"),
other => exec_err!(
"Unsupported data type {other:?} for function substr,\
expected Utf8View, Utf8 or LargeUtf8."
),
}
}

Expand Down Expand Up @@ -139,7 +146,7 @@ fn get_true_start_count(input: &str, start: usize, count: i64) -> (usize, usize)

// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
// From<u128> for ByteView
fn calculate_string_view(
fn string_view_substr(
string_array: &StringViewArray,
args: &[ArrayRef],
) -> Result<ArrayRef> {
Expand Down Expand Up @@ -273,7 +280,7 @@ fn calculate_string_view(
Ok(Arc::new(result) as ArrayRef)
}

fn calculate_string<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
T: OffsetSizeTrait,
Expand Down Expand Up @@ -330,30 +337,6 @@ where
}
}

/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
/// substr('alphabet', 3) = 'phabet'
/// substr('alphabet', 3, 2) = 'ph'
/// The implementation uses UTF-8 code points as characters
fn calculate_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
T: OffsetSizeTrait,
{
match string_array.data_type() {
DataType::Utf8View => {
calculate_string_view(as_string_view_array(&string_array)?, args)
}
DataType::Utf8 | DataType::LargeUtf8 => {
calculate_string::<V, T>(string_array, args)
}
other => {
exec_err!(
"unexpected datatype {other}, expected Utf8View, Utf8 or LargeUtf8."
)
}
}
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray, StringViewArray};
Expand Down

0 comments on commit 48e1643

Please sign in to comment.