Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,46 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
jdbc(url, table, connectionProperties)
}

/**
* Construct a `DataFrame` representing the database table accessible via JDBC URL url named
* table using connection properties. The `predicates` parameter gives a list expressions
* suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`.
*
* Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
* your external database systems.
*
* You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
* in <a
* href="https://spark.apache.org/docs/latest/sql-data-sources-jdbc.html#data-source-option">
* Data Source Option</a> in the version you use.
*
* @param table
* Name of the table in the external database.
* @param predicates
* Condition in the where clause for each partition.
* @param connectionProperties
* JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
* a "user" and "password" property should be included. "fetchsize" can be used to control the
* number of rows per fetch.
* @since 3.4.0
*/
def jdbc(
url: String,
table: String,
predicates: Array[String],
connectionProperties: Properties): DataFrame = {
sparkSession.newDataFrame { builder =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please set the format to JDBC? We are now relying the presence of predicates to figure out that something is a JDBC table. That is relying far too heavily on the client doing the right thing, for example what would happen if you set format = parquet and still define predicates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. we can't rely on client.

val dataSourceBuilder = builder.getReadBuilder.getDataSourceBuilder
predicates.foreach(predicate => dataSourceBuilder.addPredicates(predicate))
this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
val params = extraOptions ++ connectionProperties.asScala
params.foreach { case (k, v) =>
dataSourceBuilder.putOptions(k, v)
}
builder.build()
}
}

