Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,48 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
jdbc(url, table, connectionProperties)
}

/**
* Construct a `DataFrame` representing the database table accessible via JDBC URL url named
* table using connection properties. The `predicates` parameter gives a list expressions
* suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`.
*
* Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
* your external database systems.
*
* You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
* in <a
* href="https://spark.apache.org/docs/latest/sql-data-sources-jdbc.html#data-source-option">
* Data Source Option</a> in the version you use.
*
* @param table
* Name of the table in the external database.
* @param predicates
* Condition in the where clause for each partition.
* @param connectionProperties
* JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
* a "user" and "password" property should be included. "fetchsize" can be used to control the
* number of rows per fetch.
* @since 3.4.0
*/
def jdbc(
url: String,
table: String,
predicates: Array[String],
connectionProperties: Properties): DataFrame = {
sparkSession.newDataFrame { builder =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please set the format to JDBC? We are now relying the presence of predicates to figure out that something is a JDBC table. That is relying far too heavily on the client doing the right thing, for example what would happen if you set format = parquet and still define predicates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. we can't rely on client.

val dataSourceBuilder = builder.getReadBuilder.getDataSourceBuilder
format("jdbc")
dataSourceBuilder.setFormat(source)
predicates.foreach(predicate => dataSourceBuilder.addPredicates(predicate))
this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
val params = extraOptions ++ connectionProperties.asScala
params.foreach { case (k, v) =>
dataSourceBuilder.putOptions(k, v)
}
builder.build()
}
}

/**
* Loads a JSON file and returns the results as a `DataFrame`.
*
Expand Down
21 changes: 9 additions & 12 deletions ...ctor/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ class PlanGenerationTestSuite
}
}

private val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"

private val simpleSchema = new StructType()
.add("id", "long")
.add("a", "int")
Expand Down Expand Up @@ -236,21 +238,16 @@ class PlanGenerationTestSuite
}

test("read jdbc") {
session.read.jdbc(
"jdbc:h2:mem:testdb0;user=testUser;password=testPass",
"TEST.TIMETYPES",
new Properties())
session.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties())
}

test("read jdbc with partition") {
session.read.jdbc(
"jdbc:h2:mem:testdb0;user=testUser;password=testPass",
"TEST.EMP",
"THEID",
0,
4,
3,
new Properties())
session.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties())
}

test("read jdbc with predicates") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
session.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties())
}

test("read json") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ message Read {

// (Optional) A list of path for file-system backed data sources.
repeated string paths = 4;

// (Optional) Condition in the where clause for each partition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the comment that this currently only works for jdbc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

//
// This is only supported by the JDBC data source.
repeated string predicates = 5;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Relation [NAME#0,THEID#0] JDBCRelation(TEST.PEOPLE) [numPartitions=2]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"common": {
"planId": "0"
},
"read": {
"dataSource": {
"format": "jdbc",
"options": {
"url": "jdbc:h2:mem:testdb0;user\u003dtestUser;password\u003dtestPass",
"dbtable": "TEST.PEOPLE"
},
"predicates": ["THEID \u003c 2", "THEID \u003e\u003d 2"]
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.google.common.collect.{Lists, Maps}
import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.stub.StreamObserver

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand}
Expand All @@ -48,6 +48,8 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation}
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.internal.CatalogImpl
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -688,25 +690,46 @@ class SparkConnectPlanner(val session: SparkSession) {
reader.format(rel.getDataSource.getFormat)
}
localMap.foreach { case (key, value) => reader.option(key, value) }
if (rel.getDataSource.hasSchema && rel.getDataSource.getSchema.nonEmpty) {

DataType.parseTypeWithFallback(
rel.getDataSource.getSchema,
StructType.fromDDL,
fallbackParser = DataType.fromJson) match {
case s: StructType => reader.schema(s)
case other => throw InvalidPlanInput(s"Invalid schema $other")

if (rel.getDataSource.getFormat == "jdbc" && rel.getDataSource.getPredicatesCount > 0) {
if (!localMap.contains(JDBCOptions.JDBC_URL) ||
!localMap.contains(JDBCOptions.JDBC_TABLE_NAME)) {
throw InvalidPlanInput(s"Invalid jdbc params, please specify jdbc url and table.")
}

val url = rel.getDataSource.getOptionsMap.get(JDBCOptions.JDBC_URL)
val table = rel.getDataSource.getOptionsMap.get(JDBCOptions.JDBC_TABLE_NAME)
val options = new JDBCOptions(url, table, localMap)
val predicates = rel.getDataSource.getPredicatesList.asScala.toArray
val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) =>
JDBCPartition(part, i): Partition
}
val relation = JDBCRelation(parts, options)(session)
LogicalRelation(relation)
} else if (rel.getDataSource.getPredicatesCount == 0) {
if (rel.getDataSource.hasSchema && rel.getDataSource.getSchema.nonEmpty) {

DataType.parseTypeWithFallback(
rel.getDataSource.getSchema,
StructType.fromDDL,
fallbackParser = DataType.fromJson) match {
case s: StructType => reader.schema(s)
case other => throw InvalidPlanInput(s"Invalid schema $other")
}
}
if (rel.getDataSource.getPathsCount == 0) {
reader.load().queryExecution.analyzed
} else if (rel.getDataSource.getPathsCount == 1) {
reader.load(rel.getDataSource.getPaths(0)).queryExecution.analyzed
} else {
reader.load(rel.getDataSource.getPathsList.asScala.toSeq: _*).queryExecution.analyzed
}
}
if (rel.getDataSource.getPathsCount == 0) {
reader.load().queryExecution.analyzed
} else if (rel.getDataSource.getPathsCount == 1) {
reader.load(rel.getDataSource.getPaths(0)).queryExecution.analyzed
} else {
reader.load(rel.getDataSource.getPathsList.asScala.toSeq: _*).queryExecution.analyzed
throw InvalidPlanInput(
s"Predicates are not supported for ${rel.getDataSource.getFormat} data sources.")
}

case _ => throw InvalidPlanInput("Does not support " + rel.getReadTypeCase.name())
case _ => throw InvalidPlanInput(s"Does not support ${rel.getReadTypeCase.name()}")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class ProtoToParsedPlanTestSuite extends SparkFunSuite with SharedSparkSession {

conn = DriverManager.getConnection(url, properties)
conn.prepareStatement("create schema test").executeUpdate()
conn
.prepareStatement(
"create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)")
.executeUpdate()
conn
.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP(7))")
.executeUpdate()
Expand Down
190 changes: 95 additions & 95 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ class Read(google.protobuf.message.Message):
SCHEMA_FIELD_NUMBER: builtins.int
OPTIONS_FIELD_NUMBER: builtins.int
PATHS_FIELD_NUMBER: builtins.int
PREDICATES_FIELD_NUMBER: builtins.int
format: builtins.str
"""(Optional) Supported formats include: parquet, orc, text, json, parquet, csv, avro.

Expand All @@ -633,13 +634,22 @@ class Read(google.protobuf.message.Message):
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""(Optional) A list of path for file-system backed data sources."""
@property
def predicates(
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""(Optional) Condition in the where clause for each partition.

This is only supported by the JDBC data source.
"""
def __init__(
self,
*,
format: builtins.str | None = ...,
schema: builtins.str | None = ...,
options: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
paths: collections.abc.Iterable[builtins.str] | None = ...,
predicates: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self,
Expand Down Expand Up @@ -667,6 +677,8 @@ class Read(google.protobuf.message.Message):
b"options",
"paths",
b"paths",
"predicates",
b"predicates",
"schema",
b"schema",
],
Expand Down