Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("CREATE TABLE TBL_GEOMETRY (col0 GEOMETRY)").executeUpdate()
conn.prepareStatement("INSERT INTO TBL_GEOMETRY VALUES (ST_GeomFromText('POINT(0 0)'))")
.executeUpdate()

conn.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
).executeUpdate()

conn.prepareStatement(
s"""
|INSERT INTO pattern_testing_table VALUES
|('special_character_quote\\'_present'),
|('special_character_quote_not_present'),
|('special_character_percent%_present'),
|('special_character_percent_not_present'),
|('special_character_underscore_present'),
|('special_character_underscorenot_present')
""".stripMargin).executeUpdate()
}

def testConnection(): Unit = {
Expand Down Expand Up @@ -358,6 +376,184 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
val df = spark.read.jdbc(jdbcUrl, "smallint_round_trip", new Properties)
assert(df.schema.fields.head.dataType === ShortType)
}

test("test contains pushdown") {
// this one should map to contains
val df1 = spark.sql(
s"""
|SELECT * FROM pattern_testing_table
|WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin)
df1.explain("formatted")

checkAnswer(df1, Row("special_character_quote'_present"))

val df2 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE contains(pattern_testing_col, 'percent%')""".stripMargin)
checkAnswer(df2, Row("special_character_percent%_present"))

val df3 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin)
checkAnswer(df3, Row("special_character_underscore_present"))

val df4 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE contains(pattern_testing_col, 'character')
|ORDER BY pattern_testing_col""".stripMargin)
checkAnswer(df4, Seq(
Row("special_character_percent%_present"),
Row("special_character_percent_not_present"),
Row("special_character_quote'_present"),
Row("special_character_quote_not_present"),
Row("special_character_underscore_present"),
Row("special_character_underscorenot_present")))
}

test("endswith pushdown") {
val df9 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE endswith(pattern_testing_col, 'quote\\'_present')""".stripMargin)
checkAnswer(df9, Row("special_character_quote'_present"))
val df10 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE endswith(pattern_testing_col, 'percent%_present')""".stripMargin)
checkAnswer(df10, Row("special_character_percent%_present"))
val df11 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE endswith(pattern_testing_col, 'underscore_present')""".stripMargin)
checkAnswer(df11, Row("special_character_underscore_present"))
val df12 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE endswith(pattern_testing_col, 'present')
|ORDER BY pattern_testing_col""".stripMargin)
checkAnswer(df12, Seq(
Row("special_character_percent%_present"),
Row("special_character_percent_not_present"),
Row("special_character_quote'_present"),
Row("special_character_quote_not_present"),
Row("special_character_underscore_present"),
Row("special_character_underscorenot_present")))
}

test("startswith pushdown") {
val df5 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character_quote\\'')""".stripMargin)
checkAnswer(df5, Row("special_character_quote'_present"))
val df6 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character_percent%')""".stripMargin)
checkAnswer(df6, Row("special_character_percent%_present"))
val df7 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character_underscore_')""".stripMargin)
checkAnswer(df7, Row("special_character_underscore_present"))
val df8 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character')
|ORDER BY pattern_testing_col""".stripMargin)
checkAnswer(df8, Seq(
Row("special_character_percent%_present"),
Row("special_character_percent_not_present"),
Row("special_character_quote'_present"),
Row("special_character_quote_not_present"),
Row("special_character_underscore_present"),
Row("special_character_underscorenot_present")))
}

test("test like pushdown") {
// this one should map to contains
val df1 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin)

checkAnswer(df1, Row("special_character_quote'_present"))

val df2 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin)
checkAnswer(df2, Row("special_character_percent%_present"))

val df3 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin)
checkAnswer(df3, Row("special_character_underscore_present"))

val df4 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%character%'
|ORDER BY pattern_testing_col""".stripMargin)
checkAnswer(df4, Seq(
Row("special_character_percent%_present"),
Row("special_character_percent_not_present"),
Row("special_character_quote'_present"),
Row("special_character_quote_not_present"),
Row("special_character_underscore_present"),
Row("special_character_underscorenot_present")))

// map to startsWith
// this one should map to contains
val df5 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character_quote\\'%'""".stripMargin)
checkAnswer(df5, Row("special_character_quote'_present"))
val df6 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character_percent\\%%'""".stripMargin)
checkAnswer(df6, Row("special_character_percent%_present"))
val df7 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character_underscore\\_%'""".stripMargin)
checkAnswer(df7, Row("special_character_underscore_present"))
val df8 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character%'
|ORDER BY pattern_testing_col""".stripMargin)
checkAnswer(df8, Seq(
Row("special_character_percent%_present"),
Row("special_character_percent_not_present"),
Row("special_character_quote'_present"),
Row("special_character_quote_not_present"),
Row("special_character_underscore_present"),
Row("special_character_underscorenot_present")))
// map to endsWith
// this one should map to contains
val df9 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin)
checkAnswer(df9, Row("special_character_quote'_present"))
val df10 = spark.sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin)
checkAnswer(df10, Row("special_character_percent%_present"))
val df11 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%underscore\\_present'""".stripMargin)
checkAnswer(df11, Row("special_character_underscore_present"))
val df12 = spark.
sql(
s"""SELECT * FROM pattern_testing_table
|WHERE pattern_testing_col LIKE '%present' ORDER BY pattern_testing_col""".stripMargin)
checkAnswer(df12, Seq(
Row("special_character_percent%_present"),
Row("special_character_percent_not_present"),
Row("special_character_quote'_present"),
Row("special_character_quote_not_present"),
Row("special_character_underscore_present"),
Row("special_character_underscorenot_present")))
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.Map;
import java.util.StringJoiner;

