diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 0aa971c0d3ab1..56cadbe8e2c07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.SQLException +import java.sql.{SQLException, Types} import java.util.Locale import scala.util.control.NonFatal @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchT import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType} private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = @@ -90,6 +92,15 @@ private object H2Dialect extends JdbcDialect { ) } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", Types.CLOB)) + case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) + case ShortType | ByteType => Some(JdbcType("SMALLINT", Types.SMALLINT)) + case t: DecimalType => Some( + JdbcType(s"NUMERIC(${t.precision},${t.scale})", Types.NUMERIC)) + case _ => JdbcUtils.getCommonJDBCType(dt) + } + override def classifyException(message: String, e: Throwable): AnalysisException = { e match { case exception: SQLException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index e5e9c32ff62f4..fd186b764fb4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -522,6 +522,23 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df7, false) checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]") checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) + + val df8 = sql( + """ + |SELECT * FROM h2.test.employee + |WHERE cast(bonus as string) like '%30%' + |AND cast(dept as byte) > 1 + |AND cast(dept as short) > 1 + |AND cast(bonus as decimal(20, 2)) > 1200""".stripMargin) + checkFiltersRemoved(df8, ansiMode) + val expectedPlanFragment8 = if (ansiMode) { + "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, " + + "CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, ...," + } else { + "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL]," + } + checkPushedInfo(df8, expectedPlanFragment8) + checkAnswer(df8, Seq(Row(2, "david", 10000, 1300, true))) } } }