diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 6c1b7fdd1be5..5cdd8fa8fd9f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -65,6 +65,12 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { connection.prepareStatement( "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)") .executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col LONGTEXT + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala index 72edfc9f1bf1..a42caeafe6fe 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala @@ -38,6 +38,17 @@ abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite { .executeUpdate() connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)") .executeUpdate() + + connection.prepareStatement( + s""" + |INSERT INTO pattern_testing_table VALUES + |('special_character_quote\\'_present'), + |('special_character_quote_not_present'), + |('special_character_percent%_present'), + |('special_character_percent_not_present'), + |('special_character_underscore_present'), + |('special_character_underscorenot_present') + """.stripMargin).executeUpdate() } def tablePreparation(connection: Connection): Unit diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 65f7579de820..8c82e4faa7f4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -74,6 +74,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD connection.prepareStatement( "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") .executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col LONGTEXT + |) + """.stripMargin + ).executeUpdate() } override def notSupportsTableComment: Boolean = true diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 4997d335fda6..4c1d4924a41c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -77,6 +77,12 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest connection.prepareStatement( "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," + " bonus DOUBLE)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col LONGTEXT + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index a011afac1772..a14c765d76ad 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -97,6 +97,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes connection.prepareStatement( "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + " bonus BINARY_DOUBLE)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col LONGTEXT + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 1f09c2fd3fc5..24ad97af6449 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -59,6 +59,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT connection.prepareStatement( "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + " bonus double precision)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col LONGTEXT + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index c80fbfc748dd..2b61dd4b5515 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -359,6 +359,235 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(scan.schema.names.sameElements(Seq(col))) } + test("SPARK-48172: Test CONTAINS") { + val df1 = spark.sql( + s""" + |SELECT * FROM $catalogName.pattern_testing_table + |WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin) + df1.explain("formatted") + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE contains(pattern_testing_col, 'percent%')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE contains(pattern_testing_col, 'character') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 1) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test ENDSWITH") { + val df1 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE endswith(pattern_testing_col, 'quote\\'_present')""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE endswith(pattern_testing_col, 'percent%_present')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE endswith(pattern_testing_col, 'underscore_present')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE endswith(pattern_testing_col, 'present') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 1) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test STARTSWITH") { + val df1 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE startswith(pattern_testing_col, 'special_character_quote\\'')""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE startswith(pattern_testing_col, 'special_character_percent%')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE startswith(pattern_testing_col, 'special_character_underscore_')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE startswith(pattern_testing_col, 'special_character') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 1) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test LIKE") { + // this one should map to contains + val df1 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%character%' + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 1) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + + // map to startsWith + // this one should map to contains + val df5 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE 'special_character_quote\\'%'""".stripMargin) + val rows5 = df5.collect() + assert(rows5.length === 1) + assert(rows5(0).getString(0) === "special_character_quote'_present") + + val df6 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE 'special_character_percent\\%%'""".stripMargin) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "special_character_percent%_present") + + val df7 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE 'special_character_underscore\\_%'""".stripMargin) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "special_character_underscore_present") + + val df8 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE 'special_character%' + |ORDER BY pattern_testing_col""".stripMargin) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "special_character_percent%_present") + assert(rows8(1).getString(0) === "special_character_percent_not_present") + assert(rows8(2).getString(0) === "special_character_quote'_present") + assert(rows8(3).getString(0) === "special_character_quote_not_present") + assert(rows8(4).getString(0) === "special_character_underscore_present") + assert(rows8(5).getString(0) === "special_character_underscorenot_present") + // map to endsWith + // this one should map to contains + val df9 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "special_character_quote'_present") + + val df10 = spark.sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin) + val rows10 = df10.collect() + assert(rows10.length === 1) + assert(rows10(0).getString(0) === "special_character_percent%_present") + + val df11 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%underscore\\_present'""".stripMargin) + val rows11 = df11.collect() + assert(rows11.length === 1) + assert(rows11(0).getString(0) === "special_character_underscore_present") + + val df12 = spark. + sql( + s"""SELECT * FROM $catalogName.pattern_testing_table + |WHERE pattern_testing_col LIKE '%present' ORDER BY pattern_testing_col""".stripMargin) + val rows12 = df12.collect() + assert(rows12.length === 1) + assert(rows12(0).getString(0) === "special_character_percent%_present") + assert(rows12(1).getString(0) === "special_character_percent_not_present") + assert(rows12(2).getString(0) === "special_character_quote'_present") + assert(rows12(3).getString(0) === "special_character_quote_not_present") + assert(rows12(4).getString(0) === "special_character_underscore_present") + assert(rows12(5).getString(0) === "special_character_underscorenot_present") + } + test("SPARK-37038: Test TABLESAMPLE") { if (supportsTableSample) { withTable(s"$catalogName.new_table") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index e42d9193ea39..11f4389245d9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -65,7 +65,6 @@ protected String escapeSpecialCharsForLikePattern(String str) { switch (c) { case '_' -> builder.append("\\_"); case '%' -> builder.append("\\%"); - case '\'' -> builder.append("\\\'"); default -> builder.append(c); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index fc41d5a98e4a..b43e627c0eec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions +import org.apache.commons.lang3.StringUtils + import org.apache.spark.SparkException import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -388,7 +390,7 @@ private[sql] object HoursTransform { private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { override def toString: String = { if (dataType.isInstanceOf[StringType]) { - s"'$value'" + s"'${StringUtils.replace(s"$value", "'", "''")}'" } else { s"$value" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index ebfc6093dc16..949455b248ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -259,13 +259,6 @@ private[sql] case class H2Dialect() extends JdbcDialect { } class H2SQLBuilder extends JDBCSQLBuilder { - override def escapeSpecialCharsForLikePattern(str: String): String = { - str.map { - case '_' => "\\_" - case '%' => "\\%" - case c => c.toString - }.mkString - } override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index d98fcdfd0b23..50951042737a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -66,6 +66,21 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { } } + override def visitStartsWith(l: String, r: String): String = { + val value = r.substring(1, r.length() - 1) + s"$l LIKE '${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'" + } + + override def visitEndsWith(l: String, r: String): String = { + val value = r.substring(1, r.length() - 1) + s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}' ESCAPE '\\\\'" + } + + override def visitContains(l: String, r: String): String = { + val value = r.substring(1, r.length() - 1) + s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'" + } + override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 1b3672cdba5a..8e98181a9802 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -1305,7 +1305,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df5 = spark.table("h2.test.address").filter($"email".startsWith("abc_'%")) checkFiltersRemoved(df5) checkPushedInfo(df5, - raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_\'\%%' ESCAPE '\']") + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_''\%%' ESCAPE '\']") checkAnswer(df5, Seq(Row("abc_'%def@gmail.com"))) val df6 = spark.table("h2.test.address").filter($"email".endsWith("_def@gmail.com")) @@ -1336,7 +1336,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df10 = spark.table("h2.test.address").filter($"email".endsWith("_'%def@gmail.com")) checkFiltersRemoved(df10) checkPushedInfo(df10, - raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_\'\%def@gmail.com' ESCAPE '\']") + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_''\%def@gmail.com' ESCAPE '\']") checkAnswer(df10, Seq(Row("abc_'%def@gmail.com"))) val df11 = spark.table("h2.test.address").filter($"email".contains("c_d")) @@ -1364,7 +1364,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df15 = spark.table("h2.test.address").filter($"email".contains("c_'%d")) checkFiltersRemoved(df15) checkPushedInfo(df15, - raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\'\%d%' ESCAPE '\']") + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_''\%d%' ESCAPE '\']") checkAnswer(df15, Seq(Row("abc_'%def@gmail.com"))) }