diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5177f1e55829e..cf63d9ea6d1b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -349,6 +349,7 @@ object FunctionRegistry { expression[StringLocate]("position"), expression[FormatString]("printf"), expression[RegExpExtract]("regexp_extract"), + expression[RegExpExtractAll]("regexp_extract_all"), expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9229ef2039fed..25e85ab9164e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import java.util.regex.{MatchResult, Pattern} +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.text.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -386,6 +388,40 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } } +abstract class RegExpExtractBase extends TernaryExpression with ImplicitCastInputTypes { + // last regex in string, we will update the pattern iff regexp value changed. + @transient protected var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient protected var pattern: Pattern = _ + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + + protected def getMatcher(s: Any, p: Any) = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + pattern.matcher(s.toString) + } + + protected def getDoGenCodeVals(ctx: CodegenContext, ev: ExprCode) = { + val classNamePattern = classOf[Pattern].getCanonicalName + val matcher = ctx.freshName("matcher") + val matchResult = ctx.freshName("matchResult") + + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + + val setEvNotNull = if (nullable) { + s"${ev.isNull} = false;" + } else { + "" + } + + (classNamePattern, matcher, matchResult, termLastRegex, termPattern, setEvNotNull) + } +} + /** * Extract a specific(idx) group identified by a Java regex. * @@ -400,21 +436,13 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio """, since = "1.5.0") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + extends RegExpExtractBase { - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString) + val m = getMatcher(s, p) if (m.find) { val mr: MatchResult = m.toMatchResult val group = mr.group(r.asInstanceOf[Int]) @@ -429,24 +457,11 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val classNamePattern = classOf[Pattern].getCanonicalName - val matcher = ctx.freshName("matcher") - val matchResult = ctx.freshName("matchResult") - - val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") - val termPattern = ctx.addMutableState(classNamePattern, "pattern") - - val setEvNotNull = if (nullable) { - s"${ev.isNull} = false;" - } else { - "" - } - + val (classNamePattern, matcher, matchResult, termLastRegex, termPattern, setEvNotNull) = + getDoGenCodeVals(ctx, ev) nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" if (!$regexp.equals($termLastRegex)) { @@ -471,3 +486,74 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio }) } } + +/** + * Extract all specific(idx) groups identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp[, idx]) - Extracts all groups that matches `regexp`.", + examples = """ + Examples: + > SELECT _FUNC_('100-200,300-400', '(\\d+)-(\\d+)', 1); + [100, 300] + """) +case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expression) + extends RegExpExtractBase { + + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + val m = getMatcher(s, p) + var groupArrayBuffer = new ArrayBuffer[UTF8String](); + + while (m.find) { + val mr: MatchResult = m.toMatchResult + val group = mr.group(r.asInstanceOf[Int]) + if (group == null) { // Pattern matched, but not optional group + groupArrayBuffer += UTF8String.EMPTY_UTF8 + } else { + groupArrayBuffer += UTF8String.fromString(group) + } + } + + new GenericArrayData(groupArrayBuffer.toArray.asInstanceOf[Array[Any]]) + } + + override def dataType: DataType = ArrayType(StringType) + override def prettyName: String = "regexp_extract_all" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val groupArray = ctx.freshName("groupArray") + val arrayClass = classOf[GenericArrayData].getName + val (classNamePattern, matcher, matchResult, termLastRegex, termPattern, setEvNotNull) = + getDoGenCodeVals(ctx, ev) + + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + |if (!$regexp.equals($termLastRegex)) { + | // regex value changed + | $termLastRegex = $regexp.clone(); + | $termPattern = $classNamePattern.compile($termLastRegex.toString()); + |} + |java.util.regex.Matcher $matcher = + | $termPattern.matcher($subject.toString()); + |java.util.ArrayList $groupArray = + | new java.util.ArrayList(); + + |while ($matcher.find()) { + | java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + | if ($matchResult.group($idx) == null) { + | $groupArray.add(UTF8String.EMPTY_UTF8); + | } else { + | $groupArray.add(UTF8String.fromString($matchResult.group($idx))); + | } + |} + |${ev.value} = new $arrayClass($groupArray.toArray(new UTF8String[$groupArray.size()])); + |$setEvNotNull + |""" + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 06fb73ad83923..2861516b0511e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -217,6 +217,36 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(nonNullExpr, "100", row1) } + test("RegexExtractAll") { + val row1 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 1) + val row2 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 2) + val row3 = create_row("100-200,300-400,500-600", "(\\d+).*", 1) + val row4 = create_row("100-200,300-400,500-600", "([a-z])", 1) + val row5 = create_row(null, "([a-z])", 1) + val row6 = create_row("100-200,300-400,500-600", null, 1) + val row7 = create_row("100-200,300-400,500-600", "([a-z])", null) + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.int.at(2) + + val expr = RegExpExtractAll(s, p, r) + checkEvaluation(expr, Seq("100", "300", "500"), row1) + checkEvaluation(expr, Seq("200", "400", "600"), row2) + checkEvaluation(expr, Seq("100"), row3) + checkEvaluation(expr, Seq(), row4) + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) + checkEvaluation(expr, null, row7) + + val expr1 = new RegExpExtractAll(s, p) + checkEvaluation(expr1, Seq("100", "300", "500"), row1) + + val nonNullExpr = RegExpExtractAll(Literal("100-200,300-400,500-600"), + Literal("(\\d+)-(\\d+)"), Literal(1)) + checkEvaluation(nonNullExpr, Seq("100", "300", "500"), row1) + } + test("SPLIT") { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1)