Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add .str.count_matches() #2580

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ def minhash(
seed: int = 1,
) -> PyExpr: ...
def sql(sql: str, catalog: PyCatalog) -> LogicalPlanBuilder: ...
def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ...

class PyCatalog:
@staticmethod
Expand Down Expand Up @@ -1319,6 +1320,7 @@ class PySeries:
def utf8_to_date(self, format: str) -> PySeries: ...
def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PySeries: ...
def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PySeries: ...
def utf8_count_matches(self, patterns: PySeries, whole_word: bool, case_sensitive: bool) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def is_inf(self) -> PySeries: ...
def not_nan(self) -> PySeries: ...
Expand Down
34 changes: 34 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from daft.daft import tokenize_encode as _tokenize_encode
from daft.daft import udf as _udf
from daft.daft import url_download as _url_download
from daft.daft import utf8_count_matches as _utf8_count_matches
from daft.datatype import DataType, TimeUnit
from daft.expressions.testing import expr_structurally_equal
from daft.logical.schema import Field, Schema
Expand Down Expand Up @@ -2619,6 +2620,39 @@ def tokenize_decode(
"""
return Expression._from_pyexpr(_tokenize_decode(self._expr, tokens_path, io_config, pattern, special_tokens))

def count_matches(
self,
patterns: Any,
whole_words: bool = False,
case_sensitive: bool = True,
):
"""
Counts the number of times a pattern, or multiple patterns, appear in a string.

.. NOTE::
If a pattern is a substring of another pattern, the longest pattern is matched first.
For example, in the string "hello world", with patterns "hello", "world", and "hello world",
one match is counted for "hello world".

If whole_words is true, then matches are only counted if they are whole words. This
also applies to multi-word strings. For example, on the string "abc def", the strings
"def" and "abc def" would be matched, but "bc de", "abc d", and "abc " (with the space)
would not.

If case_sensitive is false, then case will be ignored. This only applies to ASCII
characters; unicode uppercase/lowercase will still be considered distinct.

Args:
patterns: A pattern or a list of patterns.
whole_words: Whether to only match whole word(s). Defaults to false.
case_sensitive: Whether the matching should be case sensitive. Defaults to true.
"""
if not isinstance(patterns, Expression):
series = item_to_series("items", patterns)
patterns = Expression._to_expression(series)

return Expression._from_pyexpr(_utf8_count_matches(self._expr, patterns._expr, whole_words, case_sensitive))


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
10 changes: 10 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,16 @@
assert self._series is not None
return Series._from_pyseries(self._series.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space))

def count_matches(self, patterns: Series, whole_words: bool = False, case_sensitive: bool = True) -> Series:
if not isinstance(patterns, Series):
raise ValueError(f"expected another Series but got {type(patterns)}")

Check warning on line 909 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L909

Added line #L909 was not covered by tests
if not isinstance(whole_words, bool):
raise ValueError(f"expected bool for whole_word but got {type(whole_words)}")

Check warning on line 911 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L911

Added line #L911 was not covered by tests
if not isinstance(case_sensitive, bool):
raise ValueError(f"expected bool for case_sensitive but got {type(case_sensitive)}")

Check warning on line 913 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L913

Added line #L913 was not covered by tests
assert self._series is not None and patterns._series is not None
return Series._from_pyseries(self._series.utf8_count_matches(patterns._series, whole_words, case_sensitive))


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.normalize
Expression.str.tokenize_encode
Expression.str.tokenize_decode
Expression.str.count_matches

.. _api-float-expression-operations:

Expand Down
1 change: 1 addition & 0 deletions src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[dependencies]
aho-corasick = "1.1.3"
arrow2 = {workspace = true, features = [
"chrono-tz",
"compute_take",
Expand Down
45 changes: 45 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
},
DataType, Series,
};
use aho_corasick::{AhoCorasickBuilder, MatchKind};
use arrow2::{array::Array, temporal_conversions};
use chrono::Datelike;
use common_error::{DaftError, DaftResult};
Expand Down Expand Up @@ -1383,6 +1384,50 @@
))
}

// Uses the Aho-Corasick algorithm to count occurrences of a number of patterns.
pub fn count_matches(
&self,
patterns: &Self,
whole_word: bool,
case_sensitive: bool,
Comment on lines +1388 to +1392
Copy link
Collaborator

@universalmind303 universalmind303 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it'd be simpler to instead accept a regex?

the regex crate uses aho-corasick under the hood, so the perf implications should be negligible as long as we're only compiling the regex once.

Regex would likely be a lot more intuitive/flexible from an end user perspective as well.

res = s.str.count_matches('\b(fox|over|lazy dog|dog)\b').to_pylist()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll change it to that. I didn't originally do this because I was worried about performance - I'll run some tests to make sure it isn't too affected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it does seem like using a regex is around 7x slower. I guess it just can't handle the large pattern created by concatenating the list of strings. So I think I'll keep it this way for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though it's slower, I think we should still allow for regex just for usability sake.

We can do this in a follow up PR, and have the python frontend map to a different backend implementation

count_matches('<text>') -> count_matches
count_matches(r'<pattern>') -> count_matches_regex

) -> DaftResult<UInt64Array> {
if patterns.null_count() == patterns.len() {
// no matches
return UInt64Array::from_iter(self.name(), iter::repeat(Some(0)).take(self.len()))
.with_validity(self.validity().cloned());

Check warning on line 1397 in src/daft-core/src/array/ops/utf8.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/utf8.rs#L1396-L1397

Added lines #L1396 - L1397 were not covered by tests
}

let patterns = patterns.as_arrow().iter().flatten();
let ac = AhoCorasickBuilder::new()
.ascii_case_insensitive(!case_sensitive)
.match_kind(MatchKind::LeftmostLongest)
.build(patterns)
.map_err(|e| {
DaftError::ComputeError(format!("Error creating string automaton: {}", e))

Check warning on line 1406 in src/daft-core/src/array/ops/utf8.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/utf8.rs#L1406

Added line #L1406 was not covered by tests
})?;
let iter = self.as_arrow().iter().map(|opt| {
opt.map(|s| {
let results = ac.find_iter(s);
if whole_word {
results
.filter(|m| {
// ensure this match is a whole word (or set of words)
// don't want to filter out things like "brass"
let prev_char = s.get(m.start() - 1..m.start());
let next_char = s.get(m.end()..m.end() + 1);
!(prev_char.is_some_and(|s| s.chars().next().unwrap().is_alphabetic())
|| next_char
.is_some_and(|s| s.chars().next().unwrap().is_alphabetic()))
})
.count() as u64
} else {
results.count() as u64
}
})
});
Ok(UInt64Array::from_iter(self.name(), iter))
}

fn unary_broadcasted_op<ScalarKernel>(&self, operation: ScalarKernel) -> DaftResult<Utf8Array>
where
ScalarKernel: Fn(&str) -> Cow<'_, str>,
Expand Down
12 changes: 12 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,18 @@ impl PySeries {
Ok(self.series.utf8_normalize(opts)?.into())
}

pub fn utf8_count_matches(
&self,
patterns: &Self,
whole_word: bool,
case_sensitive: bool,
) -> PyResult<Self> {
Ok(self
.series
.utf8_count_matches(&patterns.series, whole_word, case_sensitive)?
.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
15 changes: 15 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,19 @@ impl Series {
pub fn utf8_normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.normalize(opts)?.into_series()))
}

pub fn utf8_count_matches(
&self,
patterns: &Series,
whole_word: bool,
case_sensitive: bool,
) -> DaftResult<Series> {
self.with_utf8_array(|arr| {
patterns.with_utf8_array(|pattern_arr| {
Ok(arr
.count_matches(pattern_arr, whole_word, case_sensitive)?
.into_series())
})
})
}
}
89 changes: 89 additions & 0 deletions src/daft-functions/src/count_matches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use common_error::{DaftError, DaftResult};

use daft_core::{datatypes::Field, schema::Schema, DataType, Series};
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]

Check warning on line 10 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L10

Added line #L10 was not covered by tests
struct CountMatchesFunction {
pub(super) whole_words: bool,
pub(super) case_sensitive: bool,
}

#[typetag::serde]

Check warning on line 16 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L16

Added line #L16 was not covered by tests
impl ScalarUDF for CountMatchesFunction {
fn as_any(&self) -> &dyn std::any::Any {
self
}

Check warning on line 20 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L18-L20

Added lines #L18 - L20 were not covered by tests

fn name(&self) -> &'static str {
"count_matches"
}

Check warning on line 24 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L22-L24

Added lines #L22 - L24 were not covered by tests

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[data, _] => match data.to_field(schema) {
Ok(field) => match &field.dtype {
DataType::Utf8 => Ok(Field::new(field.name, DataType::UInt64)),
a => Err(DaftError::TypeError(format!(
"Expects inputs to count_matches to be utf8, but received {a}",
))),

Check warning on line 33 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L31-L33

Added lines #L31 - L33 were not covered by tests
},
Err(e) => Err(e),

Check warning on line 35 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L35

Added line #L35 was not covered by tests
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),

Check warning on line 40 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L37-L40

Added lines #L37 - L40 were not covered by tests
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[data, patterns] => {
data.utf8_count_matches(patterns, self.whole_words, self.case_sensitive)
}
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),

Check warning on line 52 in src/daft-functions/src/count_matches.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/count_matches.rs#L49-L52

Added lines #L49 - L52 were not covered by tests
}
}
}

pub fn utf8_count_matches(
input: ExprRef,
patterns: ExprRef,
whole_words: bool,
case_sensitive: bool,
) -> ExprRef {
ScalarFunction::new(
CountMatchesFunction {
whole_words,
case_sensitive,
},
vec![input, patterns],
)
.into()
}

#[cfg(feature = "python")]
pub mod python {
use daft_dsl::python::PyExpr;
use pyo3::{pyfunction, PyResult};

#[pyfunction]
pub fn utf8_count_matches(
expr: PyExpr,
patterns: PyExpr,
whole_words: bool,
case_sensitive: bool,
) -> PyResult<PyExpr> {
let expr =
super::utf8_count_matches(expr.into(), patterns.into(), whole_words, case_sensitive);
Ok(expr.into())
}
}
2 changes: 2 additions & 0 deletions src/daft-functions/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![feature(async_closure)]
pub mod count_matches;
pub mod distance;
pub mod hash;
pub mod minhash;
Expand All @@ -19,6 +20,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_wrapped(wrap_pyfunction!(tokenize::python::tokenize_encode))?;
parent.add_wrapped(wrap_pyfunction!(tokenize::python::tokenize_decode))?;
parent.add_wrapped(wrap_pyfunction!(minhash::python::minhash))?;
parent.add_wrapped(wrap_pyfunction!(count_matches::python::utf8_count_matches))?;

Ok(())
}
Expand Down
41 changes: 41 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,3 +1550,44 @@ def test_series_utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space
).to_pylist()
b = [manual_normalize(t, remove_punct, lowercase, nfd_unicode, white_space) for t in NORMALIZE_TEST_DATA]
assert a == b


def test_series_utf8_count_matches():
s = Series.from_pylist(
[
"the quick brown fox jumped over the lazy dog",
"the quick brown foe jumped o'er the lazy dot",
"the fox fox fox jumped over over dog lazy dog",
"the quick brown foxes hovered above the lazy dogs",
"the quick brown-fox jumped over the 'lazy dog'",
"thequickbrownfoxjumpedoverthelazydog",
"THE QUICK BROWN FOX JUMPED over THE Lazy DOG",
" fox dog over ",
]
)
p = Series.from_pylist(
[
"fox",
"over",
"lazy dog",
"dog",
]
)

res = s.str.count_matches(p, False, False).to_pylist()
assert res == [3, 0, 7, 3, 3, 3, 3, 3]
res = s.str.count_matches(p, True, False).to_pylist()
assert res == [3, 0, 7, 0, 3, 0, 3, 3]
res = s.str.count_matches(p, False, True).to_pylist()
assert res == [3, 0, 7, 3, 3, 3, 1, 3]
res = s.str.count_matches(p, True, True).to_pylist()
assert res == [3, 0, 7, 0, 3, 0, 1, 3]


@pytest.mark.parametrize("whole_words", [False, True])
@pytest.mark.parametrize("case_sensitive", [False, True])
def test_series_utf8_count_matches_overlap(whole_words, case_sensitive):
s = Series.from_pylist(["hello world"])
p = Series.from_pylist(["hello world", "hello", "world"])
res = s.str.count_matches(p, whole_words, case_sensitive).to_pylist()
assert res == [1]
Loading
Loading