Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,20 +500,22 @@ pub fn create_physical_fun(
BuiltinScalarFunction::RegexpReplace => {
Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_replace,
let specializer_func = invoke_if_regex_expressions_feature_flag!(
specialize_regexp_replace,
i32,
"regexp_replace"
);
make_scalar_function(func)(args)
let func = specializer_func(args)?;
func(args)
}
DataType::LargeUtf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_replace,
let specializer_func = invoke_if_regex_expressions_feature_flag!(
specialize_regexp_replace,
i64,
"regexp_replace"
);
make_scalar_function(func)(args)
let func = specializer_func(args)?;
func(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function regexp_replace",
Expand Down
288 changes: 274 additions & 14 deletions datafusion/physical-expr/src/regex_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,32 @@

//! Regex expressions

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::array::{
new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait,
};
use arrow::compute;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use hashbrown::HashMap;
use lazy_static::lazy_static;
use regex::Regex;
use std::any::type_name;
use std::sync::Arc;

macro_rules! downcast_string_arg {
use crate::functions::make_scalar_function;

macro_rules! fetch_string_arg {
($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{
let array = downcast_string_array_arg!($ARG, $NAME, $T);
if array.is_null(0) {
return $EARLY_ABORT(array);
} else {
array.value(0)
}
}};
}

macro_rules! downcast_string_array_arg {
($ARG:expr, $NAME:expr, $T:ident) => {{
$ARG.as_any()
.downcast_ref::<GenericStringArray<T>>()
Expand All @@ -48,14 +64,14 @@ macro_rules! downcast_string_arg {
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
let values = downcast_string_array_arg!(args[0], "string", T);
let regex = downcast_string_array_arg!(args[1], "pattern", T);
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
}
3 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
let flags = Some(downcast_string_arg!(args[2], "flags", T));
let values = downcast_string_array_arg!(args[0], "string", T);
let regex = downcast_string_array_arg!(args[1], "pattern", T);
let flags = Some(downcast_string_array_arg!(args[2], "flags", T));
compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError)
}
other => Err(DataFusionError::Internal(format!(
Expand All @@ -80,14 +96,17 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
///
/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
// Default implementation for regexp_replace, assumes all args are arrays
// and args is a sequence of 3 or 4 elements.

// creating Regex is expensive so create hashmap for memoization
let mut patterns: HashMap<String, Regex> = HashMap::new();

match args.len() {
3 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let string_array = downcast_string_array_arg!(args[0], "string", T);
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);

let result = string_array
.iter()
Expand Down Expand Up @@ -120,10 +139,10 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let flags_array = downcast_string_arg!(args[3], "flags", T);
let string_array = downcast_string_array_arg!(args[0], "string", T);
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
let flags_array = downcast_string_array_arg!(args[3], "flags", T);

let result = string_array
.iter()
Expand Down Expand Up @@ -178,10 +197,125 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
}
}

fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
input_array: &GenericStringArray<T>,
) -> Result<ArrayRef> {
// Mimicking the existing behavior of regexp_replace, if any of the scalar arguments
// are actuall null, then the result will be an array of the same size but with nulls.
Ok(new_null_array(input_array.data_type(), input_array.len()))
}

/// Special cased regex_replace implementation for the scenerio where
/// the pattern, replacement and flags are static (arrays that are derived
/// from scalars). This means we can skip regex caching system and basically
/// hold a single Regex object for the replace operation. This also speeds
/// up the pre-processing time of the replacement string, since it only
/// needs to processed once.
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
let string_array = downcast_string_array_arg!(args[0], "string", T);
let pattern = fetch_string_arg!(args[1], "pattern", T, _regexp_replace_early_abort);
let replacement =
fetch_string_arg!(args[2], "replacement", T, _regexp_replace_early_abort);
let flags = match args.len() {
3 => None,
4 => Some(fetch_string_arg!(args[3], "flags", T, _regexp_replace_early_abort)),
other => {
return Err(DataFusionError::Internal(format!(
"regexp_replace was called with {} arguments. It requires at least 3 and at most 4.",
other
)))
}
};

// Embed the flag (if it exists) into the pattern. Limit will determine
// whether this is a global match (as in replace all) or just a single
// replace operation.
let (pattern, limit) = match flags {
Some("g") => (pattern.to_string(), 0),
Some(flags) => (
format!("(?{}){}", flags.to_string().replace('g', ""), pattern),
!flags.contains('g') as usize,
),
None => (pattern.to_string(), 1),
};

let re = Regex::new(&pattern)
.map_err(|err| DataFusionError::Execution(err.to_string()))?;

// Replaces the posix groups in the replacement string
// with rust ones.
let replacement = regex_replace_posix_groups(replacement);

let result = string_array
.iter()
.map(|string| {
string.map(|string| re.replacen(string, limit, replacement.as_str()))
})
.collect::<GenericStringArray<T>>();
Ok(Arc::new(result) as ArrayRef)
}

/// Determine which implementation of the regexp_replace to use based
/// on the given set of arguments.
pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
args: &[ColumnarValue],
) -> Result<ScalarFunctionImplementation> {
// This will serve as a dispatch table where we can
// leverage it in order to determine whether the scalarity
// of the given set of arguments fits a better specialized
// function.
let (is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = (
matches!(args[0], ColumnarValue::Scalar(_)),
matches!(args[1], ColumnarValue::Scalar(_)),
matches!(args[2], ColumnarValue::Scalar(_)),
// The forth argument (flags) is optional; so in the event that
// it is not available, we'll claim that it is scalar.
matches!(args.get(3), Some(ColumnarValue::Scalar(_)) | None),
);

match (
is_source_scalar,
is_pattern_scalar,
is_replacement_scalar,
is_flags_scalar,
) {
// This represents a very hot path for the case where the there is
// a single pattern that is being matched against and a single replacement.
// This is extremely important to specialize on since it removes the overhead
// of DF's in-house regex pattern cache (since there will be at most a single
// pattern) and the pre-processing of the same replacement pattern at each
// query.
//
// The flags needs to be a scalar as well since each pattern is actually
// constructed with the flags embedded into the pattern itself. This means
// even if the pattern itself is scalar, if the flags are an array then
// we will create many regexes and it is best to use the implementation
// that caches it. If there are no flags, we can simply ignore it here,
// and let the specialized function handle it.
(_, true, true, true) => {
// We still don't know the scalarity of source, so we need the adapter
// even if it will do some extra work for the pattern and the flags.
//
// TODO: maybe we need a way of telling the adapter on which arguments
// it can skip filling (so that we won't create N - 1 redundant cols).
Ok(make_scalar_function(
_regexp_replace_static_pattern_replace::<T>,
))
}

// If there are no specialized implementations, we'll fall back to the
// generic implementation.
(_, _, _, _) => Ok(make_scalar_function(regexp_replace::<T>)),
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::*;
use datafusion_common::ScalarValue;

#[test]
fn test_case_sensitive_regexp_match() {
Expand Down Expand Up @@ -231,4 +365,130 @@ mod tests {

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec!["afooc"; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_with_flags() {
let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let flags = StringArray::from(vec!["i"; 5]);
let expected =
StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_early_abort() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec![None; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec![None; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_early_abort_flags() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["a"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let flags = StringArray::from(vec![None; 5]);
let expected = StringArray::from(vec![None; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_pattern_error() {
let values = StringArray::from(vec!["abc"; 5]);
// Delibaretely using an invalid pattern to see how the single pattern
// error is propagated on regexp_replace.
let patterns = StringArray::from(vec!["["; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
]);
let pattern_err = re.expect_err("broken pattern should have failed");
assert_eq!(
pattern_err.to_string(),
"Execution error: regex parse error:\n [\n ^\nerror: unclosed character class"
);
}

#[test]
fn test_regexp_can_specialize_all_cases() {
macro_rules! make_scalar {
() => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some("foo".to_string())))
};
}

macro_rules! make_array {
() => {
ColumnarValue::Array(
Arc::new(StringArray::from(vec!["bar"; 2])) as ArrayRef
)
};
}

for source in [make_scalar!(), make_array!()] {
for pattern in [make_scalar!(), make_array!()] {
for replacement in [make_scalar!(), make_array!()] {
for flags in [Some(make_scalar!()), Some(make_array!()), None] {
let mut args =
vec![source.clone(), pattern.clone(), replacement.clone()];
if let Some(flags) = flags {
args.push(flags.clone());
}
let regex_func = specialize_regexp_replace::<i32>(&args);
assert!(regex_func.is_ok());
}
}
}
}
}
}