import org.apache.commons.lang3.StringUtils;

import org.apache.spark.SparkIllegalArgumentException;
import org.apache.spark.SparkUnsupportedOperationException;
import org.apache.spark.sql.connector.expressions.Cast;
Expand All @@ -43,6 +45,7 @@
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StringType;

/**
* The builder to generate SQL from V2 expressions.
Expand All @@ -65,7 +68,6 @@ protected String escapeSpecialCharsForLikePattern(String str) {
switch (c) {
case '_' -> builder.append("\\_");
case '%' -> builder.append("\\%");
case '\'' -> builder.append("\\\'");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you see the comment of escapeSpecialCharsForLikePattern ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I do. Unfortunately, ' is not a special character that should be escaped for like expression in this way for all JDBCDialects. First red flag is that H2 had to remove this change, wouldn't we expect that the special cases of JDBC have to only add characters? Second red flag was JDBCV2Suite that actually had a problem as it is not calling visitLiteral that is implemented in JDBCDialect, but the one from V2ExpressionSQLBuilder when it was displaying the pushdown result, which is why I would presume this escape was added in the first place. We need to escape ' only when we are using pure string literals, as these literals in sql come in format of 'value'. This addition to escape ' is already done in visitLiteral and should not be done here one more time.

Copy link
Contributor

@cloud-fan cloud-fan May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input of escapeSpecialCharsForLikePattern is already a valid SQL string literal (produced by visitLiteral), so the ' is already escaped.

  private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
    override def visitLiteral(literal: Literal[_]): String = {
      Option(literal.value()).map(v =>
        compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString)
        .getOrElse(super.visitLiteral(literal))
    }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

default -> builder.append(c);
}
}
Expand Down Expand Up @@ -169,7 +171,16 @@ yield visitBinaryArithmetic(
}

protected String visitLiteral(Literal<?> literal) {
return literal.toString();
String litString = literal.toString();
if (literal.dataType() instanceof StringType) {
return "'"
+ StringUtils.replace(
litString.substring(1, litString.length() - 1),
"'",
"''")
+ "'";
}
return litString;
}

protected String visitNamedReference(NamedReference namedRef) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,6 @@ private[sql] case class H2Dialect() extends JdbcDialect {
}

class H2SQLBuilder extends JDBCSQLBuilder {
override def escapeSpecialCharsForLikePattern(str: String): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have noticed this at the beginning... This bug is hidden because we fixed it only for H2 and we only test it with H2.

str.map {
case '_' => "\\_"
case '%' => "\\%"
case c => c.toString
}.mkString
}

override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper {
}
}

override def visitStartsWith(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
}

override def visitEndsWith(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}' ESCAPE '\\\\'"
}

override def visitContains(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
}

override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val df5 = spark.table("h2.test.address").filter($"email".startsWith("abc_'%"))
checkFiltersRemoved(df5)
checkPushedInfo(df5,
raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_\'\%%' ESCAPE '\']")
raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_''\%%' ESCAPE '\']")
checkAnswer(df5, Seq(Row("abc_'%[email protected]")))

val df6 = spark.table("h2.test.address").filter($"email".endsWith("[email protected]"))
Expand Down Expand Up @@ -1336,7 +1336,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val df10 = spark.table("h2.test.address").filter($"email".endsWith("_'%[email protected]"))
checkFiltersRemoved(df10)
checkPushedInfo(df10,
raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_\'\%[email protected]' ESCAPE '\']")
raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_''\%[email protected]' ESCAPE '\']")
checkAnswer(df10, Seq(Row("abc_'%[email protected]")))

val df11 = spark.table("h2.test.address").filter($"email".contains("c_d"))
Expand Down Expand Up @@ -1364,7 +1364,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val df15 = spark.table("h2.test.address").filter($"email".contains("c_'%d"))
checkFiltersRemoved(df15)
checkPushedInfo(df15,
raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\'\%d%' ESCAPE '\']")
raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_''\%d%' ESCAPE '\']")
checkAnswer(df15, Seq(Row("abc_'%[email protected]")))
}

Expand Down