Skip to content

Commit

Permalink
Implement native support StringView for CONTAINS function (#12168)
Browse files Browse the repository at this point in the history
* Implement native support StringView for contains function

Signed-off-by: Tai Le Manh <[email protected]>

* Fix cargo fmt

* Implement native support StringView for contains function

Signed-off-by: Tai Le Manh <[email protected]>

* Fix cargo check

* Fix unresolved doc link

* Implement native support StringView for contains function

Signed-off-by: Tai Le Manh <[email protected]>

* Update datafusion/functions/src/regexp_common.rs

---------

Signed-off-by: Tai Le Manh <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
tlm365 and alamb committed Sep 10, 2024
1 parent 376a0b8 commit c71a9d7
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 36 deletions.
2 changes: 1 addition & 1 deletion datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ math_expressions = []
# enable regular expressions
regex_expressions = ["regex"]
# enable string functions
string_expressions = ["uuid"]
string_expressions = ["regex_expressions", "uuid"]
# enable unicode functions
unicode_expressions = ["hashbrown", "unicode-segmentation"]

Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ pub mod macros;
pub mod string;
make_stub_package!(string, "string_expressions");

#[cfg(feature = "string_expressions")]
mod regexp_common;

/// Core datafusion expressions
/// Enabled via feature flag `core_expressions`
#[cfg(feature = "core_expressions")]
Expand Down
3 changes: 2 additions & 1 deletion datafusion/functions/src/regex/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

//! "regx" DataFusion functions
//! "regex" DataFusion functions

pub mod regexplike;
pub mod regexpmatch;
pub mod regexpreplace;

// create UDFs
make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match);
make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like);
Expand Down
123 changes: 123 additions & 0 deletions datafusion/functions/src/regexp_common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Common utilities for implementing regex functions

use crate::string::common::StringArrayType;

use arrow::array::{Array, ArrayDataBuilder, BooleanArray};
use arrow::datatypes::DataType;
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use datafusion_common::DataFusionError;
use regex::Regex;

use std::collections::HashMap;

#[cfg(doc)]
use arrow::array::{LargeStringArray, StringArray, StringViewArray};
/// Perform SQL `array ~ regex_array` operation on
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
///
/// If `regex_array` element has an empty value, the corresponding result value is always true.
///
/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
/// which allow special search modes, such as case-insensitive and multi-line mode.
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
/// for more information.
///
/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs].
///
/// Can remove when <https://github.com/apache/arrow-rs/issues/6370> is implemented upstream
///
/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37
pub fn regexp_is_match_utf8<'a, S1, S2, S3>(
array: &'a S1,
regex_array: &'a S2,
flags_array: Option<&'a S3>,
) -> datafusion_common::Result<BooleanArray, DataFusionError>
where
&'a S1: StringArrayType<'a>,
&'a S2: StringArrayType<'a>,
&'a S3: StringArrayType<'a>,
{
if array.len() != regex_array.len() {
return Err(DataFusionError::Execution(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());

let mut patterns: HashMap<String, Regex> = HashMap::new();
let mut result = BooleanBufferBuilder::new(array.len());

let complete_pattern = match flags_array {
Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
|(pattern, flags)| {
pattern.map(|pattern| match flags {
Some(flag) => format!("(?{flag}){pattern}"),
None => pattern.to_string(),
})
},
)) as Box<dyn Iterator<Item = Option<String>>>,
None => Box::new(
regex_array
.iter()
.map(|pattern| pattern.map(|pattern| pattern.to_string())),
),
};

array
.iter()
.zip(complete_pattern)
.map(|(value, pattern)| {
match (value, pattern) {
(Some(_), Some(pattern)) if pattern == *"" => {
result.append(true);
}
(Some(value), Some(pattern)) => {
let existing_pattern = patterns.get(&pattern);
let re = match existing_pattern {
Some(re) => re,
None => {
let re = Regex::new(pattern.as_str()).map_err(|e| {
DataFusionError::Execution(format!(
"Regular expression did not compile: {e:?}"
))
})?;
patterns.entry(pattern).or_insert(re)
}
};
result.append(re.is_match(value));
}
_ => result.append(false),
}
Ok(())
})
.collect::<datafusion_common::Result<Vec<()>, DataFusionError>>()?;

let data = unsafe {
ArrayDataBuilder::new(DataType::Boolean)
.len(array.len())
.buffers(vec![result.into()])
.nulls(nulls)
.build_unchecked()
};

Ok(BooleanArray::from(data))
}
190 changes: 167 additions & 23 deletions datafusion/functions/src/string/contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@
// specific language governing permissions and limitations
// under the License.

