From b30555e680c6eda6003f9e771cef2207857a91bd Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 3 Dec 2025 15:25:21 +0000 Subject: [PATCH 1/5] Making fetchsize option case-insensitive for Postgres connector --- .../execution/datasources/jdbc/JDBCRDD.scala | 3 +- .../spark/sql/jdbc/PostgresDialectSuite.scala | 87 +++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 47f5f180789e..17412671994e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.SQL_TEXT import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.execution.datasources.{DataSourceMetricsMixin, ExternalEngineDatasourceRDD} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -306,7 +307,7 @@ class JDBCRDD( val part = thePart.asInstanceOf[JDBCPartition] conn = getConnection(part.idx) import scala.jdk.CollectionConverters._ - dialect.beforeFetch(conn, options.asProperties.asScala.toMap) + dialect.beforeFetch(conn, CaseInsensitiveMap(options.asProperties.asScala.toMap)) // This executes a generic SQL statement (or PL/SQL block) before reading // the table/query via JDBC. Use this feature to initialize the database diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala new file mode 100644 index 000000000000..3b54045bf71b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.jdbc + +import java.sql.Connection + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class PostgresDialectSuite extends SparkFunSuite with MockitoSugar with SharedSparkSession { + + gridTest("PostgresDialect sets autoCommit correctly with fetchSize option")( + Seq( + ("fetchsize", Some("100"), true), + ("fetchSize", Some("100"), true), + ("FETCHSIZE", Some("100"), true), + ("fetchsize", Some("0"), false), + ("fetchsize", None, false) + ) + ) { case (optionKey: String, optionValue: Option[String], shouldSetAutoCommit: Boolean) => + val conn = mock[Connection] + when(conn.prepareStatement(any[String], any[Int], any[Int])) + .thenReturn(mock[java.sql.PreparedStatement]) + + val optionsMap = Map( + "url" -> "jdbc:postgresql://localhost/test", + "dbtable" -> "test_table" + ) ++ optionValue.map(v => Map(optionKey -> v)).getOrElse(Map.empty) + + val options = new JDBCOptions(CaseInsensitiveMap(optionsMap), None) + + val schema = StructType(Seq(StructField("id", IntegerType))) + val partition = new JDBCPartition(null, 0) + val rdd = new JDBCRDD( + spark.sparkContext, + _ => conn, + schema, + Array.empty, + Array.empty, + Array(partition), + "jdbc:postgresql://localhost/test", + options, + Some(PostgresDialect()), + None, + None, + 0, + Array.empty, + 0, + Map.empty + ) + + try { + rdd.compute(partition, org.apache.spark.TaskContext.empty()) + } catch { + case _: Exception => // Expected to fail, we just want beforeFetch to be called + } + + if (shouldSetAutoCommit) { + verify(conn).setAutoCommit(false) + } else { + verify(conn, never()).setAutoCommit(false) + } + } +} From e2d23e6d4c0e7ee70ec143bc46ea18021c163a90 Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 3 Dec 2025 18:53:25 +0000 Subject: [PATCH 2/5] Resolved build errors --- .../org/apache/spark/sql/jdbc/PostgresDialectSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala index 3b54045bf71b..979545ddd3e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala @@ -26,7 +26,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCDatabaseMetadata, JDBCOptions, JDBCPartition, JDBCRDD} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -40,7 +40,7 @@ class PostgresDialectSuite extends SparkFunSuite with MockitoSugar with SharedSp ("fetchsize", Some("0"), false), ("fetchsize", None, false) ) - ) { case (optionKey: String, optionValue: Option[String], shouldSetAutoCommit: Boolean) => + ) { case (optionKey, optionValue, shouldSetAutoCommit) => val conn = mock[Connection] when(conn.prepareStatement(any[String], any[Int], any[Int])) .thenReturn(mock[java.sql.PreparedStatement]) @@ -50,7 +50,7 @@ class PostgresDialectSuite extends SparkFunSuite with MockitoSugar with SharedSp "dbtable" -> "test_table" ) ++ optionValue.map(v => Map(optionKey -> v)).getOrElse(Map.empty) - val options = new JDBCOptions(CaseInsensitiveMap(optionsMap), None) + val options = new JDBCOptions(CaseInsensitiveMap(optionsMap)) val schema = StructType(Seq(StructField("id", IntegerType))) val partition = new JDBCPartition(null, 0) @@ -63,7 +63,7 @@ class PostgresDialectSuite extends SparkFunSuite with MockitoSugar with SharedSp Array(partition), "jdbc:postgresql://localhost/test", options, - Some(PostgresDialect()), + JDBCDatabaseMetadata(None, None, None, None), None, None, 0, From 9a87095e22b36b4c94d92ce17263481b9641b431 Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 3 Dec 2025 19:18:10 +0000 Subject: [PATCH 3/5] Switched to JDBCOptions --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +--- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 11 +++++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 17412671994e..8534a24d0110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -27,7 +27,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.SQL_TEXT import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.execution.datasources.{DataSourceMetricsMixin, ExternalEngineDatasourceRDD} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -306,8 +305,7 @@ class JDBCRDD( val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] conn = getConnection(part.idx) - import scala.jdk.CollectionConverters._ - dialect.beforeFetch(conn, CaseInsensitiveMap(options.asProperties.asScala.toMap)) + dialect.beforeFetch(conn, options) // This executes a generic SQL statement (or PL/SQL block) before reading // the table/query via JDBC. Use this feature to initialize the database diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index ce4c347cad34..548d370d611d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -342,9 +342,20 @@ abstract class JdbcDialect extends Serializable with Logging { * @param connection The connection object * @param properties The connection properties. This is passed through from the relation. */ + @deprecated("Use beforeFetch(Connection, JDBCOptions) instead", "4.0.0") def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param options The JDBC options for the connection. + */ + def beforeFetch(connection: Connection, options: JDBCOptions): Unit = { + beforeFetch(connection, options.parameters) + } + /** * Escape special characters in SQL string literals. * @param value The string to be escaped. From 5e10a767e71ab280e0352424e66242a119ff97c3 Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Fri, 5 Dec 2025 12:34:53 +0000 Subject: [PATCH 4/5] Change deprecated min spark version to 4.2.0 --- .../src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 548d370d611d..875bfeb011bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -342,7 +342,7 @@ abstract class JdbcDialect extends Serializable with Logging { * @param connection The connection object * @param properties The connection properties. This is passed through from the relation. */ - @deprecated("Use beforeFetch(Connection, JDBCOptions) instead", "4.0.0") + @deprecated("Use beforeFetch(Connection, JDBCOptions) instead", "4.2.0") def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } From 85fc530dc30f905cc77386219f2645d4328d7cd6 Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Fri, 5 Dec 2025 15:17:26 +0000 Subject: [PATCH 5/5] Simplified test --- .../spark/sql/jdbc/PostgresDialectSuite.scala | 82 +++++++------------ 1 file changed, 30 insertions(+), 52 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala index 979545ddd3e0..15682bcf68f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala @@ -20,68 +20,46 @@ package org.apache.spark.sql.jdbc import java.sql.Connection -import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCDatabaseMetadata, JDBCOptions, JDBCPartition, JDBCRDD} -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions -class PostgresDialectSuite extends SparkFunSuite with MockitoSugar with SharedSparkSession { +class PostgresDialectSuite extends SparkFunSuite with MockitoSugar { - gridTest("PostgresDialect sets autoCommit correctly with fetchSize option")( - Seq( - ("fetchsize", Some("100"), true), - ("fetchSize", Some("100"), true), - ("FETCHSIZE", Some("100"), true), - ("fetchsize", Some("0"), false), - ("fetchsize", None, false) - ) - ) { case (optionKey, optionValue, shouldSetAutoCommit) => - val conn = mock[Connection] - when(conn.prepareStatement(any[String], any[Int], any[Int])) - .thenReturn(mock[java.sql.PreparedStatement]) - - val optionsMap = Map( - "url" -> "jdbc:postgresql://localhost/test", + private def createJDBCOptions(extraOptions: Map[String, String]): JDBCOptions = { + new JDBCOptions(Map( + "url" -> "jdbc:postgresql://localhost:5432/test", "dbtable" -> "test_table" - ) ++ optionValue.map(v => Map(optionKey -> v)).getOrElse(Map.empty) + ) ++ extraOptions) + } - val options = new JDBCOptions(CaseInsensitiveMap(optionsMap)) + test("beforeFetch sets autoCommit=false with lowercase fetchsize") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "100"))) + verify(conn).setAutoCommit(false) + } - val schema = StructType(Seq(StructField("id", IntegerType))) - val partition = new JDBCPartition(null, 0) - val rdd = new JDBCRDD( - spark.sparkContext, - _ => conn, - schema, - Array.empty, - Array.empty, - Array(partition), - "jdbc:postgresql://localhost/test", - options, - JDBCDatabaseMetadata(None, None, None, None), - None, - None, - 0, - Array.empty, - 0, - Map.empty - ) + test("beforeFetch sets autoCommit=false with camelCase fetchSize") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("fetchSize" -> "100"))) + verify(conn).setAutoCommit(false) + } - try { - rdd.compute(partition, org.apache.spark.TaskContext.empty()) - } catch { - case _: Exception => // Expected to fail, we just want beforeFetch to be called - } + test("beforeFetch sets autoCommit=false with uppercase FETCHSIZE") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("FETCHSIZE" -> "100"))) + verify(conn).setAutoCommit(false) + } - if (shouldSetAutoCommit) { - verify(conn).setAutoCommit(false) - } else { - verify(conn, never()).setAutoCommit(false) - } + test("beforeFetch does not set autoCommit when fetchSize is 0") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "0"))) + verify(conn, never()).setAutoCommit(false) } }