diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index 361b92be67031..f8852ad197a72 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -250,7 +250,16 @@ logging into the data sources.
pushDownLimit |
false |
- The option to enable or disable LIMIT push-down into the JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down.
+ The option to enable or disable LIMIT push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down.
+ |
+ read |
+
+
+
+ pushDownTableSample |
+ false |
+
+ The option to enable or disable TABLESAMPLE push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down TABLESAMPLE to the JDBC data source. Otherwise, if value sets to true, TABLESAMPLE is pushed down to the JDBC data source.
|
read |
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index de0916deb0154..1a1a592d00bca 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -49,6 +49,9 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort))
+ .set("spark.sql.catalog.postgresql.pushDownTableSample", "true")
+ .set("spark.sql.catalog.postgresql.pushDownLimit", "true")
+
override def dataPreparation(conn: Connection): Unit = {}
override def testUpdateColumnType(tbl: String): Unit = {
@@ -75,4 +78,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes
val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata)
assert(t.schema === expectedSchema)
}
+
+ override def supportsTableSample: Boolean = true
}
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index a97adf94ed1e4..5e2504aafff2f 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -22,10 +22,13 @@ import java.util
import org.apache.log4j.Level
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sample}
import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -33,6 +36,8 @@ import org.apache.spark.tags.DockerTest
@DockerTest
private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFunSuite {
+ import testImplicits._
+
val catalogName: String
// dialect specific update column type test
def testUpdateColumnType(tbl: String): Unit
@@ -284,4 +289,109 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
testIndexUsingSQL(s"$catalogName.new_table")
}
}
+
+ def supportsTableSample: Boolean = false
+
+ private def samplePushed(df: DataFrame): Boolean = {
+ val sample = df.queryExecution.optimizedPlan.collect {
+ case s: Sample => s
+ }
+ sample.isEmpty
+ }
+
+ private def filterPushed(df: DataFrame): Boolean = {
+ val filter = df.queryExecution.optimizedPlan.collect {
+ case f: Filter => f
+ }
+ filter.isEmpty
+ }
+
+ private def limitPushed(df: DataFrame, limit: Int): Boolean = {
+ val filter = df.queryExecution.optimizedPlan.collect {
+ case relation: DataSourceV2ScanRelation => relation.scan match {
+ case v1: V1ScanWrapper =>
+ return v1.pushedDownOperators.limit == Some(limit)
+ }
+ }
+ false
+ }
+
+ private def columnPruned(df: DataFrame, col: String): Boolean = {
+ val scan = df.queryExecution.optimizedPlan.collectFirst {
+ case s: DataSourceV2ScanRelation => s
+ }.get
+ scan.schema.names.sameElements(Seq(col))
+ }
+
+ test("SPARK-37038: Test TABLESAMPLE") {
+ if (supportsTableSample) {
+ withTable(s"$catalogName.new_table") {
+ sql(s"CREATE TABLE $catalogName.new_table (col1 INT, col2 INT)")
+ spark.range(10).select($"id" * 2, $"id" * 2 + 1).write.insertInto(s"$catalogName.new_table")
+
+ // sample push down + column pruning
+ val df1 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" +
+ " REPEATABLE (12345)")
+ assert(samplePushed(df1))
+ assert(columnPruned(df1, "col1"))
+ assert(df1.collect().length < 10)
+
+ // sample push down only
+ val df2 = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (50 PERCENT)" +
+ " REPEATABLE (12345)")
+ assert(samplePushed(df2))
+ assert(df2.collect().length < 10)
+
+ // sample(BUCKET ... OUT OF) push down + limit push down + column pruning
+ val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" +
+ " LIMIT 2")
+ assert(samplePushed(df3))
+ assert(limitPushed(df3, 2))
+ assert(columnPruned(df3, "col1"))
+ assert(df3.collect().length == 2)
+
+ // sample(... PERCENT) push down + limit push down + column pruning
+ val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" +
+ " TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2")
+ assert(samplePushed(df4))
+ assert(limitPushed(df4, 2))
+ assert(columnPruned(df4, "col1"))
+ assert(df4.collect().length == 2)
+
+ // sample push down + filter push down + limit push down
+ val df5 = sql(s"SELECT * FROM $catalogName.new_table" +
+ " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2")
+ assert(samplePushed(df5))
+ assert(filterPushed(df5))
+ assert(limitPushed(df5, 2))
+ assert(df5.collect().length == 2)
+
+ // sample + filter + limit + column pruning
+ // sample pushed down, filer/limit not pushed down, column pruned
+ // Todo: push down filter/limit
+ val df6 = sql(s"SELECT col1 FROM $catalogName.new_table" +
+ " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2")
+ assert(samplePushed(df6))
+ assert(!filterPushed(df6))
+ assert(!limitPushed(df6, 2))
+ assert(columnPruned(df6, "col1"))
+ assert(df6.collect().length == 2)
+
+ // sample + limit
+ // Push down order is sample -> filter -> limit
+ // only limit is pushed down because in this test sample is after limit
+ val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5)
+ assert(!samplePushed(df7))
+ assert(limitPushed(df7, 2))
+
+ // sample + filter
+ // Push down order is sample -> filter -> limit
+ // only filter is pushed down because in this test sample is after filter
+ val df8 = spark.read.table(s"$catalogName.new_table").where($"col1" > 1).sample(0.5)
+ assert(!samplePushed(df8))
+ assert(filterPushed(df8))
+ assert(df8.collect().length < 10)
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
index 20c9d2e883923..27ee534d804ff 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
@@ -23,7 +23,7 @@
* An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
* interfaces to do operator push down, and keep the operator push down result in the returned
* {@link Scan}. When pushing down operators, the push down order is:
- * filter -> aggregate -> limit -> column pruning.
+ * sample -> filter -> aggregate -> limit -> column pruning.
*
* @since 3.0.0
*/
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
new file mode 100644
index 0000000000000..3630feb4680ea
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A mix-in interface for {@link Scan}. Data sources can implement this interface to
+ * push down SAMPLE.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public interface SupportsPushDownTableSample extends ScanBuilder {
+
+ /**
+ * Pushes down SAMPLE to the data source.
+ */
+ boolean pushTableSample(
+ double lowerBound,
+ double upperBound,
+ boolean withReplacement,
+ long seed);
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 86b29d4698202..a3e3a2d54e32a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
-import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
+import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, Filter}
@@ -103,8 +103,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
- aggregation: Option[Aggregation],
- limit: Option[Int],
+ pushedDownOperators: PushedDownOperators,
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
@@ -135,9 +134,9 @@ case class RowDataSourceScanExec(
def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
- val (aggString, groupByString) = if (aggregation.nonEmpty) {
- (seqToString(aggregation.get.aggregateExpressions),
- seqToString(aggregation.get.groupByColumns))
+ val (aggString, groupByString) = if (pushedDownOperators.aggregation.nonEmpty) {
+ (seqToString(pushedDownOperators.aggregation.get.aggregateExpressions),
+ seqToString(pushedDownOperators.aggregation.get.groupByColumns))
} else {
("[]", "[]")
}
@@ -155,7 +154,10 @@ case class RowDataSourceScanExec(
"PushedFilters" -> seqToString(markedFilters.toSeq),
"PushedAggregates" -> aggString,
"PushedGroupby" -> groupByString) ++
- limit.map(value => "PushedLimit" -> s"LIMIT $value")
+ pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++
+ pushedDownOperators.sample.map(v => "PushedSample" ->
+ s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
+ )
}
// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 81cd37f35aa15..232559714ccbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -45,6 +45,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Coun
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
+import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.sources._
@@ -335,8 +336,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
- None,
- None,
+ PushedDownOperators(None, None, None),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
@@ -410,8 +410,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
- None,
- None,
+ PushedDownOperators(None, None, None),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@@ -434,8 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
- None,
- None,
+ PushedDownOperators(None, None, None),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index e0730f39035d3..774d8ed228609 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -191,9 +191,14 @@ class JDBCOptions(
// An option to allow/disallow pushing down aggregate into JDBC data source
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean
- // An option to allow/disallow pushing down LIMIT into JDBC data source
+ // An option to allow/disallow pushing down LIMIT into V2 JDBC data source
+ // This only applies to Data Source V2 JDBC
val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean
+ // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source
+ // This only applies to Data Source V2 JDBC
+ val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean
+
// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
@@ -270,6 +275,7 @@ object JDBCOptions {
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
+ val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 7973850201826..1b8d33b94fbd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -181,6 +182,7 @@ object JDBCRDD extends Logging {
* @param groupByColumns - The pushed down group by columns.
* @param limit - The pushed down limit. If the value is 0, it means no limit or limit
* is not pushed down.
+ * @param sample - The pushed down tableSample.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
@@ -193,6 +195,7 @@ object JDBCRDD extends Logging {
options: JDBCOptions,
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[String]] = None,
+ sample: Option[TableSampleInfo] = None,
limit: Int = 0): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
@@ -212,6 +215,7 @@ object JDBCRDD extends Logging {
url,
options,
groupByColumns,
+ sample,
limit)
}
}
@@ -231,6 +235,7 @@ private[jdbc] class JDBCRDD(
url: String,
options: JDBCOptions,
groupByColumns: Option[Array[String]],
+ sample: Option[TableSampleInfo],
limit: Int)
extends RDD[InternalRow](sc, Nil) {
@@ -354,10 +359,16 @@ private[jdbc] class JDBCRDD(
val myWhereClause = getWhereClause(part)
+ val myTableSampleClause: String = if (sample.nonEmpty) {
+ JdbcDialects.get(url).getTableSample(sample.get)
+ } else {
+ ""
+ }
+
val myLimitClause: String = dialect.getLimitClause(limit)
- val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
- s" $getGroupByClause $myLimitClause"
+ val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
+ s" $myWhereClause $getGroupByClause $myLimitClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index ff9fcd493f600..cd1eae89ee890 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
@@ -299,6 +300,7 @@ private[sql] case class JDBCRelation(
finalSchema: StructType,
filters: Array[Filter],
groupByColumns: Option[Array[String]],
+ tableSample: Option[TableSampleInfo],
limit: Int): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
@@ -310,6 +312,7 @@ private[sql] case class JDBCRelation(
jdbcOptions,
Some(finalSchema),
groupByColumns,
+ tableSample,
limit).asInstanceOf[RDD[Row]]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index b688c325c5636..e4f00021a74a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -93,8 +93,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case PhysicalOperation(project, filters,
- DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate, limit), output)) =>
+ case PhysicalOperation(project, filters, DataSourceV2ScanRelation(
+ _, V1ScanWrapper(scan, pushed, pushedDownOperators), output)) =>
val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
if (v1Relation.schema != scan.readSchema()) {
throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
@@ -108,8 +108,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
output.toStructType,
Set.empty,
pushed.toSet,
- aggregate,
- limit,
+ pushedDownOperators,
unsafeRowRDD,
v1Relation,
tableIdentifier = None)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index a8c251a812a6d..f837ab54546d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
@@ -138,6 +138,18 @@ object PushDownUtils extends PredicateHelper {
}
}
+ /**
+ * Pushes down TableSample to the data source Scan
+ */
+ def pushTableSample(scanBuilder: ScanBuilder, sample: TableSampleInfo): Boolean = {
+ scanBuilder match {
+ case s: SupportsPushDownTableSample =>
+ s.pushTableSample(
+ sample.lowerBound, sample.upperBound, sample.withReplacement, sample.seed)
+ case _ => false
+ }
+ }
+
/**
* Pushes down LIMIT to the data source Scan
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
new file mode 100644
index 0000000000000..c21354d646164
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+
+/**
+ * Pushed down operators
+ */
+case class PushedDownOperators(
+ aggregation: Option[Aggregation],
+ sample: Option[TableSampleInfo],
+ limit: Option[Int])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
new file mode 100644
index 0000000000000..cb4fb9eb0809a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+case class TableSampleInfo(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 960a1ea60598b..f73f831903364 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
@@ -36,7 +36,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._
def apply(plan: LogicalPlan): LogicalPlan = {
- applyColumnPruning(applyLimit(pushDownAggregates(pushDownFilters(createScanBuilder(plan)))))
+ applyColumnPruning(
+ applyLimit(pushDownAggregates(pushDownFilters(pushDownSample(createScanBuilder(plan))))))
}
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
@@ -225,13 +226,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
withProjection
}
+ def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform {
+ case sample: Sample => sample.child match {
+ case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
+ val tableSample = TableSampleInfo(
+ sample.lowerBound,
+ sample.upperBound,
+ sample.withReplacement,
+ sample.seed)
+ val pushed = PushDownUtils.pushTableSample(sHolder.builder, tableSample)
+ if (pushed) {
+ sHolder.pushedSample = Some(tableSample)
+ sample.child
+ } else {
+ sample
+ }
+
+ case _ => sample
+ }
+ }
+
def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
child match {
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue)
if (limitPushed) {
- sHolder.setLimit(Some(limitValue))
+ sHolder.pushedLimit = Some(limitValue)
}
globalLimit
case _ => globalLimit
@@ -249,7 +270,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
- V1ScanWrapper(v1, pushedFilters, aggregation, sHolder.pushedLimit)
+ val pushedDownOperators =
+ PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit)
+ V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
case _ => scan
}
}
@@ -260,16 +283,16 @@ case class ScanBuilderHolder(
relation: DataSourceV2Relation,
builder: ScanBuilder) extends LeafNode {
var pushedLimit: Option[Int] = None
- private[sql] def setLimit(limit: Option[Int]): Unit = pushedLimit = limit
+
+ var pushedSample: Option[TableSampleInfo] = None
}
-// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
-// the physical v1 scan node.
+// A wrapper for v1 scan to carry the translated filters and the handled ones, along with
+// other pushed down operators. This is required by the physical v1 scan node.
case class V1ScanWrapper(
v1Scan: V1Scan,
handledFilters: Seq[sources.Filter],
- pushedAggregate: Option[Aggregation],
- pushedLimit: Option[Int]) extends Scan {
+ pushedDownOperators: PushedDownOperators) extends Scan {
override def readSchema(): StructType = v1Scan.readSchema()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index 94d9d1433f9d4..ff79d1a5c4144 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -20,6 +20,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
import org.apache.spark.sql.types.StructType
@@ -29,6 +30,7 @@ case class JDBCScan(
pushedFilters: Array[Filter],
pushedAggregateColumn: Array[String] = Array(),
groupByColumns: Option[Array[String]],
+ tableSample: Option[TableSampleInfo],
pushedLimit: Int) extends V1Scan {
override def readSchema(): StructType = prunedSchema
@@ -44,7 +46,8 @@ case class JDBCScan(
} else {
pushedAggregateColumn
}
- relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, pushedLimit)
+ relation.buildScan(
+ columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit)
}
}.asInstanceOf[T]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 14826748dd432..7605b03f49ea5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -21,9 +21,10 @@ import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
@@ -37,6 +38,7 @@ case class JDBCScanBuilder(
with SupportsPushDownRequiredColumns
with SupportsPushDownAggregates
with SupportsPushDownLimit
+ with SupportsPushDownTableSample
with Logging {
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
@@ -45,15 +47,9 @@ case class JDBCScanBuilder(
private var finalSchema = schema
- private var pushedLimit = 0
+ private var tableSample: Option[TableSampleInfo] = None
- override def pushLimit(limit: Int): Boolean = {
- if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) {
- pushedLimit = limit
- return true
- }
- false
- }
+ private var pushedLimit = 0
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (jdbcOptions.pushDownPredicate) {
@@ -109,6 +105,27 @@ case class JDBCScanBuilder(
}
}
+ override def pushTableSample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long): Boolean = {
+ if (jdbcOptions.pushDownTableSample &&
+ JdbcDialects.get(jdbcOptions.url).supportsTableSample) {
+ this.tableSample = Some(TableSampleInfo(lowerBound, upperBound, withReplacement, seed))
+ return true
+ }
+ false
+ }
+
+ override def pushLimit(limit: Int): Boolean = {
+ if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) {
+ pushedLimit = limit
+ return true
+ }
+ false
+ }
+
override def pruneColumns(requiredSchema: StructType): Unit = {
// JDBC doesn't support nested column pruning.
// TODO (SPARK-32593): JDBC support nested column and nested column pruning.
@@ -134,6 +151,6 @@ case class JDBCScanBuilder(
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter,
- pushedAggregateList, pushedGroupByCols, pushedLimit)
+ pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index ac6fd2f5a1b58..568318c76f329 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -370,6 +371,11 @@ abstract class JdbcDialect extends Serializable with Logging{
* returns whether the dialect supports limit or not
*/
def supportsLimit(): Boolean = true
+
+ def supportsTableSample: Boolean = false
+
+ def getTableSample(sample: TableSampleInfo): String =
+ throw new UnsupportedOperationException("TableSample is not supported by this data source")
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 3ce785ed844c5..317ae19ed914b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -21,6 +21,7 @@ import java.sql.{Connection, Types}
import java.util.Locale
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.types._
@@ -154,4 +155,13 @@ private object PostgresDialect extends JdbcDialect {
val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL"
s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable"
}
+
+ override def supportsTableSample: Boolean = true
+
+ override def getTableSample(sample: TableSampleInfo): String = {
+ // hard-coded to BERNOULLI for now because Spark doesn't have a way to specify sample
+ // method name
+ s"TABLESAMPLE BERNOULLI" +
+ s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE (${sample.seed})"
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index d5b8ea9c42e60..9df1a9f7290ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -93,6 +93,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2)))
}
+ // TABLESAMPLE ({integer_expression | decimal_expression} PERCENT) and
+ // TABLESAMPLE (BUCKET integer_expression OUT OF integer_expression)
+ // are tested in JDBC dialect tests because TABLESAMPLE is not supported by all the DBMS
+ test("TABLESAMPLE (integer_expression ROWS) is the same as LIMIT") {
+ val df = sql("SELECT NAME FROM h2.test.employee TABLESAMPLE (3 ROWS)")
+ val scan = df.queryExecution.optimizedPlan.collectFirst {
+ case s: DataSourceV2ScanRelation => s
+ }.get
+ assert(scan.schema.names.sameElements(Seq("NAME")))
+ checkPushedLimit(df, true, 3)
+ checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy")))
+ }
+
test("simple scan with LIMIT") {
val df1 = spark.read.table("h2.test.employee")
.where($"dept" === 1).limit(1)
@@ -146,12 +159,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = {
df.queryExecution.optimizedPlan.collect {
- case DataSourceV2ScanRelation(_, scan, _) => scan match {
+ case relation: DataSourceV2ScanRelation => relation.scan match {
case v1: V1ScanWrapper =>
if (pushed) {
- assert(v1.pushedLimit.nonEmpty && v1.pushedLimit.get === limit)
+ assert(v1.pushedDownOperators.limit === Some(limit))
} else {
- assert(v1.pushedLimit.isEmpty)
+ assert(v1.pushedDownOperators.limit.isEmpty)
}
}
}