Skip to content

Commit

Permalink
improvements to (i)starts_with and (i)ends_with performance (#6118)
Browse files Browse the repository at this point in the history
* improvements to "starts_with" and "ends_with"

* add tests and refactor slightly

* add comments
  • Loading branch information
samuelcolvin committed Jul 30, 2024
1 parent bf9ce47 commit bf0ea91
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 20 deletions.
10 changes: 8 additions & 2 deletions arrow-string/src/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,14 @@ fn op_binary<'a>(
Op::Like(neg) => binary_predicate(l, r, neg, Predicate::like),
Op::ILike(neg) => binary_predicate(l, r, neg, |s| Predicate::ilike(s, false)),
Op::Contains => Ok(l.zip(r).map(|(l, r)| Some(str_contains(l?, r?))).collect()),
Op::StartsWith => Ok(l.zip(r).map(|(l, r)| Some(l?.starts_with(r?))).collect()),
Op::EndsWith => Ok(l.zip(r).map(|(l, r)| Some(l?.ends_with(r?))).collect()),
Op::StartsWith => Ok(l
.zip(r)
.map(|(l, r)| Some(Predicate::StartsWith(r?).evaluate(l?)))
.collect()),
Op::EndsWith => Ok(l
.zip(r)
.map(|(l, r)| Some(Predicate::EndsWith(r?).evaluate(l?)))
.collect()),
}
}

Expand Down
139 changes: 121 additions & 18 deletions arrow-string/src/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use arrow_schema::ArrowError;
use memchr::memchr2;
use memchr::memmem::Finder;
use regex::{Regex, RegexBuilder};
use std::iter::zip;

/// A string based predicate
pub enum Predicate<'a> {
Expand Down Expand Up @@ -88,10 +89,12 @@ impl<'a> Predicate<'a> {
Predicate::Eq(v) => *v == haystack,
Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(),
Predicate::StartsWith(v) => haystack.starts_with(v),
Predicate::IStartsWithAscii(v) => starts_with_ignore_ascii_case(haystack, v),
Predicate::EndsWith(v) => haystack.ends_with(v),
Predicate::IEndsWithAscii(v) => ends_with_ignore_ascii_case(haystack, v),
Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel),
Predicate::IStartsWithAscii(v) => {
starts_with(haystack, v, equals_ignore_ascii_case_kernel)
}
Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel),
Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel),
Predicate::Regex(v) => v.is_match(haystack),
}
}
Expand All @@ -114,17 +117,17 @@ impl<'a> Predicate<'a> {
Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
finder.find(haystack.as_bytes()).is_some() != negate
}),
Predicate::StartsWith(v) => {
BooleanArray::from_unary(array, |haystack| haystack.starts_with(v) != negate)
}
Predicate::StartsWith(v) => BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_kernel) != negate
}),
Predicate::IStartsWithAscii(v) => BooleanArray::from_unary(array, |haystack| {
starts_with_ignore_ascii_case(haystack, v) != negate
starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
}),
Predicate::EndsWith(v) => BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_kernel) != negate
}),
Predicate::EndsWith(v) => {
BooleanArray::from_unary(array, |haystack| haystack.ends_with(v) != negate)
}
Predicate::IEndsWithAscii(v) => BooleanArray::from_unary(array, |haystack| {
ends_with_ignore_ascii_case(haystack, v) != negate
ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
}),
Predicate::Regex(v) => {
BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
Expand All @@ -133,14 +136,36 @@ impl<'a> Predicate<'a> {
}
}

fn starts_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool {
let end = haystack.len().min(needle.len());
haystack.is_char_boundary(end) && needle.eq_ignore_ascii_case(&haystack[..end])
/// This is faster than `str::starts_with` for small strings.
/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel)
}
}

fn ends_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool {
let start = haystack.len().saturating_sub(needle.len());
haystack.is_char_boundary(start) && needle.eq_ignore_ascii_case(&haystack[start..])
/// This is faster than `str::ends_with` for small strings.
/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(
haystack.as_bytes().iter().rev(),
needle.as_bytes().iter().rev(),
)
.all(byte_eq_kernel)
}
}

fn equals_kernel((n, h): (&u8, &u8)) -> bool {
n == h
}

fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool {
n.to_ascii_lowercase() == h.to_ascii_lowercase()
}

/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does:
Expand Down Expand Up @@ -263,6 +288,7 @@ mod tests {
let r = regex_like(a_eq, false).unwrap();
assert_eq!(r.to_string(), expected);
}

#[test]
fn test_contains() {
assert!(Predicate::contains("hay").evaluate("haystack"));
Expand All @@ -281,4 +307,81 @@ mod tests {
assert!(!Predicate::contains("x").evaluate("haystack"));
assert!(!Predicate::contains("haystack haystack").evaluate("haystack"));
}

#[test]
fn test_starts_with() {
assert!(Predicate::StartsWith("hay").evaluate("haystack"));
assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack"));
assert!(Predicate::StartsWith("haystack").evaluate("haystack"));
assert!(Predicate::StartsWith("ha").evaluate("haystack"));
assert!(Predicate::StartsWith("h").evaluate("haystack"));
assert!(Predicate::StartsWith("").evaluate("haystack"));

assert!(!Predicate::StartsWith("stack").evaluate("haystack"));
assert!(!Predicate::StartsWith("haystacks").evaluate("haystack"));
assert!(!Predicate::StartsWith("HAY").evaluate("haystack"));
assert!(!Predicate::StartsWith("h£ay").evaluate("haystack"));
assert!(!Predicate::StartsWith("hay").evaluate("h£aystack"));
}

#[test]
fn test_ends_with() {
assert!(Predicate::EndsWith("stack").evaluate("haystack"));
assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack"));
assert!(Predicate::EndsWith("haystack").evaluate("haystack"));
assert!(Predicate::EndsWith("ck").evaluate("haystack"));
assert!(Predicate::EndsWith("k").evaluate("haystack"));
assert!(Predicate::EndsWith("").evaluate("haystack"));

assert!(!Predicate::EndsWith("hay").evaluate("haystack"));
assert!(!Predicate::EndsWith("STACK").evaluate("haystack"));
assert!(!Predicate::EndsWith("haystacks").evaluate("haystack"));
assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack"));
assert!(!Predicate::EndsWith("st£ack").evaluate("haystack"));
assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack"));
}

#[test]
fn test_istarts_with() {
assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack"));
assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK"));
assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack"));
assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack"));
assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk"));

assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack"));
assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack"));
assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack"));
assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack"));
}

#[test]
fn test_iends_with() {
assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk"));
assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk"));
assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk"));
assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK"));
assert!(Predicate::IEndsWithAscii("").evaluate("haystack"));

assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack"));
assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK"));
assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack"));
assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k"));
assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack"));
}
}

0 comments on commit bf0ea91

Please sign in to comment.