Skip to content

Commit fe85ff8

Browse files
committed
Optimize regex_replace for scalar patterns
1 parent ebb28f5 commit fe85ff8

File tree

2 files changed

+155
-20
lines changed

2 files changed

+155
-20
lines changed

datafusion/physical-expr/src/functions.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,20 +500,22 @@ pub fn create_physical_fun(
500500
BuiltinScalarFunction::RegexpReplace => {
501501
Arc::new(|args| match args[0].data_type() {
502502
DataType::Utf8 => {
503-
let func = invoke_if_regex_expressions_feature_flag!(
504-
regexp_replace,
503+
let specializer_func = invoke_if_regex_expressions_feature_flag!(
504+
specialize_regexp_replace,
505505
i32,
506506
"regexp_replace"
507507
);
508-
make_scalar_function(func)(args)
508+
let func = specializer_func(args)?;
509+
func(args)
509510
}
510511
DataType::LargeUtf8 => {
511-
let func = invoke_if_regex_expressions_feature_flag!(
512-
regexp_replace,
512+
let specializer_func = invoke_if_regex_expressions_feature_flag!(
513+
specialize_regexp_replace,
513514
i64,
514515
"regexp_replace"
515516
);
516-
make_scalar_function(func)(args)
517+
let func = specializer_func(args)?;
518+
func(args)
517519
}
518520
other => Err(DataFusionError::Internal(format!(
519521
"Unsupported data type {:?} for function regexp_replace",

datafusion/physical-expr/src/regex_expressions.rs

Lines changed: 147 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,32 @@
2121

2222
//! Regex expressions
2323
24-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
24+
use arrow::array::{
25+
new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait,
26+
};
2527
use arrow::compute;
2628
use datafusion_common::{DataFusionError, Result};
29+
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
2730
use hashbrown::HashMap;
2831
use lazy_static::lazy_static;
2932
use regex::Regex;
3033
use std::any::type_name;
3134
use std::sync::Arc;
3235

33-
macro_rules! downcast_string_arg {
36+
use crate::functions::make_scalar_function;
37+
38+
macro_rules! fetch_string_arg {
39+
($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{
40+
let array = downcast_string_array_arg!($ARG, $NAME, $T);
41+
if array.is_null(0) {
42+
return $EARLY_ABORT(array);
43+
} else {
44+
array.value(0)
45+
}
46+
}};
47+
}
48+
49+
macro_rules! downcast_string_array_arg {
3450
($ARG:expr, $NAME:expr, $T:ident) => {{
3551
$ARG.as_any()
3652
.downcast_ref::<GenericStringArray<T>>()
@@ -48,14 +64,14 @@ macro_rules! downcast_string_arg {
4864
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
4965
match args.len() {
5066
2 => {
51-
let values = downcast_string_arg!(args[0], "string", T);
52-
let regex = downcast_string_arg!(args[1], "pattern", T);
67+
let values = downcast_string_array_arg!(args[0], "string", T);
68+
let regex = downcast_string_array_arg!(args[1], "pattern", T);
5369
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
5470
}
5571
3 => {
56-
let values = downcast_string_arg!(args[0], "string", T);
57-
let regex = downcast_string_arg!(args[1], "pattern", T);
58-
let flags = Some(downcast_string_arg!(args[2], "flags", T));
72+
let values = downcast_string_array_arg!(args[0], "string", T);
73+
let regex = downcast_string_array_arg!(args[1], "pattern", T);
74+
let flags = Some(downcast_string_array_arg!(args[2], "flags", T));
5975
compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError)
6076
}
6177
other => Err(DataFusionError::Internal(format!(
@@ -80,14 +96,17 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
8096
///
8197
/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
8298
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
99+
// Default implementation for regexp_replace, assumes all args are arrays
100+
// and args is a sequence of 3 or 4 elements.
101+
83102
// creating Regex is expensive so create hashmap for memoization
84103
let mut patterns: HashMap<String, Regex> = HashMap::new();
85104

86105
match args.len() {
87106
3 => {
88-
let string_array = downcast_string_arg!(args[0], "string", T);
89-
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
90-
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
107+
let string_array = downcast_string_array_arg!(args[0], "string", T);
108+
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
109+
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
91110

92111
let result = string_array
93112
.iter()
@@ -120,10 +139,10 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
120139
Ok(Arc::new(result) as ArrayRef)
121140
}
122141
4 => {
123-
let string_array = downcast_string_arg!(args[0], "string", T);
124-
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
125-
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
126-
let flags_array = downcast_string_arg!(args[3], "flags", T);
142+
let string_array = downcast_string_array_arg!(args[0], "string", T);
143+
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
144+
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
145+
let flags_array = downcast_string_array_arg!(args[3], "flags", T);
127146

128147
let result = string_array
129148
.iter()
@@ -178,6 +197,120 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
178197
}
179198
}
180199

200+
fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
201+
input_array: &GenericStringArray<T>,
202+
) -> Result<ArrayRef> {
203+
// Mimicing the existing behavior of regexp_replace, if any of the scalar arguments
204+
// are actuall null, then the result will be an array of the same size but with nulls.
205+
Ok(new_null_array(input_array.data_type(), input_array.len()))
206+
}
207+
208+
fn _regexp_replace_static_pattern<T: OffsetSizeTrait>(
209+
args: &[ArrayRef],
210+
) -> Result<ArrayRef> {
211+
// Special cased regex_replace implementation for the scenerio where
212+
// both the pattern itself and the flags are scalars. This means we can
213+
// skip regex caching system and basically hold a single Regex object
214+
// for the replace operation.
215+
216+
let string_array = downcast_string_array_arg!(args[0], "string", T);
217+
let pattern = fetch_string_arg!(args[1], "pattern", T, _regexp_replace_early_abort);
218+
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
219+
let flags = match args.len() {
220+
3 => None,
221+
4 => Some(fetch_string_arg!(args[3], "flags", T, _regexp_replace_early_abort)),
222+
other => {
223+
return Err(DataFusionError::Internal(format!(
224+
"regexp_replace was called with {} arguments. It requires at least 3 and at most 4.",
225+
other
226+
)))
227+
}
228+
};
229+
230+
// Embed the flag (if it exists) into the pattern
231+
let (pattern, replace_all) = match flags {
232+
Some("g") => (pattern.to_string(), true),
233+
Some(flags) => (
234+
format!("(?{}){}", flags.to_string().replace('g', ""), pattern),
235+
flags.contains('g'),
236+
),
237+
None => (pattern.to_string(), false),
238+
};
239+
240+
let re = Regex::new(&pattern)
241+
.map_err(|err| DataFusionError::Execution(err.to_string()))?;
242+
243+
let result = string_array
244+
.iter()
245+
.zip(replacement_array.iter())
246+
.map(|(string, replacement)| match (string, replacement) {
247+
(Some(string), Some(replacement)) => {
248+
let replacement = regex_replace_posix_groups(replacement);
249+
250+
if replace_all {
251+
Some(re.replace_all(string, replacement.as_str()))
252+
} else {
253+
Some(re.replace(string, replacement.as_str()))
254+
}
255+
}
256+
_ => None,
257+
})
258+
.collect::<GenericStringArray<T>>();
259+
Ok(Arc::new(result) as ArrayRef)
260+
}
261+
262+
/// Determine which implementation of the regexp_replace to use based
263+
/// on the given set of arguments.
264+
pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
265+
args: &[ColumnarValue],
266+
) -> Result<ScalarFunctionImplementation> {
267+
// This will serve as a dispatch table where we can
268+
// leverage it in order to determine whether the scalarity
269+
// of the given set of arguments fits a better specialized
270+
// function.
271+
let (is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = (
272+
matches!(args[0], ColumnarValue::Scalar(_)),
273+
matches!(args[1], ColumnarValue::Scalar(_)),
274+
matches!(args[2], ColumnarValue::Scalar(_)),
275+
// The forth argument (flags) is optional; so in the event that
276+
// it is not available, we'll claim that it is scalar.
277+
matches!(args.get(3), Some(ColumnarValue::Scalar(_)) | None),
278+
);
279+
280+
match (
281+
is_source_scalar,
282+
is_pattern_scalar,
283+
is_replacement_scalar,
284+
is_flags_scalar,
285+
) {
286+
// This represents a very hot path for the case where the there is
287+
// a single pattern that is being matched against. This is extremely
288+
// important to specialize on since it removes the overhead of DF's
289+
// in-house regex pattern cache (since there will be at most a single
290+
// pattern).
291+
//
292+
// The flags needs to be a scalar as well since each pattern is actually
293+
// constructed with the flags embedded into the pattern itself. This means
294+
// even if the pattern itself is scalar, if the flags are an array then
295+
// we will create many regexes and it is best to use the implementation
296+
// that caches it. If there are no flags, we can simply ignore it here,
297+
// and let the specialized function handle it.
298+
(_, true, _, true) => {
299+
// We still don't know the scalarity of source/replacement, so we
300+
// need the adapter even if it will do some extra work for the pattern
301+
// and the flags.
302+
//
303+
// TODO: maybe we need a way of telling the adapter on which arguments
304+
// it can skip filling (so that we won't create N - 1 redundant cols).
305+
Ok(make_scalar_function(_regexp_replace_static_pattern::<T>))
306+
}
307+
308+
// If there are no specialized implementations, we'll fall back to the
309+
// generic implementation.
310+
(_, _, _, _) => Ok(make_scalar_function(regexp_replace::<T>)),
311+
}
312+
}
313+
181314
#[cfg(test)]
182315
mod tests {
183316
use super::*;

0 commit comments

Comments
 (0)