Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/sql-data-sources-jdbc.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,16 @@ logging into the data sources.
<td><code>pushDownLimit</code></td>
<td><code>false</code></td>
<td>
The option to enable or disable LIMIT push-down into the JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down.
The option to enable or disable LIMIT push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down.
</td>
<td>read</td>
</tr>

<tr>
<td><code>pushDownTableSample</code></td>
<td><code>false</code></td>
<td>
The option to enable or disable TABLESAMPLE push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down TABLESAMPLE to the JDBC data source. Otherwise, if value sets to true, TABLESAMPLE is pushed down to the JDBC data source.
</td>
<td>read</td>
</tr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort))
.set("spark.sql.catalog.postgresql.pushDownTableSample", "true")
.set("spark.sql.catalog.postgresql.pushDownLimit", "true")

override def dataPreparation(conn: Connection): Unit = {}

override def testUpdateColumnType(tbl: String): Unit = {
Expand All @@ -75,4 +78,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes
val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata)
assert(t.schema === expectedSchema)
}

override def supportsTableSample: Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ import java.util
import org.apache.log4j.Level

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sample}
import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.tags.DockerTest

@DockerTest
private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFunSuite {
import testImplicits._

val catalogName: String
// dialect specific update column type test
def testUpdateColumnType(tbl: String): Unit
Expand Down Expand Up @@ -284,4 +289,109 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
testIndexUsingSQL(s"$catalogName.new_table")
}
}

def supportsTableSample: Boolean = false

private def samplePushed(df: DataFrame): Boolean = {
val sample = df.queryExecution.optimizedPlan.collect {
case s: Sample => s
}
sample.isEmpty
}

private def filterPushed(df: DataFrame): Boolean = {
val filter = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
filter.isEmpty
}

private def limitPushed(df: DataFrame, limit: Int): Boolean = {
val filter = df.queryExecution.optimizedPlan.collect {
case relation: DataSourceV2ScanRelation => relation.scan match {
case v1: V1ScanWrapper =>
return v1.pushedDownOperators.limit == Some(limit)
}
}
false
}

private def columnPruned(df: DataFrame, col: String): Boolean = {
val scan = df.queryExecution.optimizedPlan.collectFirst {
case s: DataSourceV2ScanRelation => s
}.get
scan.schema.names.sameElements(Seq(col))
}

test("SPARK-37038: Test TABLESAMPLE") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is flaky:

- SPARK-37038: Test TABLESAMPLE *** FAILED *** (972 milliseconds)
  Array([18]) had length 1 instead of expected length 2 (V2JDBCTest.scala:330)
  org.scalatest.exceptions.TestFailedException:
  at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:472)
  at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:471)
  at org.scalatest.Assertions$.newAssertionFailedException(Assertions.scala:1231)
  at org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:1295)
  at org.apache.spark.sql.jdbc.v2.V2JDBCTest.$anonfun$$init$$35(V2JDBCTest.scala:330)
  at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
  at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1468)
  at org.apache.spark.sql.test.SQLTestUtilsBase.withTable(SQLTestUtils.scala:306)
  at org.apache.spark.sql.test.SQLTestUtilsBase.withTable$(SQLTestUtils.scala:304)
  at org.apache.spark.sql.jdbc.DockerJDBCIntegrationSuite.withTable(DockerJDBCIntegrationSuite.scala:95)
  at org.apache.spark.sql.jdbc.v2.V2JDBCTest.$anonfun$$init$$34(V2JDBCTest.scala:280)
  at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85)
  at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83)
  at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
  at org.scalatest.Transformer.apply(Transformer.scala:22)
  at org.scalatest.Transformer.apply(Transformer.scala:20)
  at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:226)
  at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:190)
  at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:224)
  at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:236)
  at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)

https://github.com/apache/spark/runs/4259414295?check_suite_focus=true

I saw this few times but retriggered all of them so I can't fine now 😢 . I am going to monitor this a bit more but thought it's worth mentioning :-).

if (supportsTableSample) {
withTable(s"$catalogName.new_table") {
sql(s"CREATE TABLE $catalogName.new_table (col1 INT, col2 INT)")
spark.range(10).select($"id" * 2, $"id" * 2 + 1).write.insertInto(s"$catalogName.new_table")

// sample push down + column pruning
val df1 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" +
" REPEATABLE (12345)")
assert(samplePushed(df1))
assert(columnPruned(df1, "col1"))
assert(df1.collect().length < 10)

// sample push down only
val df2 = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (50 PERCENT)" +
" REPEATABLE (12345)")
assert(samplePushed(df2))
assert(df2.collect().length < 10)

// sample(BUCKET ... OUT OF) push down + limit push down + column pruning
val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" +
" LIMIT 2")
assert(samplePushed(df3))
assert(limitPushed(df3, 2))
assert(columnPruned(df3, "col1"))
assert(df3.collect().length == 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think should be <=2, as the TABLESAMPLE is not repeatable and may only produce one row.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another way is to always specify the seed in table sample tests.

@huaxingao can you help fix this? I don't have a strong opinion on which approach is better.


// sample(... PERCENT) push down + limit push down + column pruning
val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" +
" TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2")
assert(samplePushed(df4))
assert(limitPushed(df4, 2))
assert(columnPruned(df4, "col1"))
assert(df4.collect().length == 2)

