Skip to content

Commit

Permalink
add faster path for strpos in ascii-only case
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Sep 9, 2024
1 parent 25f7aff commit 8ebd4d1
Showing 1 changed file with 48 additions and 17 deletions.
65 changes: 48 additions & 17 deletions datafusion/functions/src/unicode/strpos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray,
};
use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};

use crate::string::common::StringArrayType;
use crate::utils::{make_scalar_function, utf8_to_int_type};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use crate::utils::{make_scalar_function, utf8_to_int_type};

#[derive(Debug)]
pub struct StrposFunc {
signature: Signature,
Expand Down Expand Up @@ -140,24 +138,43 @@ fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
substring_array: V2,
) -> Result<ArrayRef>
where
V1: ArrayAccessor<Item = &'a str>,
V2: ArrayAccessor<Item = &'a str>,
V1: StringArrayType<'a, Item = &'a str>,
V2: StringArrayType<'a, Item = &'a str>,
{
let string_iter = ArrayIter::new(string_array);
let substring_iter = ArrayIter::new(substring_array);
let ascii_only = string_array.is_ascii() && substring_array.is_ascii();
let string_iter = string_array.iter();
let substring_iter = substring_array.iter();

let result = string_iter
.zip(substring_iter)
.map(|(string, substring)| match (string, substring) {
(Some(string), Some(substring)) => {
// The `find` method returns the byte index of the substring.
// We count the number of chars up to that byte index.
T::Native::from_usize(
string
.find(substring)
.map(|x| string[..x].chars().count() + 1)
.unwrap_or(0),
)
// If only ASCII characters are present, we can use the slide window method to find
// the sub vector in the main vector. This is faster than string.find() method.
if ascii_only {
// If the substring is empty, the result is 1.
if substring.as_bytes().is_empty() {
return T::Native::from_usize(1);
} else {
T::Native::from_usize(
string
.as_bytes()
.windows(substring.as_bytes().len())
.position(|w| w == substring.as_bytes())
.map(|x| x + 1)
.unwrap_or(0),
)
}
} else {
// The `find` method returns the byte index of the substring.
// We count the number of chars up to that byte index.
T::Native::from_usize(
string
.find(substring)
.map(|x| string[..x].chars().count() + 1)
.unwrap_or(0),
)
}
}
_ => None,
})
Expand Down Expand Up @@ -201,47 +218,61 @@ mod tests {
test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);

// LargeUtf8 and LargeUtf8 combinations
test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);

// Utf8 and LargeUtf8 combinations
test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);

// LargeUtf8 and Utf8 combinations
test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);

// Utf8View and Utf8View combinations
test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);

// Utf8View and Utf8 combinations
test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);

// Utf8View and LargeUtf8 combinations
test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
}
}

0 comments on commit 8ebd4d1

Please sign in to comment.