Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ package object dsl {
}

def like(other: Expression, escapeChar: Char = '\\'): Expression =
Like(expr, other, Literal(escapeChar.toString))
Like(expr, other, escapeChar)
def rlike(other: Expression): Expression = RLike(expr, other)
def contains(other: Expression): Expression = Contains(expr, other)
def startsWith(other: Expression): Expression = StartsWith(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,24 @@ import java.util.regex.{MatchResult, Pattern}

import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


trait StringRegexExpression extends Expression
abstract class StringRegexExpression extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {

def str: Expression
def pattern: Expression

def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean

override def dataType: DataType = BooleanType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

// try cache foldable pattern
private lazy val cache: Pattern = pattern match {
private lazy val cache: Pattern = right match {
case p: Expression if p.foldable =>
compile(p.eval().asInstanceOf[UTF8String].toString)
case _ => null
Expand All @@ -55,17 +52,18 @@ trait StringRegexExpression extends Expression
Pattern.compile(escape(str))
}

def nullSafeMatch(input1: Any, input2: Any): Any = {
val s = input2.asInstanceOf[UTF8String].toString
val regex = if (cache == null) compile(s) else cache
protected def pattern(str: String) = if (cache == null) compile(str) else cache

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val regex = pattern(input2.asInstanceOf[UTF8String].toString)
if(regex == null) {
null
} else {
matches(regex, input1.asInstanceOf[UTF8String].toString)
}
}

override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}"
override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}"
}

// scalastyle:off line.contains.tab
Expand Down Expand Up @@ -110,65 +108,46 @@ trait StringRegexExpression extends Expression
true
> SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/';
true
> SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_');
true
""",
note = """
Use RLIKE to match with standard regular expressions.
""",
since = "1.0.0")
// scalastyle:on line.contains.tab
case class Like(str: Expression, pattern: Expression, escape: Expression)
extends TernaryExpression with StringRegexExpression {

def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\"))

override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
override def children: Seq[Expression] = Seq(str, pattern, escape)
case class Like(left: Expression, right: Expression, escapeChar: Char)
extends StringRegexExpression {

private lazy val escapeChar: Char = if (escape.foldable) {
escape.eval() match {
case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0)
case s => throw new AnalysisException(
s"The 'escape' parameter must be a string literal of one char but it is $s.")
}
} else {
throw new AnalysisException("The 'escape' parameter must be a string literal.")
}
def this(left: Expression, right: Expression) = this(left, right, '\\')

override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)

override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()

override def toString: String = escapeChar match {
case '\\' => s"$str LIKE $pattern"
case c => s"$str LIKE $pattern ESCAPE '$c'"
}

protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
nullSafeMatch(input1, input2)
case '\\' => s"$left LIKE $right"
case c => s"$left LIKE $right ESCAPE '$c'"
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"

if (pattern.foldable) {
val patternVal = pattern.eval()
if (patternVal != null) {
if (right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString()))
val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern",
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
val pattern = ctx.addMutableState(patternClass, "patternLike",
v => s"""$v = $patternClass.compile("$regexStr");""")

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = str.genCode(ctx)
val eval = left.genCode(ctx)
ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches();
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
}
""")
} else {
Expand All @@ -178,8 +157,8 @@ case class Like(str: Expression, pattern: Expression, escape: Expression)
""")
}
} else {
val patternStr = ctx.freshName("patternStr")
val compiledPattern = ctx.freshName("compiledPattern")
val pattern = ctx.freshName("pattern")
val rightStr = ctx.freshName("rightStr")
// We need double escape to avoid org.codehaus.commons.compiler.CompileException.
// '\\' will cause exception 'Single quote must be backslash-escaped in character literal'.
// '\"' will cause exception 'Line break in literal not allowed'.
Expand All @@ -188,12 +167,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression)
} else {
escapeChar
}
nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $patternStr = $eval2.toString();
$patternClass $compiledPattern = $patternClass.compile(
$escapeFunc($patternStr, '$newEscapeChar'));
${ev.value} = $compiledPattern.matcher($eval1.toString()).matches();
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile(
$escapeFunc($rightStr, '$newEscapeChar'));
${ev.value} = $pattern.matcher($eval1.toString()).matches();
Copy link
Member

@dongjoon-hyun dongjoon-hyun Feb 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for me: This part is a logical reverting.

"""
})
}
Expand Down Expand Up @@ -232,20 +211,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression)
""",
since = "1.0.0")
// scalastyle:on line.contains.tab
case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {

override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

override def str: Expression = left
override def pattern: Expression = right
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {

override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
override def toString: String = s"$left RLIKE $right"

protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
throw new ParseException("Invalid escape string." +
"Escape string must contains only one character.", ctx)
}
str
str.charAt(0)
}.getOrElse('\\')
invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar)))
invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar))
case SqlBaseParser.RLIKE =>
invertIfNotDefined(RLike(e, expression(ctx.pattern)))
case SqlBaseParser.NULL if ctx.NOT != null =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3560,21 +3560,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Seq(Row(1)))
}

test("the like function with the escape parameter") {
val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape")
checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true))

val longEscapeError = intercept[AnalysisException] {
df.selectExpr("like(str, pattern, '@%')").collect()
}.getMessage
assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char"))

val nonFoldableError = intercept[AnalysisException] {
df.selectExpr("like(str, pattern, escape)").collect()
}.getMessage
assert(nonFoldableError.contains("The 'escape' parameter must be a string literal"))
}

test("SPARK-29462: Empty array of NullType for array function with no arguments") {
Seq((true, StringType), (false, NullType)).foreach {
case (arrayDefaultToString, expectedType) =>
Expand Down