use crate::regexp_common::regexp_is_match_utf8;
use crate::utils::make_scalar_function;
use arrow::array::{ArrayRef, OffsetSizeTrait};

use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Boolean;
use datafusion_common::cast::as_generic_string_array;
use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
use datafusion_common::exec_err;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::{arrow_datafusion_err, exec_err};
use datafusion_expr::ScalarUDFImpl;
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, Signature, Volatility};

use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
pub struct ContainsFunc {
signature: Signature,
Expand All @@ -44,7 +47,17 @@ impl ContainsFunc {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
vec![
Exact(vec![Utf8View, Utf8View]),
Exact(vec![Utf8View, Utf8]),
Exact(vec![Utf8View, LargeUtf8]),
Exact(vec![Utf8, Utf8View]),
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8View]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
Volatility::Immutable,
),
}
Expand All @@ -69,28 +82,116 @@ impl ScalarUDFImpl for ContainsFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(contains::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(contains::<i64>, vec![])(args),
other => {
exec_err!("unsupported data type {other:?} for function contains")
}
}
make_scalar_function(contains, vec![])(args)
}
}

/// use regexp_is_match_utf8_scalar to do the calculation for contains
pub fn contains<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef, DataFusionError> {
let mod_str = as_generic_string_array::<T>(&args[0])?;
let match_str = as_generic_string_array::<T>(&args[1])?;
let res = arrow::compute::kernels::comparison::regexp_is_match_utf8(
mod_str, match_str, None,
)
.map_err(|e| arrow_datafusion_err!(e))?;

Ok(Arc::new(res) as ArrayRef)
pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
match (args[0].data_type(), args[1].data_type()) {
(Utf8View, Utf8View) => {
let mod_str = args[0].as_string_view();
let match_str = args[1].as_string_view();
let res = regexp_is_match_utf8::<
StringViewArray,
StringViewArray,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(Utf8View, Utf8) => {
let mod_str = args[0].as_string_view();
let match_str = args[1].as_string::<i32>();
let res = regexp_is_match_utf8::<
StringViewArray,
GenericStringArray<i32>,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(Utf8View, LargeUtf8) => {
let mod_str = args[0].as_string_view();
let match_str = args[1].as_string::<i64>();
let res = regexp_is_match_utf8::<
StringViewArray,
GenericStringArray<i64>,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(Utf8, Utf8View) => {
let mod_str = args[0].as_string::<i32>();
let match_str = args[1].as_string_view();
let res = regexp_is_match_utf8::<
GenericStringArray<i32>,
StringViewArray,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(Utf8, Utf8) => {
let mod_str = args[0].as_string::<i32>();
let match_str = args[1].as_string::<i32>();
let res = regexp_is_match_utf8::<
GenericStringArray<i32>,
GenericStringArray<i32>,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(Utf8, LargeUtf8) => {
let mod_str = args[0].as_string::<i32>();
let match_str = args[1].as_string::<i64>();
let res = regexp_is_match_utf8::<
GenericStringArray<i32>,
GenericStringArray<i64>,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(LargeUtf8, Utf8View) => {
let mod_str = args[0].as_string::<i64>();
let match_str = args[1].as_string_view();
let res = regexp_is_match_utf8::<
GenericStringArray<i64>,
StringViewArray,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(LargeUtf8, Utf8) => {
let mod_str = args[0].as_string::<i64>();
let match_str = args[1].as_string::<i32>();
let res = regexp_is_match_utf8::<
GenericStringArray<i64>,
GenericStringArray<i32>,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
(LargeUtf8, LargeUtf8) => {
let mod_str = args[0].as_string::<i64>();
let match_str = args[1].as_string::<i64>();
let res = regexp_is_match_utf8::<
GenericStringArray<i64>,
GenericStringArray<i64>,
GenericStringArray<i32>,
>(mod_str, match_str, None)?;

Ok(Arc::new(res) as ArrayRef)
}
other => {
exec_err!("Unsupported data type {other:?} for function `contains`.")
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -138,6 +239,49 @@ mod tests {
Boolean,
BooleanArray
);

test_function!(
ContainsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"Apache"
)))),
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("pac")))),
],
Ok(Some(true)),
bool,
Boolean,
BooleanArray
);
test_function!(
ContainsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"Apache"
)))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ap")))),
],
Ok(Some(false)),
bool,
Boolean,
BooleanArray
);
test_function!(
ContainsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"Apache"
)))),
ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
"DataFusion"
)))),
],
Ok(Some(false)),
bool,
Boolean,
BooleanArray
);

Ok(())
}
}
Loading

0 comments on commit c71a9d7

Please sign in to comment.