2121
2222//! Regex expressions
2323
24- use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
24+ use arrow:: array:: {
25+ new_null_array, Array , ArrayRef , GenericStringArray , OffsetSizeTrait ,
26+ } ;
2527use arrow:: compute;
2628use datafusion_common:: { DataFusionError , Result } ;
29+ use datafusion_expr:: { ColumnarValue , ScalarFunctionImplementation } ;
2730use hashbrown:: HashMap ;
2831use lazy_static:: lazy_static;
2932use regex:: Regex ;
3033use std:: any:: type_name;
3134use std:: sync:: Arc ;
3235
33- macro_rules! downcast_string_arg {
36+ use crate :: functions:: make_scalar_function;
37+
38+ macro_rules! fetch_string_arg {
39+ ( $ARG: expr, $NAME: expr, $T: ident, $EARLY_ABORT: ident) => { {
40+ let array = downcast_string_array_arg!( $ARG, $NAME, $T) ;
41+ if array. is_null( 0 ) {
42+ return $EARLY_ABORT( array) ;
43+ } else {
44+ array. value( 0 )
45+ }
46+ } } ;
47+ }
48+
49+ macro_rules! downcast_string_array_arg {
3450 ( $ARG: expr, $NAME: expr, $T: ident) => { {
3551 $ARG. as_any( )
3652 . downcast_ref:: <GenericStringArray <T >>( )
@@ -48,14 +64,14 @@ macro_rules! downcast_string_arg {
4864pub fn regexp_match < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
4965 match args. len ( ) {
5066 2 => {
51- let values = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
52- let regex = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
67+ let values = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
68+ let regex = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
5369 compute:: regexp_match ( values, regex, None ) . map_err ( DataFusionError :: ArrowError )
5470 }
5571 3 => {
56- let values = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
57- let regex = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
58- let flags = Some ( downcast_string_arg ! ( args[ 2 ] , "flags" , T ) ) ;
72+ let values = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
73+ let regex = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
74+ let flags = Some ( downcast_string_array_arg ! ( args[ 2 ] , "flags" , T ) ) ;
5975 compute:: regexp_match ( values, regex, flags) . map_err ( DataFusionError :: ArrowError )
6076 }
6177 other => Err ( DataFusionError :: Internal ( format ! (
@@ -80,14 +96,17 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
8096///
8197/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
8298pub fn regexp_replace < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
99+ // Default implementation for regexp_replace, assumes all args are arrays
100+ // and args is a sequence of 3 or 4 elements.
101+
83102 // creating Regex is expensive so create hashmap for memoization
84103 let mut patterns: HashMap < String , Regex > = HashMap :: new ( ) ;
85104
86105 match args. len ( ) {
87106 3 => {
88- let string_array = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
89- let pattern_array = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
90- let replacement_array = downcast_string_arg ! ( args[ 2 ] , "replacement" , T ) ;
107+ let string_array = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
108+ let pattern_array = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
109+ let replacement_array = downcast_string_array_arg ! ( args[ 2 ] , "replacement" , T ) ;
91110
92111 let result = string_array
93112 . iter ( )
@@ -120,10 +139,10 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
120139 Ok ( Arc :: new ( result) as ArrayRef )
121140 }
122141 4 => {
123- let string_array = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
124- let pattern_array = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
125- let replacement_array = downcast_string_arg ! ( args[ 2 ] , "replacement" , T ) ;
126- let flags_array = downcast_string_arg ! ( args[ 3 ] , "flags" , T ) ;
142+ let string_array = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
143+ let pattern_array = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
144+ let replacement_array = downcast_string_array_arg ! ( args[ 2 ] , "replacement" , T ) ;
145+ let flags_array = downcast_string_array_arg ! ( args[ 3 ] , "flags" , T ) ;
127146
128147 let result = string_array
129148 . iter ( )
@@ -178,6 +197,120 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
178197 }
179198}
180199
200+ fn _regexp_replace_early_abort < T : OffsetSizeTrait > (
201+ input_array : & GenericStringArray < T > ,
202+ ) -> Result < ArrayRef > {
203+ // Mimicing the existing behavior of regexp_replace, if any of the scalar arguments
204+ // are actuall null, then the result will be an array of the same size but with nulls.
205+ Ok ( new_null_array ( input_array. data_type ( ) , input_array. len ( ) ) )
206+ }
207+
208+ fn _regexp_replace_static_pattern < T : OffsetSizeTrait > (
209+ args : & [ ArrayRef ] ,
210+ ) -> Result < ArrayRef > {
211+ // Special cased regex_replace implementation for the scenerio where
212+ // both the pattern itself and the flags are scalars. This means we can
213+ // skip regex caching system and basically hold a single Regex object
214+ // for the replace operation.
215+
216+ let string_array = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
217+ let pattern = fetch_string_arg ! ( args[ 1 ] , "pattern" , T , _regexp_replace_early_abort) ;
218+ let replacement_array = downcast_string_array_arg ! ( args[ 2 ] , "replacement" , T ) ;
219+ let flags = match args. len ( ) {
220+ 3 => None ,
221+ 4 => Some ( fetch_string_arg ! ( args[ 3 ] , "flags" , T , _regexp_replace_early_abort) ) ,
222+ other => {
223+ return Err ( DataFusionError :: Internal ( format ! (
224+ "regexp_replace was called with {} arguments. It requires at least 3 and at most 4." ,
225+ other
226+ ) ) )
227+ }
228+ } ;
229+
230+ // Embed the flag (if it exists) into the pattern
231+ let ( pattern, replace_all) = match flags {
232+ Some ( "g" ) => ( pattern. to_string ( ) , true ) ,
233+ Some ( flags) => (
234+ format ! ( "(?{}){}" , flags. to_string( ) . replace( 'g' , "" ) , pattern) ,
235+ flags. contains ( 'g' ) ,
236+ ) ,
237+ None => ( pattern. to_string ( ) , false ) ,
238+ } ;
239+
240+ let re = Regex :: new ( & pattern)
241+ . map_err ( |err| DataFusionError :: Execution ( err. to_string ( ) ) ) ?;
242+
243+ let result = string_array
244+ . iter ( )
245+ . zip ( replacement_array. iter ( ) )
246+ . map ( |( string, replacement) | match ( string, replacement) {
247+ ( Some ( string) , Some ( replacement) ) => {
248+ let replacement = regex_replace_posix_groups ( replacement) ;
249+
250+ if replace_all {
251+ Some ( re. replace_all ( string, replacement. as_str ( ) ) )
252+ } else {
253+ Some ( re. replace ( string, replacement. as_str ( ) ) )
254+ }
255+ }
256+ _ => None ,
257+ } )
258+ . collect :: < GenericStringArray < T > > ( ) ;
259+ Ok ( Arc :: new ( result) as ArrayRef )
260+ }
261+
262+ /// Determine which implementation of the regexp_replace to use based
263+ /// on the given set of arguments.
264+ pub fn specialize_regexp_replace < T : OffsetSizeTrait > (
265+ args : & [ ColumnarValue ] ,
266+ ) -> Result < ScalarFunctionImplementation > {
267+ // This will serve as a dispatch table where we can
268+ // leverage it in order to determine whether the scalarity
269+ // of the given set of arguments fits a better specialized
270+ // function.
271+ let ( is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = (
272+ matches ! ( args[ 0 ] , ColumnarValue :: Scalar ( _) ) ,
273+ matches ! ( args[ 1 ] , ColumnarValue :: Scalar ( _) ) ,
274+ matches ! ( args[ 2 ] , ColumnarValue :: Scalar ( _) ) ,
275+ // The forth argument (flags) is optional; so in the event that
276+ // it is not available, we'll claim that it is scalar.
277+ matches ! ( args. get( 3 ) , Some ( ColumnarValue :: Scalar ( _) ) | None ) ,
278+ ) ;
279+
280+ match (
281+ is_source_scalar,
282+ is_pattern_scalar,
283+ is_replacement_scalar,
284+ is_flags_scalar,
285+ ) {
286+ // This represents a very hot path for the case where the there is
287+ // a single pattern that is being matched against. This is extremely
288+ // important to specialize on since it removes the overhead of DF's
289+ // in-house regex pattern cache (since there will be at most a single
290+ // pattern).
291+ //
292+ // The flags needs to be a scalar as well since each pattern is actually
293+ // constructed with the flags embedded into the pattern itself. This means
294+ // even if the pattern itself is scalar, if the flags are an array then
295+ // we will create many regexes and it is best to use the implementation
296+ // that caches it. If there are no flags, we can simply ignore it here,
297+ // and let the specialized function handle it.
298+ ( _, true , _, true ) => {
299+ // We still don't know the scalarity of source/replacement, so we
300+ // need the adapter even if it will do some extra work for the pattern
301+ // and the flags.
302+ //
303+ // TODO: maybe we need a way of telling the adapter on which arguments
304+ // it can skip filling (so that we won't create N - 1 redundant cols).
305+ Ok ( make_scalar_function ( _regexp_replace_static_pattern :: < T > ) )
306+ }
307+
308+ // If there are no specialized implementations, we'll fall back to the
309+ // generic implementation.
310+ ( _, _, _, _) => Ok ( make_scalar_function ( regexp_replace :: < T > ) ) ,
311+ }
312+ }
313+
181314#[ cfg( test) ]
182315mod tests {
183316 use super :: * ;
0 commit comments