// sample push down + filter push down + limit push down
val df5 = sql(s"SELECT * FROM $catalogName.new_table" +
" TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2")
assert(samplePushed(df5))
assert(filterPushed(df5))
assert(limitPushed(df5, 2))
assert(df5.collect().length == 2)

// sample + filter + limit + column pruning
// sample pushed down, filer/limit not pushed down, column pruned
// Todo: push down filter/limit
val df6 = sql(s"SELECT col1 FROM $catalogName.new_table" +
" TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2")
assert(samplePushed(df6))
assert(!filterPushed(df6))
assert(!limitPushed(df6, 2))
assert(columnPruned(df6, "col1"))
assert(df6.collect().length == 2)

// sample + limit
// Push down order is sample -> filter -> limit
// only limit is pushed down because in this test sample is after limit
val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5)
assert(!samplePushed(df7))
assert(limitPushed(df7, 2))

// sample + filter
// Push down order is sample -> filter -> limit
// only filter is pushed down because in this test sample is after filter
val df8 = spark.read.table(s"$catalogName.new_table").where($"col1" > 1).sample(0.5)
assert(!samplePushed(df8))
assert(filterPushed(df8))
assert(df8.collect().length < 10)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
* interfaces to do operator push down, and keep the operator push down result in the returned
* {@link Scan}. When pushing down operators, the push down order is:
* filter -&gt; aggregate -&gt; limit -&gt; column pruning.
* sample -&gt; filter -&gt; aggregate -&gt; limit -&gt; column pruning.
*
* @since 3.0.0
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface to
* push down SAMPLE.
*
* @since 3.3.0
*/
@Evolving
public interface SupportsPushDownTableSample extends ScanBuilder {

/**
* Pushes down SAMPLE to the data source.
*/
boolean pushTableSample(
double lowerBound,
double upperBound,
boolean withReplacement,
long seed);
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, Filter}
Expand Down Expand Up @@ -103,8 +103,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
aggregation: Option[Aggregation],
limit: Option[Int],
pushedDownOperators: PushedDownOperators,
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
Expand Down Expand Up @@ -135,9 +134,9 @@ case class RowDataSourceScanExec(

def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")

val (aggString, groupByString) = if (aggregation.nonEmpty) {
(seqToString(aggregation.get.aggregateExpressions),
seqToString(aggregation.get.groupByColumns))
val (aggString, groupByString) = if (pushedDownOperators.aggregation.nonEmpty) {
(seqToString(pushedDownOperators.aggregation.get.aggregateExpressions),
seqToString(pushedDownOperators.aggregation.get.groupByColumns))
} else {
("[]", "[]")
}
Expand All @@ -155,7 +154,10 @@ case class RowDataSourceScanExec(
"PushedFilters" -> seqToString(markedFilters.toSeq),
"PushedAggregates" -> aggString,
"PushedGroupby" -> groupByString) ++
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have the above two entries if agg is not pushed. We can fix it in a followup.

limit.map(value => "PushedLimit" -> s"LIMIT $value")
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
)
}

// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Coun
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -335,8 +336,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
None,
None,
PushedDownOperators(None, None, None),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -410,8 +410,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
None,
None,
PushedDownOperators(None, None, None),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -434,8 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
None,
None,
PushedDownOperators(None, None, None),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,14 @@ class JDBCOptions(
// An option to allow/disallow pushing down aggregate into JDBC data source
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean

// An option to allow/disallow pushing down LIMIT into JDBC data source
// An option to allow/disallow pushing down LIMIT into V2 JDBC data source
// This only applies to Data Source V2 JDBC
val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean

// An option to allow/disallow pushing down TABLESAMPLE into JDBC data source
// This only applies to Data Source V2 JDBC
val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean

// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
Expand Down Expand Up @@ -270,6 +275,7 @@ object JDBCOptions {
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -181,6 +182,7 @@ object JDBCRDD extends Logging {
* @param groupByColumns - The pushed down group by columns.
* @param limit - The pushed down limit. If the value is 0, it means no limit or limit
* is not pushed down.
* @param sample - The pushed down tableSample.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
Expand All @@ -193,6 +195,7 @@ object JDBCRDD extends Logging {
options: JDBCOptions,
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[String]] = None,
sample: Option[TableSampleInfo] = None,
limit: Int = 0): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
Expand All @@ -212,6 +215,7 @@ object JDBCRDD extends Logging {
url,
options,
groupByColumns,
sample,
limit)
}
}
Expand All @@ -231,6 +235,7 @@ private[jdbc] class JDBCRDD(
url: String,
options: JDBCOptions,
groupByColumns: Option[Array[String]],
sample: Option[TableSampleInfo],
limit: Int)
extends RDD[InternalRow](sc, Nil) {

Expand Down Expand Up @@ -354,10 +359,16 @@ private[jdbc] class JDBCRDD(

val myWhereClause = getWhereClause(part)

val myTableSampleClause: String = if (sample.nonEmpty) {
JdbcDialects.get(url).getTableSample(sample.get)
} else {
""
}

val myLimitClause: String = dialect.getLimitClause(limit)

val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
s" $getGroupByClause $myLimitClause"
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
s" $myWhereClause $getGroupByClause $myLimitClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
Expand Down
Loading