Skip to content

Commit

Permalink
feat(datafusion): add RegexSearch, StringContains and StringJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Jul 24, 2023
1 parent 64ea921 commit 4edaab5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
40 changes: 40 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,43 @@ def string_find(op):
) - df.lit(1)

return df.functions.strpos(arg, pattern) - df.lit(1)


@translate.register(ops.RegexSearch)
def regex_search(op):
arg = translate(op.arg)
pattern = translate(op.pattern)

def search(arr):
default = pa.scalar(0, type=pa.int64())
lengths = pc.list_value_length(arr).fill_null(default)
return pc.greater(lengths, default)

string_regex_search = df.udf(
search,
input_types=[PyArrowType.from_ibis(dt.Array(dt.string))],
return_type=PyArrowType.from_ibis(dt.bool),
volatility="immutable",
name="string_regex_search",
)

return string_regex_search(df.functions.regexp_match(arg, pattern))


@translate.register(ops.StringContains)
def string_contains(op):
haystack = translate(op.haystack)
needle = translate(op.needle)

return df.functions.strpos(haystack, needle) > df.lit(0)


@translate.register(ops.StringJoin)
def string_join(op):
if (sep := getattr(op.sep, "value", None)) is None:
raise ValueError(
"join `sep` expressions must be literals. "
"Arbitrary expressions are not supported in the DataFusion backend"
)

return df.functions.concat_ws(sep, *map(translate, op.arg))
16 changes: 4 additions & 12 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ def test_string_col_is_unicode(alltypes, df):
lambda t: t.string_col.str.contains('6'),
id='contains',
marks=[
pytest.mark.notimpl(
["datafusion"], raises=com.OperationNotDefinedError
),
pytest.mark.broken(
["mssql"],
raises=sa.exc.ProgrammingError,
Expand Down Expand Up @@ -188,7 +185,7 @@ def test_string_col_is_unicode(alltypes, df):
id="rlike",
marks=[
pytest.mark.notimpl(
["datafusion", "mssql", "oracle"],
["mssql", "oracle"],
raises=com.OperationNotDefinedError,
),
],
Expand All @@ -199,7 +196,7 @@ def test_string_col_is_unicode(alltypes, df):
id="re_search_substring",
marks=[
pytest.mark.notimpl(
["datafusion", "mssql", "oracle"],
["mssql", "oracle"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notimpl(["impala"], raises=AssertionError),
Expand All @@ -211,7 +208,7 @@ def test_string_col_is_unicode(alltypes, df):
id='re_search',
marks=[
pytest.mark.notimpl(
["datafusion", "mssql", "oracle"],
["mssql", "oracle"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notimpl(["impala"], raises=AssertionError),
Expand All @@ -223,7 +220,7 @@ def test_string_col_is_unicode(alltypes, df):
id='re_search_posix',
marks=[
pytest.mark.notimpl(
["datafusion", "mssql", "oracle"],
["mssql", "oracle"],
raises=com.OperationNotDefinedError,
),
pytest.mark.broken(["pyspark"], raises=AssertionError),
Expand Down Expand Up @@ -826,11 +823,6 @@ def test_string_col_is_unicode(alltypes, df):
lambda t: ibis.literal('-').join(['a', t.string_col, 'c']),
lambda t: 'a-' + t.string_col + '-c',
id='join',
marks=[
pytest.mark.notimpl(
["datafusion"], raises=com.OperationNotDefinedError
),
],
),
param(
lambda t: t.string_col + t.date_string_col,
Expand Down

0 comments on commit 4edaab5

Please sign in to comment.