From 94198c91fedfa530b830cc173182f5f9400a40dc Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 17 Oct 2021 23:29:06 -0700 Subject: [PATCH 1/6] [SPARK-37038][SQL] DSV2 Sample Push Down --- docs/sql-data-sources-jdbc.md | 8 +++ .../jdbc/v2/PostgresIntegrationSuite.scala | 4 ++ .../connector/expressions/Expressions.java | 22 ++++++++ .../connector/expressions/TableSample.java | 55 +++++++++++++++++++ .../read/SupportsPushDownTableSample.java | 42 ++++++++++++++ .../connector/expressions/expressions.scala | 20 +++++++ .../sql/execution/DataSourceScanExec.scala | 11 +++- .../datasources/DataSourceStrategy.scala | 3 + .../datasources/jdbc/JDBCOptions.scala | 4 ++ .../execution/datasources/jdbc/JDBCRDD.scala | 20 +++++-- .../datasources/jdbc/JDBCRelation.scala | 7 ++- .../datasources/v2/PushDownUtils.scala | 20 ++++++- .../v2/V2ScanRelationPushDown.scala | 32 ++++++++++- .../datasources/v2/jdbc/JDBCScan.scala | 23 +++++++- .../apache/spark/sql/jdbc/JdbcDialects.scala | 6 +- .../spark/sql/jdbc/PostgresDialect.scala | 27 +++++++++ 16 files changed, 288 insertions(+), 16 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/TableSample.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 16d525eea3b1e..23d4e66957836 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -246,6 +246,14 @@ logging into the data sources. read + + pushDownTableSample + false + + The option to enable or disable TABLESAMPLE push-down into the JDBC data source. The default value is false, in which case Spark does not push down TABLESAMPLE to the JDBC data source. Otherwise, if value sets to true, TABLESAMPLE is pushed down to the JDBC data source. + + read + keytab (none) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index de0916deb0154..d8dd41e39ce24 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -49,6 +49,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") + override def dataPreparation(conn: Connection): Unit = {} override def testUpdateColumnType(tbl: String): Unit = { @@ -75,4 +77,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def supportsTableSample: Boolean = true } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java index 6aed8896e9f58..99e3647268611 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java @@ -190,4 +190,26 @@ public static SortOrder sort(Expression expr, SortDirection direction, NullOrder public static SortOrder sort(Expression expr, SortDirection direction) { return LogicalExpressions.sort(expr, direction, direction.defaultNullOrdering()); } + + /** + * Create a tableSample expression. + * + * @param methodName the sample method name + * @param lowerBound the lower-bound of the sampling probability (usually 0.0) + * @param upperBound the upper-bound of the sampling probability + * @param withReplacement whether to sample with replacement + * @param seed the random seed + * @return a TableSample + * + * @since 3.3.0 + */ + public static TableSample tableSample( + String methodName, + double lowerBound, + double upperBound, + boolean withReplacement, + long seed) { + return LogicalExpressions.tableSample( + methodName, lowerBound, upperBound, withReplacement, seed); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/TableSample.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/TableSample.java new file mode 100644 index 0000000000000..dc2e40c8c4bc8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/TableSample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * Represents a TableSample in the public expression API. + * + * @since 3.3.0 + */ +@Experimental +public interface TableSample extends Expression { + + /** + * Returns the sample method name. + */ + String methodName(); + + /** + * Returns the lower-bound of the sampling probability (usually 0.0). + */ + double lowerBound(); + + /** + * Returns the upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + */ + double upperBound(); + + /** + * Returns whether to sample with replacement. + */ + boolean withReplacement(); + + /** + * Returns the random seed. + */ + long seed(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java new file mode 100644 index 0000000000000..303e0f1cdcc47 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.TableSample; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * push down SAMPLE. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTableSample extends Scan { + + /** + * Pushes down SAMPLE to the data source. + */ + boolean pushTableSample(TableSample limit); + + /** + * Returns the TableSample that is pushed to the data source via + * {@link #pushTableSample(TableSample)}. + */ + TableSample pushedTableSample(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 2863d94d198b2..a1772b66d209b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -61,6 +61,15 @@ private[sql] object LogicalExpressions { nullOrdering: NullOrdering): SortOrder = { SortValue(reference, direction, nullOrdering) } + + def tableSample( + methodName: String, + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long): TableSample = { + TableSampleValue(methodName: String, lowerBound, upperBound, withReplacement, seed) + } } /** @@ -357,3 +366,14 @@ private[sql] object SortValue { None } } + +private[sql] final case class TableSampleValue( + methodName: String, + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long) extends TableSample { + + override def describe(): String = s"$methodName $lowerBound $lowerBound $upperBound" + + s" $withReplacement $seed" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 4f282edaf81b3..479261bc77028 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.expressions.TableSample import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} @@ -104,6 +105,7 @@ case class RowDataSourceScanExec( filters: Set[Filter], handledFilters: Set[Filter], aggregation: Option[Aggregation], + sample: TableSample, rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -149,11 +151,18 @@ case class RowDataSourceScanExec( handledFilters } + val sampleStr = if (sample != null) { + s"TABLESAMPLE ${sample.methodName} ${sample.lowerBound} ${sample.upperBound} " + + s"${sample.withReplacement} ${sample.seed}" + } else { + "[]" + } Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq), "PushedAggregates" -> aggString, - "PushedGroupby" -> groupByString) + "PushedGroupby" -> groupByString, + "PushedSample" -> sampleStr) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2c2dac1b9b78a..ed910404b1694 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -336,6 +336,7 @@ object DataSourceStrategy Set.empty, Set.empty, None, + null, toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -410,6 +411,7 @@ object DataSourceStrategy pushedFilters.toSet, handledFilters, None, + null, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -433,6 +435,7 @@ object DataSourceStrategy pushedFilters.toSet, handledFilters, None, + null, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 510a22caa3335..9ede5a86aca82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -191,6 +191,9 @@ class JDBCOptions( // An option to allow/disallow pushing down aggregate into JDBC data source val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean + // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source + val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean + // The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either // by --files option of spark-submit or manually val keytab = { @@ -266,6 +269,7 @@ object JDBCOptions { val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") + val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") val JDBC_TABLE_COMMENT = newOption("tableComment") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e024e4bb02102..7b5ed00677d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.TableSample import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -179,6 +180,7 @@ object JDBCRDD extends Logging { * @param options - JDBC options that contains url, table and other information. * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. + * @param sample - The pushed down tableSample. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ @@ -190,7 +192,8 @@ object JDBCRDD extends Logging { parts: Array[Partition], options: JDBCOptions, outputSchema: Option[StructType] = None, - groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = { + groupByColumns: Option[Array[String]] = None, + sample: Option[TableSample] = None): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -208,7 +211,8 @@ object JDBCRDD extends Logging { parts, url, options, - groupByColumns) + groupByColumns, + sample) } } @@ -226,7 +230,8 @@ private[jdbc] class JDBCRDD( partitions: Array[Partition], url: String, options: JDBCOptions, - groupByColumns: Option[Array[String]]) + groupByColumns: Option[Array[String]], + sample: Option[TableSample]) extends RDD[InternalRow](sc, Nil) { /** @@ -274,6 +279,13 @@ private[jdbc] class JDBCRDD( } } + /** + * A TABLESAMPLE clause representing pushed-down TableSample. + */ + private def getTableSample(): String = { + JdbcDialects.get(url).getTableSample(sample) + } + /** * Runs the SQL query against the JDBC driver. * @@ -350,7 +362,7 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + - s" $getGroupByClause" + s" $getGroupByClause $getTableSample" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8098fa0b83a95..b3c8ca3e93de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.TableSample import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects @@ -298,7 +299,8 @@ private[sql] case class JDBCRelation( requiredColumns: Array[String], finalSchema: StructType, filters: Array[Filter], - groupByColumns: Option[Array[String]]): RDD[Row] = { + groupByColumns: Option[Array[String]], + tableSample: Option[TableSample]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -308,7 +310,8 @@ private[sql] case class JDBCRelation( parts, jdbcOptions, Some(finalSchema), - groupByColumns).asInstanceOf[RDD[Row]] + groupByColumns, + tableSample).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 335038ab53f5a..d95ac09ffcb3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -22,10 +22,10 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, TableSample} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -138,6 +138,22 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down TableSample to the data source Scan + */ + def pushTableSample(scan: Scan, sample: TableSample): Boolean = { + scan match { + case s: SupportsPushDownTableSample => s.pushTableSample(sample) + case v1: V1ScanWrapper => + v1.v1Scan match { + case s: SupportsPushDownTableSample => + s.pushTableSample(sample) + case _ => false + } + case _ => false + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ec45a5d7853c9..db46f0dd8f128 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,10 +23,11 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project, Sample} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.LogicalExpressions import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownTableSample, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -36,7 +37,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { - applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan)))) + applySample(applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))) } private def createScanBuilder(plan: LogicalPlan) = plan.transform { @@ -225,6 +226,31 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { withProjection } + def applySample(plan: LogicalPlan): LogicalPlan = plan.transform { + case sample @ Sample(_, _, _, _, DataSourceV2ScanRelation(_, scan, _)) => + val supportsPushDownSample = scan match { + case _: SupportsPushDownTableSample => true + case v1: V1ScanWrapper => + v1.v1Scan match { + case _: SupportsPushDownTableSample => true + case _ => false + } + case _ => false + } + if (supportsPushDownSample) { + val tableSample = LogicalExpressions.tableSample( + "", + sample.lowerBound, + sample.upperBound, + sample.withReplacement, + sample.seed) + val pushed = PushDownUtils.pushTableSample(scan, tableSample) + if (pushed) sample.child else sample + } else { + sample + } + } + private def getWrappedScan( scan: Scan, sHolder: ScanBuilderHolder, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ef42691e5ca94..7ff4874dca70f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.connector.read.V1Scan +import org.apache.spark.sql.connector.expressions.TableSample +import org.apache.spark.sql.connector.read.{SupportsPushDownTableSample, V1Scan} import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} import org.apache.spark.sql.types.StructType @@ -28,9 +30,23 @@ case class JDBCScan( prunedSchema: StructType, pushedFilters: Array[Filter], pushedAggregateColumn: Array[String] = Array(), - groupByColumns: Option[Array[String]]) extends V1Scan { + groupByColumns: Option[Array[String]]) extends V1Scan + with SupportsPushDownTableSample { override def readSchema(): StructType = prunedSchema + private var tableSample: Option[TableSample] = None + + override def pushTableSample(sample: TableSample): Boolean = { + if (relation.jdbcOptions.pushDownTableSample && + JdbcDialects.get(relation.jdbcOptions.url).supportsTableSample) { + this.tableSample = Some(sample) + true + } else { + false + } + } + + override def pushedTableSample: TableSample = if (tableSample.nonEmpty) tableSample.get else null override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = { new BaseRelation with TableScan { @@ -43,7 +59,8 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns) + relation.buildScan( + columnList, prunedSchema, pushedFilters, groupByColumns, tableSample) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 9e54ba7ce27e4..d79977367daf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, Timesta import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.{NamedReference, TableSample} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.internal.SQLConf @@ -358,6 +358,10 @@ abstract class JdbcDialect extends Serializable with Logging{ def classifyException(message: String, e: Throwable): AnalysisException = { new AnalysisException(message, cause = Some(e)) } + + def supportsTableSample: Boolean = false + + def getTableSample(sample: Option[TableSample]): String = "" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3ce785ed844c5..01f06e8056f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import java.util.Locale +import org.apache.spark.sql.connector.expressions.TableSample import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ @@ -154,4 +155,30 @@ private object PostgresDialect extends JdbcDialect { val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL" s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable" } + + override def supportsTableSample: Boolean = true + + override def getTableSample(sample: Option[TableSample]): String = { + if (sample.nonEmpty) { + val method = if (sample.get.methodName.isEmpty) { + "BERNOULLI" + } else { + sample.get.methodName + } + + val repeatable = if (sample.get.withReplacement()) { + if (sample.get.seed() != 0) { + "REPEATABLE (" + sample.get.seed() + ")" + } else { + "REPEATABLE" + } + } else { + "" + } + s"TABLESAMPLE $method" + + s" ( ${(sample.get.upperBound - sample.get.lowerBound) * 100} ) $repeatable" + } else { + "" + } + } } From 751578c1676ecc654db5f59bafb0d27d6ea4431a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 18 Oct 2021 00:09:32 -0700 Subject: [PATCH 2/6] fix style --- .../org/apache/spark/sql/connector/expressions/Expressions.java | 2 +- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java index 99e3647268611..c6364f1fa89a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java @@ -190,7 +190,7 @@ public static SortOrder sort(Expression expr, SortDirection direction, NullOrder public static SortOrder sort(Expression expr, SortDirection direction) { return LogicalExpressions.sort(expr, direction, direction.defaultNullOrdering()); } - + /** * Create a tableSample expression. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index db46f0dd8f128..0b5a531802d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -250,7 +250,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sample } } - + private def getWrappedScan( scan: Scan, sHolder: ScanBuilderHolder, From 0dc45e66bb7f0f9feda53c61125640de50df8f73 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 19 Oct 2021 16:14:38 -0700 Subject: [PATCH 3/6] add seed in TableSample SQL syntax --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 5 ++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 14 ++++++++++---- .../apache/spark/sql/jdbc/PostgresDialect.scala | 11 +---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 32b080cdbb080..e1bccf653b21d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -674,7 +674,7 @@ joinCriteria ; sample - : TABLESAMPLE '(' sampleMethod? ')' + : TABLESAMPLE '(' sampleMethod? ')' (REPEATABLE '('seed=INTEGER_VALUE')')? ; sampleMethod @@ -1194,6 +1194,7 @@ ansiNonReserved | REFRESH | RENAME | REPAIR + | REPEATABLE | REPLACE | RESET | RESPECT @@ -1460,6 +1461,7 @@ nonReserved | REFRESH | RENAME | REPAIR + | REPEATABLE | REPLACE | RESET | RESPECT @@ -1726,6 +1728,7 @@ REFERENCES: 'REFERENCES'; REFRESH: 'REFRESH'; RENAME: 'RENAME'; REPAIR: 'REPAIR'; +REPEATABLE: 'REPEATABLE'; REPLACE: 'REPLACE'; RESET: 'RESET'; RESPECT: 'RESPECT'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d36c7ac82e9bd..4dd4c48ac6436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1180,7 +1180,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { // Create a sampled plan if we need one. - def sample(fraction: Double): Sample = { + def sample(fraction: Double, seed: Long): Sample = { // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling // function takes X PERCENT as the input and the range of X is [0, 100], we need to // adjust the fraction. @@ -1188,13 +1188,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) + Sample(0.0, fraction, withReplacement = false, seed, query) } if (ctx.sampleMethod() == null) { throw QueryParsingErrors.emptyInputForTableSampleError(ctx) } + val seed = if (ctx.seed != null) { + ctx.seed.getText.toLong + } else { + (math.random * 1000).toLong + } + ctx.sampleMethod() match { case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query) @@ -1202,7 +1208,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case ctx: SampleByPercentileContext => val fraction = ctx.percentage.getText.toDouble val sign = if (ctx.negativeSign == null) 1 else -1 - sample(sign * fraction / 100.0d) + sample(sign * fraction / 100.0d, seed) case ctx: SampleByBytesContext => val bytesStr = ctx.bytes.getText @@ -1222,7 +1228,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } case ctx: SampleByBucketContext => - sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble, seed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 01f06e8056f9a..874ff36fc7a66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -165,16 +165,7 @@ private object PostgresDialect extends JdbcDialect { } else { sample.get.methodName } - - val repeatable = if (sample.get.withReplacement()) { - if (sample.get.seed() != 0) { - "REPEATABLE (" + sample.get.seed() + ")" - } else { - "REPEATABLE" - } - } else { - "" - } + val repeatable = "REPEATABLE (" + sample.get.seed() + ")" s"TABLESAMPLE $method" + s" ( ${(sample.get.upperBound - sample.get.lowerBound) * 100} ) $repeatable" } else { From 28c392c568c0a084d5a134f0e0ece376f6a2fc0f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 19 Oct 2021 16:45:39 -0700 Subject: [PATCH 4/6] add REPEATABLE in sql-ref-ansi-compliance.md --- docs/sql-ref-ansi-compliance.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index c10e8661bfde1..e4f404965f407 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -491,6 +491,7 @@ Below is a list of all the keywords in Spark SQL. |REGEXP|non-reserved|non-reserved|not a keyword| |RENAME|non-reserved|non-reserved|non-reserved| |REPAIR|non-reserved|non-reserved|non-reserved| +|REPEATABLE|non-reserved|non-reserved|non-reserved| |REPLACE|non-reserved|non-reserved|non-reserved| |RESET|non-reserved|non-reserved|non-reserved| |RESPECT|non-reserved|non-reserved|non-reserved| From 0d7ddbd40b14a80decca0ddf2219b4caf9a0c095 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 27 Oct 2021 16:25:01 -0700 Subject: [PATCH 5/6] SupportsPushDownTableSample shoud extend ScanBuilder --- .../read/SupportsPushDownTableSample.java | 8 +-- .../sql/execution/DataSourceScanExec.scala | 14 ++--- .../datasources/v2/PushDownUtils.scala | 13 ++--- .../v2/V2ScanRelationPushDown.scala | 52 +++++++++---------- .../datasources/v2/jdbc/JDBCScan.scala | 20 ++----- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 17 +++++- 6 files changed, 54 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java index 303e0f1cdcc47..4fde1a0277623 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java @@ -27,16 +27,10 @@ * @since 3.3.0 */ @Evolving -public interface SupportsPushDownTableSample extends Scan { +public interface SupportsPushDownTableSample extends ScanBuilder { /** * Pushes down SAMPLE to the data source. */ boolean pushTableSample(TableSample limit); - - /** - * Returns the TableSample that is pushed to the data source via - * {@link #pushTableSample(TableSample)}. - */ - TableSample pushedTableSample(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 479261bc77028..75611f4218304 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -105,7 +105,7 @@ case class RowDataSourceScanExec( filters: Set[Filter], handledFilters: Set[Filter], aggregation: Option[Aggregation], - sample: TableSample, + sample: Option[TableSample], rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -151,18 +151,14 @@ case class RowDataSourceScanExec( handledFilters } - val sampleStr = if (sample != null) { - s"TABLESAMPLE ${sample.methodName} ${sample.lowerBound} ${sample.upperBound} " + - s"${sample.withReplacement} ${sample.seed}" - } else { - "[]" - } Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq), "PushedAggregates" -> aggString, - "PushedGroupby" -> groupByString, - "PushedSample" -> sampleStr) + "PushedGroupby" -> groupByString) ++ + sample.map(v => "PushedSample" -> + s"SAMPLE ${v.methodName} ${v.lowerBound} ${v.upperBound} ${v.withReplacement} ${v.seed}" + ) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index d95ac09ffcb3f..93a49086c9296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -141,15 +141,10 @@ object PushDownUtils extends PredicateHelper { /** * Pushes down TableSample to the data source Scan */ - def pushTableSample(scan: Scan, sample: TableSample): Boolean = { - scan match { - case s: SupportsPushDownTableSample => s.pushTableSample(sample) - case v1: V1ScanWrapper => - v1.v1Scan match { - case s: SupportsPushDownTableSample => - s.pushTableSample(sample) - case _ => false - } + def pushTableSample(scanBuilder: ScanBuilder, sample: TableSample): Boolean = { + scanBuilder match { + case s: SupportsPushDownTableSample => + s.pushTableSample(sample) case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 0b5a531802d28..59238a35556d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project, Sample} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.expressions.LogicalExpressions +import org.apache.spark.sql.connector.expressions.{LogicalExpressions, TableSample} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownTableSample, V1Scan} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -37,7 +37,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { - applySample(applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))) + applyColumnPruning(applySample(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))) } private def createScanBuilder(plan: LogicalPlan) = plan.transform { @@ -227,27 +227,23 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } def applySample(plan: LogicalPlan): LogicalPlan = plan.transform { - case sample @ Sample(_, _, _, _, DataSourceV2ScanRelation(_, scan, _)) => - val supportsPushDownSample = scan match { - case _: SupportsPushDownTableSample => true - case v1: V1ScanWrapper => - v1.v1Scan match { - case _: SupportsPushDownTableSample => true - case _ => false + case sample @ Sample(_, _, _, _, child) => child match { + case ScanOperation(_, _, sHolder: ScanBuilderHolder) => + val tableSample = LogicalExpressions.tableSample( + "", + sample.lowerBound, + sample.upperBound, + sample.withReplacement, + sample.seed) + val pushed = PushDownUtils.pushTableSample(sHolder.builder, tableSample) + if (pushed) { + sHolder.setSample(Some(tableSample)) + sample.child + } else { + sample } - case _ => false - } - if (supportsPushDownSample) { - val tableSample = LogicalExpressions.tableSample( - "", - sample.lowerBound, - sample.upperBound, - sample.withReplacement, - sample.seed) - val pushed = PushDownUtils.pushTableSample(scan, tableSample) - if (pushed) sample.child else sample - } else { - sample + + case _ => sample } } @@ -262,7 +258,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - V1ScanWrapper(v1, pushedFilters, aggregation) + V1ScanWrapper(v1, pushedFilters, aggregation, sHolder.pushedSample) case _ => scan } } @@ -271,13 +267,17 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case class ScanBuilderHolder( output: Seq[AttributeReference], relation: DataSourceV2Relation, - builder: ScanBuilder) extends LeafNode + builder: ScanBuilder) extends LeafNode { + var pushedSample: Option[TableSample] = None + private[sql] def setSample(sample: Option[TableSample]): Unit = pushedSample = sample +} // A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by // the physical v1 scan node. case class V1ScanWrapper( v1Scan: V1Scan, handledFilters: Seq[sources.Filter], - pushedAggregate: Option[Aggregation]) extends Scan { + pushedAggregate: Option[Aggregation], + pushedSAmple: Option[TableSample]) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 7ff4874dca70f..da220ea9c7a56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.connector.expressions.TableSample -import org.apache.spark.sql.connector.read.{SupportsPushDownTableSample, V1Scan} +import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation -import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} import org.apache.spark.sql.types.StructType @@ -30,23 +29,10 @@ case class JDBCScan( prunedSchema: StructType, pushedFilters: Array[Filter], pushedAggregateColumn: Array[String] = Array(), - groupByColumns: Option[Array[String]]) extends V1Scan - with SupportsPushDownTableSample { + groupByColumns: Option[Array[String]], + tableSample: Option[TableSample]) extends V1Scan { override def readSchema(): StructType = prunedSchema - private var tableSample: Option[TableSample] = None - - override def pushTableSample(sample: TableSample): Boolean = { - if (relation.jdbcOptions.pushDownTableSample && - JdbcDialects.get(relation.jdbcOptions.url).supportsTableSample) { - this.tableSample = Some(sample) - true - } else { - false - } - } - - override def pushedTableSample: TableSample = if (tableSample.nonEmpty) tableSample.get else null override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = { new BaseRelation with TableScan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index b0de7c015c91a..d4a48f59e559f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,8 +20,9 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.TableSample import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownTableSample} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.jdbc.JdbcDialects @@ -36,6 +37,7 @@ case class JDBCScanBuilder( with SupportsPushDownFilters with SupportsPushDownRequiredColumns with SupportsPushDownAggregates + with SupportsPushDownTableSample with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -44,6 +46,8 @@ case class JDBCScanBuilder( private var finalSchema = schema + private var tableSample: Option[TableSample] = None + override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) @@ -98,6 +102,15 @@ case class JDBCScanBuilder( } } + override def pushTableSample(sample: TableSample): Boolean = { + if (jdbcOptions.pushDownTableSample && + JdbcDialects.get(jdbcOptions.url).supportsTableSample) { + this.tableSample = Some(sample) + return true + } + false + } + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. @@ -123,6 +136,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols) + pushedAggregateList, pushedGroupByCols, tableSample) } } From 1ee110568825fb8d068cb5771413489edba3888b Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 27 Oct 2021 17:17:55 -0700 Subject: [PATCH 6/6] fix file conflicts --- .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 29 +++++++++++++++++++ .../datasources/v2/DataSourceV2Strategy.scala | 3 +- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index a97adf94ed1e4..ab310c549ace9 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -23,6 +23,7 @@ import org.apache.log4j.Level import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.catalyst.plans.logical.Sample import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} @@ -284,4 +285,32 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu testIndexUsingSQL(s"$catalogName.new_table") } } + + def supportsTableSample: Boolean = false + + test("Test TABLESAMPLE") { + withTable(s"$catalogName.new_table") { + sql(s"CREATE TABLE $catalogName.new_table (col1 INT, col2 INT)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (1, 2)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (3, 4)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (5, 6)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (7, 8)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (9, 10)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (11, 12)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (13, 14)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (15, 16)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (17, 18)") + sql(s"INSERT INTO TABLE $catalogName.new_table values (19, 20)") + if (supportsTableSample) { + val df = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + + s" REPEATABLE (12345)") + df.explain(true) + val sample = df.queryExecution.optimizedPlan.collect { + case s: Sample => s + } + assert(sample.isEmpty) + assert(df.collect().length <= 7) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 66ee43130976d..e058ebdd3bb67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -94,7 +94,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) => + DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate, sample), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( @@ -108,6 +108,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat Set.empty, pushed.toSet, aggregate, + sample, unsafeRowRDD, v1Relation, tableIdentifier = None)