diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 13a9c752e3..bbb3d9aa34 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -294,6 +294,8 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.RLike.enabled` | Enable Comet acceleration for `RLike` | true | | `spark.comet.expression.Rand.enabled` | Enable Comet acceleration for `Rand` | true | | `spark.comet.expression.Randn.enabled` | Enable Comet acceleration for `Randn` | true | +| `spark.comet.expression.RegExpExtract.enabled` | Enable Comet acceleration for `RegExpExtract` | true | +| `spark.comet.expression.RegExpExtractAll.enabled` | Enable Comet acceleration for `RegExpExtractAll` | true | | `spark.comet.expression.RegExpReplace.enabled` | Enable Comet acceleration for `RegExpReplace` | true | | `spark.comet.expression.Remainder.enabled` | Enable Comet acceleration for `Remainder` | true | | `spark.comet.expression.Reverse.enabled` | Enable Comet acceleration for `Reverse` | true | diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 56b12a9e48..cccb6220a4 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,7 +23,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, - SparkStringSpace, + SparkRegExpExtract, SparkRegExpExtractAll, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -194,6 +194,8 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtract::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtractAll::default())), ] } diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..2026ec5fec 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod regexp_extract; mod string_space; mod substring; +pub use regexp_extract::{SparkRegExpExtract, SparkRegExpExtractAll}; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs new file mode 100644 index 0000000000..38c80d8129 --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -0,0 +1,599 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, GenericStringArray, GenericStringBuilder, ListArray, OffsetSizeTrait, +}; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion::common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::logical_expr_common::signature::TypeSignature::Exact; +use regex::Regex; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible regexp_extract function +/// +/// Extracts a substring matching a [regular expression](https://docs.rs/regex/latest/regex/#syntax) +/// and returns the specified capture group. +/// +/// The function signature is: `regexp_extract(str, regexp, idx)` +/// where: +/// - `str`: The input string to search in +/// - `regexp`: The regular expression pattern (must be a literal) +/// - `idx`: The capture group index (0 for the entire match, must be a literal) +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtract { + signature: Signature, +} + +impl Default for SparkRegExpExtract { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtract { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int32]), + Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(match &_arg_types[0] { + DataType::Utf8 => DataType::Utf8, + DataType::LargeUtf8 => DataType::LargeUtf8, + _ => { + return exec_err!( + "regexp_extract expects utf8 or largeutf8 input but got {:?}", + _arg_types[0] + ) + } + }) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = &args.args; + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let is_scalar = len.is_none(); + let result = match args[0].data_type() { + DataType::Utf8 => regexp_extract_func::(args), + DataType::LargeUtf8 => regexp_extract_func::(args), + _ => { + return exec_err!( + "regexp_extract expects the data type of subject to be utf8 or largeutf8 but got {:?}", + args[0].data_type() + ); + } + }; + if is_scalar { + result + .and_then(|arr| ScalarValue::try_from_array(&arr, 0)) + .map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} + +/// Spark-compatible regexp_extract_all function +/// +/// Extracts all substrings matching a [regular expression](https://docs.rs/regex/latest/regex/#syntax) +/// and returns them as an array. +/// +/// The function signature is: `regexp_extract_all(str, regexp, idx)` +/// where: +/// - `str`: The input string to search in +/// - `regexp`: The regular expression pattern (must be a literal) +/// - `idx`: The capture group index (0 for the entire match, must be a literal) +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtractAll { + signature: Signature, +} + +impl Default for SparkRegExpExtractAll { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtractAll { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int32]), + Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtractAll { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(match &_arg_types[0] { + DataType::Utf8 => DataType::List(Arc::new(arrow::datatypes::Field::new( + "item", + DataType::Utf8, + false, + ))), + DataType::LargeUtf8 => DataType::List(Arc::new(arrow::datatypes::Field::new( + "item", + DataType::LargeUtf8, + false, + ))), + _ => { + return exec_err!( + "regexp_extract_all expects utf8 or largeutf8 input but got {:?}", + _arg_types[0] + ) + } + }) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = &args.args; + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let is_scalar = len.is_none(); + let result = match args[0].data_type() { + DataType::Utf8 => regexp_extract_all_func::(args), + DataType::LargeUtf8 => regexp_extract_all_func::(args), + _ => { + return exec_err!( + "regexp_extract_all expects the data type of subject to be utf8 or largeutf8 but got {:?}", + args[0].data_type() + ); + } + }; + if is_scalar { + result + .and_then(|arr| ScalarValue::try_from_array(&arr, 0)) + .map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} + +// Helper functions + +fn regexp_extract_func(args: &[ColumnarValue]) -> Result { + let (subject, regex, idx) = parse_args(args, "regexp_extract")?; + + let subject_array = match subject { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, + }; + + regexp_extract_array::(&subject_array, ®ex, idx) +} + +fn regexp_extract_all_func(args: &[ColumnarValue]) -> Result { + let (subject, regex, idx) = parse_args(args, "regexp_extract_all")?; + + let subject_array = match subject { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, + }; + + regexp_extract_all_array::(&subject_array, ®ex, idx) +} + +fn parse_args<'a>( + args: &'a [ColumnarValue], + fn_name: &str, +) -> Result<(&'a ColumnarValue, Regex, i32)> { + if args.len() != 3 { + return exec_err!("{} expects 3 arguments, got {}", fn_name, args.len()); + } + + let subject = &args[0]; + let pattern = &args[1]; + let idx = &args[2]; + + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("{} pattern must be a string literal", fn_name); + } + }; + + let idx_val = match idx { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, + _ => { + return exec_err!("{} idx must be an integer literal", fn_name); + } + }; + if idx_val < 0 { + return exec_err!("{fn_name} group index must be non-negative"); + } + + let regex = Regex::new(&pattern_str).map_err(|e| { + DataFusionError::Execution(format!("Invalid regex pattern '{}': {}", pattern_str, e)) + })?; + + Ok((subject, regex, idx_val)) +} + +fn regexp_extract_array( + array: &ArrayRef, + regex: &Regex, + idx: i32, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Execution("regexp_extract expects string array input".to_string()) + })?; + + let mut builder = GenericStringBuilder::::new(); + for s in string_array.iter() { + match s { + Some(text) => { + let extracted = regexp_extract(text, regex, idx)?; + builder.append_value(extracted); + } + None => { + builder.append_null(); + } + } + } + + Ok(Arc::new(builder.finish())) +} + +fn regexp_extract(text: &str, regex: &Regex, idx: i32) -> Result { + let idx = idx as usize; + match regex.captures(text) { + Some(caps) => { + // Spark behavior: throw error if group index is out of bounds + let group_cnt = caps.len() - 1; + if idx > group_cnt { + return exec_err!( + "Regex group index out of bounds, group count: {}, index: {}", + group_cnt, + idx + ); + } + Ok(caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default()) + } + None => { + // No match: return empty string (Spark behavior) + Ok(String::new()) + } + } +} + +fn regexp_extract_all_array( + array: &ArrayRef, + regex: &Regex, + idx: i32, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Execution("regexp_extract_all expects string array input".to_string()) + })?; + + let item_data_type = match array.data_type() { + DataType::Utf8 => DataType::Utf8, + DataType::LargeUtf8 => DataType::LargeUtf8, + _ => { + return exec_err!( + "regexp_extract_all expects utf8 or largeutf8 array but got {:?}", + array.data_type() + ); + } + }; + let item_field = Arc::new(arrow::datatypes::Field::new("item", item_data_type, false)); + + let string_builder = GenericStringBuilder::::new(); + let mut list_builder = + arrow::array::ListBuilder::new(string_builder).with_field(Arc::clone(&item_field)); + + for s in string_array.iter() { + match s { + Some(text) => { + let matches = regexp_extract_all(text, regex, idx)?; + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); + } + None => { + list_builder.append(false); + } + } + } + + let list_array = list_builder.finish(); + + // Manually create a new ListArray with the correct field schema to ensure nullable is false + // This ensures the schema matches what we declared in return_type + Ok(Arc::new(ListArray::new( + FieldRef::from(Arc::clone(&item_field)), + list_array.offsets().clone(), + Arc::clone(list_array.values()), + list_array.nulls().cloned(), + ))) +} + +fn regexp_extract_all(text: &str, regex: &Regex, idx: i32) -> Result> { + let idx = idx as usize; + let mut results = Vec::new(); + + for caps in regex.captures_iter(text) { + // Check bounds for each capture (matches Spark behavior) + let group_num = caps.len() - 1; + if idx > group_num { + return exec_err!( + "Regex group index out of bounds, group count: {}, index: {}", + group_num, + idx + ); + } + + let matched = caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(); + results.push(matched); + } + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{LargeStringArray, StringArray}; + + #[test] + fn test_regexp_extract_basic() { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + + // Spark behavior: return "" on no match, not None + assert_eq!(regexp_extract("123-abc", ®ex, 0).unwrap(), "123-abc"); + assert_eq!(regexp_extract("123-abc", ®ex, 1).unwrap(), "123"); + assert_eq!(regexp_extract("123-abc", ®ex, 2).unwrap(), "abc"); + assert_eq!(regexp_extract("no match", ®ex, 0).unwrap(), ""); // no match → "" + + // Spark behavior: group index out of bounds → error + assert!(regexp_extract("123-abc", ®ex, 3).is_err()); + assert!(regexp_extract("123-abc", ®ex, 99).is_err()); + assert!(regexp_extract("123-abc", ®ex, -1).is_err()); + } + + #[test] + fn test_regexp_extract_all_basic() { + let regex = Regex::new(r"(\d+)").unwrap(); + + // Multiple matches + let matches = regexp_extract_all("a1b2c3", ®ex, 0).unwrap(); + assert_eq!(matches, vec!["1", "2", "3"]); + + // Same with group index 1 + let matches = regexp_extract_all("a1b2c3", ®ex, 1).unwrap(); + assert_eq!(matches, vec!["1", "2", "3"]); + + // No match: returns empty vec, not error + let matches = regexp_extract_all("no digits", ®ex, 0).unwrap(); + assert!(matches.is_empty()); + assert_eq!(matches, Vec::::new()); + + // Group index out of bounds → error + assert!(regexp_extract_all("a1b2c3", ®ex, 2).is_err()); + } + + #[test] + fn test_regexp_extract_all_array() -> Result<()> { + use datafusion::common::cast::as_list_array; + + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("a1b2"), + Some("no digits"), + None, + Some("c3d4e5"), + ])) as ArrayRef; + + let result = regexp_extract_all_array::(&array, ®ex, 0)?; + let list_array = as_list_array(&result)?; + + // Row 0: "a1b2" → ["1", "2"] + let row0 = list_array.value(0); + let row0_str = row0 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row0_str.len(), 2); + assert_eq!(row0_str.value(0), "1"); + assert_eq!(row0_str.value(1), "2"); + + // Row 1: "no digits" → [] (empty array, not NULL) + let row1 = list_array.value(1); + let row1_str = row1 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row1_str.len(), 0); // Empty array + assert!(!list_array.is_null(1)); // Not NULL, just empty + + // Row 2: NULL input → NULL output + assert!(list_array.is_null(2)); + + // Row 3: "c3d4e5" → ["3", "4", "5"] + let row3 = list_array.value(3); + let row3_str = row3 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row3_str.len(), 3); + assert_eq!(row3_str.value(0), "3"); + assert_eq!(row3_str.value(1), "4"); + assert_eq!(row3_str.value(2), "5"); + + Ok(()) + } + + #[test] + fn test_regexp_extract_array() -> Result<()> { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("123-abc"), + Some("456-def"), + None, + Some("no-match"), + ])) as ArrayRef; + + let result = regexp_extract_array::(&array, ®ex, 1)?; + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "123"); + assert_eq!(result_array.value(1), "456"); + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + + Ok(()) + } + + #[test] + fn test_regexp_extract_largeutf8() -> Result<()> { + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(LargeStringArray::from(vec![ + Some("a1b2c3"), + Some("x5y6"), + None, + Some("no digits"), + ])) as ArrayRef; + + let result = regexp_extract_array::(&array, ®ex, 1)?; + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "1"); // First digit from "a1b2c3" + assert_eq!(result_array.value(1), "5"); // First digit from "x5y6" + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + + Ok(()) + } + + #[test] + fn test_regexp_extract_all_largeutf8() -> Result<()> { + use datafusion::common::cast::as_list_array; + + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(LargeStringArray::from(vec![ + Some("a1b2c3"), + Some("x5y6"), + None, + Some("no digits"), + ])) as ArrayRef; + + let result = regexp_extract_all_array::(&array, ®ex, 0)?; + let list_array = as_list_array(&result)?; + + // Row 0: "a1b2c3" → ["1", "2", "3"] (all matches) + let row0 = list_array.value(0); + let row0_str = row0 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row0_str.len(), 3); + assert_eq!(row0_str.value(0), "1"); + assert_eq!(row0_str.value(1), "2"); + assert_eq!(row0_str.value(2), "3"); + + // Row 1: "x5y6" → ["5", "6"] (all matches) + let row1 = list_array.value(1); + let row1_str = row1 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row1_str.len(), 2); + assert_eq!(row1_str.value(0), "5"); + assert_eq!(row1_str.value(1), "6"); + + // Row 2: NULL input → NULL output + assert!(list_array.is_null(2)); + + // Row 3: "no digits" → [] (empty array, not NULL) + let row3 = list_array.value(3); + let row3_str = row3 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row3_str.len(), 0); // Empty array + assert!(!list_array.is_null(3)); // Not NULL, just empty + + Ok(()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 83917d33fc..9514c47da9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -154,6 +154,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 15f4b238f2..51b41e2593 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -286,3 +286,81 @@ trait CommonStringExprs { } } } + +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + // Check if the pattern is compatible with Spark or allow incompatible patterns + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal + expr.idx match { + case Literal(_, DataTypes.IntegerType) => + Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) + } + } + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = scalarFunctionExprToProto("regexp_extract", subjectExpr, patternExpr, idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} + +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + // Check if the pattern is compatible with Spark or allow incompatible patterns + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal + // For regexp_extract_all, idx will default to 1 if not specified + expr.idx match { + case Literal(_, DataTypes.IntegerType) => + Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) + } + } + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = + scalarFunctionExprToProto("regexp_extract_all", subjectExpr, patternExpr, idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index f9882780c8..0bdaba62b1 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -391,4 +391,315 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("regexp_extract basic") { + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("100-200", 1), + ("300-400", 1), + (null, 1), // NULL input + ("no-match", 1), // no match → should return "" + ("abc123def456", 1), + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test basic extraction: group 0 (full match) + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") + // Test group 2 + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") + // Test non-existent group → should error + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, NULL, 0) FROM tbl") + } + } + } + + test("regexp_extract edge cases") { + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = + Seq(("email@example.com", 1), ("phone: 123-456-7890", 1), ("price: $99.99", 1), (null, 1)) + + withParquetTable(data, "tbl") { + // Extract email domain + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") + // Extract phone number + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{3}-\\d{3}-\\d{4})', 1) FROM tbl") + // Extract price + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all basic") { + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("a1b2c3", 1), + ("test123test456", 1), + (null, 1), // NULL input + ("no digits", 1), // no match → should return [] + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test with explicit group 0 (full match on no-group pattern) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") + // Test with explicit group 0 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, NULL, 0) FROM tbl") + } + } + } + + test("regexp_extract_all multiple matches") { + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("The prices are $10, $20, and $30", 1), + ("colors: red, green, blue", 1), + ("words: hello world", 1), + (null, 1)) + + withParquetTable(data, "tbl") { + // Extract all prices + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") + // Extract all words + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all with dictionary encoding") { + withSQLConf( + CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", + "parquet.enable.dictionary" -> "true") { + // Use repeated values to trigger dictionary encoding + // Mix short strings, long strings, and various patterns + val longString1 = "prefix" + ("abc" * 100) + "123" + ("xyz" * 100) + "456" + val longString2 = "start" + ("test" * 200) + "789" + ("end" * 150) + + val data = (0 until 2000).map(i => { + val text = i % 7 match { + case 0 => "a1b2c3" // Simple repeated pattern + case 1 => "x5y6" // Another simple pattern + case 2 => "no-match" // No digits + case 3 => longString1 // Long string with digits + case 4 => longString2 // Another long string + case 5 => "email@test.com-phone:123-456-7890" // Complex pattern + case 6 => "" // Empty string + } + (text, 1) + }) + + withParquetTable(data, "tbl") { + // Test simple pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)') FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + + // Test complex patterns + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '(\\d{3}-\\d{3}-\\d{4})', 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '@([a-z]*)', 1) FROM tbl") + + // Test with multiple groups + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d*)', 1) FROM tbl") + } + } + } + + test("regexp_extract with dictionary encoding") { + withSQLConf( + CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", + "parquet.enable.dictionary" -> "true") { + // Use repeated values to trigger dictionary encoding + // Mix short and long strings with various patterns + val longString1 = "data" + ("x" * 500) + "999" + ("y" * 500) + val longString2 = ("a" * 1000) + "777" + ("b" * 1000) + + val data = (0 until 2000).map(i => { + val text = i % 7 match { + case 0 => "a1b2c3" + case 1 => "x5y6" + case 2 => "no-match" + case 3 => longString1 + case 4 => longString2 + case 5 => "IP:192.168.1.100-PORT:8080" + case 6 => "" + } + (text, 1) + }) + + withParquetTable(data, "tbl") { + // Test extracting first match with simple pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)', 1) FROM tbl") + + // Test with complex patterns + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)\\.(\\d+)\\.(\\d+)\\.(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, 'PORT:(\\d+)', 1) FROM tbl") + + // Test with multiple groups - extract second group + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '([a-z])(\\d+)', 2) FROM tbl") + } + } + } + + test("regexp_extract unicode and special characters") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("测试123test", 1), // Chinese characters + ("日本語456にほんご", 1), // Japanese characters + ("한글789Korean", 1), // Korean characters + ("Привет999Hello", 1), // Cyrillic + ("line1\nline2", 1), // Newline + ("tab\there", 1), // Tab + ("special: $#@!%^&*", 1), // Special chars + ("mixed测试123test日本語", 1), // Mixed unicode + (null, 1)) + + withParquetTable(data, "tbl") { + // Extract digits from unicode text + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + + // Test word boundaries with unicode + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '([a-zA-Z]+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-zA-Z]+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all multiple groups") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("a1b2c3", 1), + ("x5y6z7", 1), + ("test123demo456end789", 1), + (null, 1), + ("no match here", 1)) + + withParquetTable(data, "tbl") { + // Test extracting different groups - full match + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d+)', 0) FROM tbl") + // Test extracting group 1 (letters) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d+)', 1) FROM tbl") + // Test extracting group 2 (digits) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d+)', 2) FROM tbl") + + // Test with three groups + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)(\\d+)([a-z]+)', 1) FROM tbl") + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)(\\d+)([a-z]+)', 2) FROM tbl") + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)(\\d+)([a-z]+)', 3) FROM tbl") + } + } + } + + test("regexp_extract_all group index out of bounds") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq(("a1b2c3", 1), ("test123", 1), (null, 1)) + + withParquetTable(data, "tbl") { + // Group index out of bounds - should match Spark's behavior (error) + // Pattern has only 1 group, asking for group 2 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 2) FROM tbl") + + // Pattern has no groups, asking for group 1 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 1) FROM tbl") + } + } + } + + test("regexp_extract complex patterns") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("2024-01-15", 1), // Date + ("192.168.1.1", 1), // IP address + ("user@domain.co.uk", 1), // Complex email + ("content", 1), // HTML-like + ("Time: 14:30:45.123", 1), // Timestamp + ("Version: 1.2.3-beta", 1), // Version string + ("RGB(255,128,0)", 1), // RGB color + (null, 1)) + + withParquetTable(data, "tbl") { + // Extract year from date + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{4})-(\\d{2})-(\\d{2})', 1) FROM tbl") + + // Extract month from date + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{4})-(\\d{2})-(\\d{2})', 2) FROM tbl") + + // Extract IP octets + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)\\.(\\d+)\\.(\\d+)\\.(\\d+)', 2) FROM tbl") + + // Extract email domain + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '@([a-z.]+)', 1) FROM tbl") + + // Extract time components + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{2}):(\\d{2}):(\\d{2})', 1) FROM tbl") + + // Extract RGB values + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, 'RGB\\((\\d+),(\\d+),(\\d+)\\)', 2) FROM tbl") + + // Test regexp_extract_all with complex patterns + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract vs regexp_extract_all comparison") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq(("a1b2c3", 1), ("x5y6", 1), (null, 1), ("no digits", 1), ("single7match", 1)) + + withParquetTable(data, "tbl") { + // Compare single extraction vs all extractions in one query + checkSparkAnswerAndOperator("""SELECT + | regexp_extract(_1, '(\\d+)', 1) as first_match, + | regexp_extract_all(_1, '(\\d+)', 1) as all_matches + |FROM tbl""".stripMargin) + + // Verify regexp_extract returns first match only while regexp_extract_all returns all + checkSparkAnswerAndOperator("""SELECT + | _1, + | regexp_extract(_1, '(\\d+)', 1) as first_digit, + | regexp_extract_all(_1, '(\\d+)', 1) as all_digits + |FROM tbl""".stripMargin) + + // Test with multiple groups + checkSparkAnswerAndOperator("""SELECT + | regexp_extract(_1, '([a-z])(\\d+)', 1) as first_letter, + | regexp_extract_all(_1, '([a-z])(\\d+)', 1) as all_letters, + | regexp_extract(_1, '([a-z])(\\d+)', 2) as first_digit, + | regexp_extract_all(_1, '([a-z])(\\d+)', 2) as all_digits + |FROM tbl""".stripMargin) + } + } + } + }