/**
* Loads a JSON file and returns the results as a `DataFrame`.
*
Expand Down
21 changes: 9 additions & 12 deletions ...ctor/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ class PlanGenerationTestSuite
}
}

private val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"

private val simpleSchema = new StructType()
.add("id", "long")
.add("a", "int")
Expand Down Expand Up @@ -236,21 +238,16 @@ class PlanGenerationTestSuite
}

test("read jdbc") {
session.read.jdbc(
"jdbc:h2:mem:testdb0;user=testUser;password=testPass",
"TEST.TIMETYPES",
new Properties())
session.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties())
}

test("read jdbc with partition") {
session.read.jdbc(
"jdbc:h2:mem:testdb0;user=testUser;password=testPass",
"TEST.EMP",
"THEID",
0,
4,
3,
new Properties())
session.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties())
}

test("read jdbc with predicates") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
session.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties())
}

test("read json") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ message Read {

// (Optional) A list of path for file-system backed data sources.
repeated string paths = 4;

// (Optional) Condition in the where clause for each partition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the comment that this currently only works for jdbc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

repeated string predicates = 5;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Relation [NAME#0,THEID#0] JDBCRelation(TEST.PEOPLE) [numPartitions=2]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"common": {
"planId": "0"
},
"read": {
"dataSource": {
"options": {
"url": "jdbc:h2:mem:testdb0;user\u003dtestUser;password\u003dtestPass",
"dbtable": "TEST.PEOPLE"
},
"predicates": ["THEID \u003c 2", "THEID \u003e\u003d 2"]
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.google.common.collect.{Lists, Maps}
import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.stub.StreamObserver

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand}
Expand All @@ -48,6 +48,8 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation}
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.internal.CatalogImpl
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -684,26 +686,43 @@ class SparkConnectPlanner(val session: SparkSession) {
case proto.Read.ReadTypeCase.DATA_SOURCE =>
val localMap = CaseInsensitiveMap[String](rel.getDataSource.getOptionsMap.asScala.toMap)
val reader = session.read
if (rel.getDataSource.hasFormat) {
reader.format(rel.getDataSource.getFormat)
}
localMap.foreach { case (key, value) => reader.option(key, value) }
if (rel.getDataSource.hasSchema && rel.getDataSource.getSchema.nonEmpty) {

DataType.parseTypeWithFallback(
rel.getDataSource.getSchema,
StructType.fromDDL,
fallbackParser = DataType.fromJson) match {
case s: StructType => reader.schema(s)
case other => throw InvalidPlanInput(s"Invalid schema $other")

if (rel.getDataSource.getPredicatesCount == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make the logic a bit like this:

if (format == "jdbc" && rel.getDataSource.getPredicatesCount) {
  // Plan JDBC with predicates
} else id (rel.getDataSource.getPredicatesCount == 0) {
 // Plan datasource
} else {
  throw InvalidPlan(s"Predicates are not supported for $format datasources.)"
}

if (rel.getDataSource.hasFormat) {
reader.format(rel.getDataSource.getFormat)
}
if (rel.getDataSource.hasSchema && rel.getDataSource.getSchema.nonEmpty) {

DataType.parseTypeWithFallback(
rel.getDataSource.getSchema,
StructType.fromDDL,
fallbackParser = DataType.fromJson) match {
case s: StructType => reader.schema(s)
case other => throw InvalidPlanInput(s"Invalid schema $other")
}
}
if (rel.getDataSource.getPathsCount == 0) {
reader.load().queryExecution.analyzed
} else if (rel.getDataSource.getPathsCount == 1) {
reader.load(rel.getDataSource.getPaths(0)).queryExecution.analyzed
} else {
reader.load(rel.getDataSource.getPathsList.asScala.toSeq: _*).queryExecution.analyzed
}
}
if (rel.getDataSource.getPathsCount == 0) {
reader.load().queryExecution.analyzed
} else if (rel.getDataSource.getPathsCount == 1) {
reader.load(rel.getDataSource.getPaths(0)).queryExecution.analyzed
} else {
reader.load(rel.getDataSource.getPathsList.asScala.toSeq: _*).queryExecution.analyzed
if (!localMap.contains(JDBCOptions.JDBC_URL) ||
!localMap.contains(JDBCOptions.JDBC_TABLE_NAME)) {
throw InvalidPlanInput(s"Invalid jdbc params, please specify jdbc url and table.")
}
val url = rel.getDataSource.getOptionsMap.get(JDBCOptions.JDBC_URL)
val table = rel.getDataSource.getOptionsMap.get(JDBCOptions.JDBC_TABLE_NAME)
val options = new JDBCOptions(url, table, localMap)
val predicates = rel.getDataSource.getPredicatesList.asScala.toArray
val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) =>
JDBCPartition(part, i): Partition
}
val relation = JDBCRelation(parts, options)(session)
LogicalRelation(relation)
}

case _ => throw InvalidPlanInput("Does not support " + rel.getReadTypeCase.name())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class ProtoToParsedPlanTestSuite extends SparkFunSuite with SharedSparkSession {

conn = DriverManager.getConnection(url, properties)
conn.prepareStatement("create schema test").executeUpdate()
conn
.prepareStatement(
"create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)")
.executeUpdate()
conn
.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP(7))")
.executeUpdate()
Expand Down
224 changes: 126 additions & 98 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

79 changes: 76 additions & 3 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -680,33 +680,106 @@ class Read(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_schema", b"_schema"]
) -> typing_extensions.Literal["schema"] | None: ...

class PartitionedJDBC(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

class OptionsEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
value: builtins.str
def __init__(
self,
*,
key: builtins.str = ...,
value: builtins.str = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
) -> None: ...

URL_FIELD_NUMBER: builtins.int
TABLE_FIELD_NUMBER: builtins.int
PREDICATES_FIELD_NUMBER: builtins.int
OPTIONS_FIELD_NUMBER: builtins.int
url: builtins.str
"""(Required) JDBC URL."""
table: builtins.str
"""(Required) Name of the table in the external database."""
@property
def predicates(
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""(Optional) Condition in the where clause for each partition."""
@property
def options(
self,
) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
"""Options or connection arguments for the JDBC database.
The map key is case insensitive.
"""
def __init__(
self,
*,
url: builtins.str = ...,
table: builtins.str = ...,
predicates: collections.abc.Iterable[builtins.str] | None = ...,
options: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"options", b"options", "predicates", b"predicates", "table", b"table", "url", b"url"
],
) -> None: ...

NAMED_TABLE_FIELD_NUMBER: builtins.int
DATA_SOURCE_FIELD_NUMBER: builtins.int
PARTITIONED_JDBC_FIELD_NUMBER: builtins.int
@property
def named_table(self) -> global___Read.NamedTable: ...
@property
def data_source(self) -> global___Read.DataSource: ...
@property
def partitioned_jdbc(self) -> global___Read.PartitionedJDBC: ...
def __init__(
self,
*,
named_table: global___Read.NamedTable | None = ...,
data_source: global___Read.DataSource | None = ...,
partitioned_jdbc: global___Read.PartitionedJDBC | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"data_source", b"data_source", "named_table", b"named_table", "read_type", b"read_type"
"data_source",
b"data_source",
"named_table",
b"named_table",
"partitioned_jdbc",
b"partitioned_jdbc",
"read_type",
b"read_type",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"data_source", b"data_source", "named_table", b"named_table", "read_type", b"read_type"
"data_source",
b"data_source",
"named_table",
b"named_table",
"partitioned_jdbc",
b"partitioned_jdbc",
"read_type",
b"read_type",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["read_type", b"read_type"]
) -> typing_extensions.Literal["named_table", "data_source"] | None: ...
) -> typing_extensions.Literal["named_table", "data_source", "partitioned_jdbc"] | None: ...

global___Read = Read

Expand Down