diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 27561857c122..ed27f611fc88 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, SortOrder, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} @@ -46,7 +47,9 @@ class InMemoryTable( val name: String, val schema: StructType, override val partitioning: Array[Transform], - override val properties: util.Map[String, String]) + override val properties: util.Map[String, String], + val distribution: Distribution = Distributions.unspecified(), + val ordering: Array[SortOrder] = Array.empty) extends Table with SupportsRead with SupportsWrite with SupportsDelete with SupportsMetadataColumns { @@ -274,7 +277,11 @@ class InMemoryTable( this } - override def build(): Write = new Write { + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = distribution + + override def requiredOrdering: Array[SortOrder] = ordering + override def toBatch: BatchWrite = writer override def toStreaming: StreamingWrite = streamingWriter match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala index edb9f6548083..7baa66cf23ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala @@ -24,7 +24,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -69,13 +70,24 @@ class BasicInMemoryTableCatalog extends TableCatalog { schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { + createTable(ident, schema, partitions, properties, Distributions.unspecified(), Array.empty) + } + + def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + distribution: Distribution, + ordering: Array[SortOrder]): Table = { if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident) } InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) - val table = new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties) + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, ordering) tables.put(ident, table) namespaces.putIfAbsent(ident.namespace.toList, Map()) table diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala new file mode 100644 index 000000000000..c4f8a7ad1147 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, NamedExpression, NullOrdering, NullsFirst, NullsLast, SortDirection, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, Sort} +import org.apache.spark.sql.connector.distributions.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, IdentityTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortValue} +import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write} +import org.apache.spark.sql.internal.SQLConf + +object DistributionAndOrderingUtils { + + def prepareQuery(write: Write, query: LogicalPlan, conf: SQLConf): LogicalPlan = write match { + case write: RequiresDistributionAndOrdering => + val resolver = conf.resolver + + val distribution = write.requiredDistribution match { + case d: OrderedDistribution => + d.ordering.map(e => toCatalyst(e, query, resolver)) + case d: ClusteredDistribution => + d.clustering.map(e => toCatalyst(e, query, resolver)) + case _: UnspecifiedDistribution => + Array.empty[Expression] + } + + val queryWithDistribution = if (distribution.nonEmpty) { + val numShufflePartitions = conf.numShufflePartitions + // the conversion to catalyst expressions above produces SortOrder expressions + // for OrderedDistribution and generic expressions for ClusteredDistribution + // this allows RepartitionByExpression to pick either range or hash partitioning + RepartitionByExpression(distribution, query, numShufflePartitions) + } else { + query + } + + val ordering = write.requiredOrdering.toSeq + .map(e => toCatalyst(e, query, resolver)) + .asInstanceOf[Seq[SortOrder]] + + val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { + Sort(ordering, global = false, queryWithDistribution) + } else { + queryWithDistribution + } + + queryWithDistributionAndOrdering + + case _ => + query + } + + private def toCatalyst( + expr: V2Expression, + query: LogicalPlan, + resolver: Resolver): Expression = { + + // we cannot perform the resolution in the analyzer since we need to optimize expressions + // in nodes like OverwriteByExpression before constructing a logical write + def resolve(ref: FieldReference): NamedExpression = { + query.resolve(ref.parts, resolver) match { + case Some(attr) => attr + case None => throw new AnalysisException(s"Cannot resolve '$ref' using ${query.output}") + } + } + + expr match { + case SortValue(child, direction, nullOrdering) => + val catalystChild = toCatalyst(child, query, resolver) + SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty) + case IdentityTransform(ref) => + resolve(ref) + case ref: FieldReference => + resolve(ref) + case _ => + throw new AnalysisException(s"$expr is not currently supported") + } + } + + private def toCatalyst(direction: V2SortDirection): SortDirection = direction match { + case V2SortDirection.ASCENDING => Ascending + case V2SortDirection.DESCENDING => Descending + } + + private def toCatalyst(nullOrdering: V2NullOrdering): NullOrdering = nullOrdering match { + case V2NullOrdering.NULLS_FIRST => NullsFirst + case V2NullOrdering.NULLS_LAST => NullsLast + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index a8e0731edf14..6efc64087e86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -40,7 +40,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) => val writeBuilder = newWriteBuilder(r.table, query, options) val write = writeBuilder.build() - a.copy(write = Some(write)) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) + a.copy(write = Some(write), query = newQuery) case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) => // fail if any filter cannot be converted. correctness depends on removing all matching data. @@ -63,7 +64,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { throw new SparkException(s"Table does not support overwrite by expression: $table") } - o.copy(write = Some(write)) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) + o.copy(write = Some(write), query = newQuery) case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) => val table = r.table @@ -74,7 +76,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } - o.copy(write = Some(write)) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) + o.copy(write = Some(write), query = newQuery) } private def isTruncate(filters: Array[Filter]): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala new file mode 100644 index 000000000000..317ebb725653 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -0,0 +1,572 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{catalyst, DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder} +import org.apache.spark.sql.connector.expressions.LogicalExpressions._ +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.util.QueryExecutionListener + +class WriteDistributionAndOrderingSuite + extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") + } + + private val namespace = Array("ns1") + private val ident = Identifier.of(namespace, "test_table") + private val tableNameAsString = "testcat." + ident.toString + private val emptyProps = Collections.emptyMap[String, String] + private val schema = new StructType() + .add("id", IntegerType) + .add("data", StringType) + + private val resolver = conf.resolver + + test("ordered distribution and sort with same exprs: append") { + checkOrderedDistributionAndSortWithSameExprs("append") + } + + test("ordered distribution and sort with same exprs: overwrite") { + checkOrderedDistributionAndSortWithSameExprs("overwrite") + } + + test("ordered distribution and sort with same exprs: overwriteDynamic") { + checkOrderedDistributionAndSortWithSameExprs("overwriteDynamic") + } + + private def checkOrderedDistributionAndSortWithSameExprs(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.ordered(tableOrdering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioning = RangePartitioning(writeOrdering, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeCommand = command) + } + + test("clustered distribution and sort with same exprs: append") { + checkClusteredDistributionAndSortWithSameExprs("append") + } + + test("clustered distribution and sort with same exprs: overwrite") { + checkClusteredDistributionAndSortWithSameExprs("overwrite") + } + + test("clustered distribution and sort with same exprs: overwriteDynamic") { + checkClusteredDistributionAndSortWithSameExprs("overwriteDynamic") + } + + private def checkClusteredDistributionAndSortWithSameExprs(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val clustering = Array[Expression](FieldReference("data"), FieldReference("id")) + val tableDistribution = Distributions.clustered(clustering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data"), attr("id")) + val writePartitioning = HashPartitioning(writePartitioningExprs, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeCommand = command) + } + + test("clustered distribution and sort with extended exprs: append") { + checkClusteredDistributionAndSortWithExtendedExprs("append") + } + + test("clustered distribution and sort with extended exprs: overwrite") { + checkClusteredDistributionAndSortWithExtendedExprs("overwrite") + } + + test("clustered distribution and sort with extended exprs: overwriteDynamic") { + checkClusteredDistributionAndSortWithExtendedExprs("overwriteDynamic") + } + + private def checkClusteredDistributionAndSortWithExtendedExprs(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val clustering = Array[Expression](FieldReference("data")) + val tableDistribution = Distributions.clustered(clustering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data")) + val writePartitioning = HashPartitioning(writePartitioningExprs, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeCommand = command) + } + + test("unspecified distribution and local sort: append") { + checkUnspecifiedDistributionAndLocalSort("append") + } + + test("unspecified distribution and local sort: overwrite") { + checkUnspecifiedDistributionAndLocalSort("overwrite") + } + + test("unspecified distribution and local sort: overwriteDynamic") { + checkUnspecifiedDistributionAndLocalSort("overwriteDynamic") + } + + private def checkUnspecifiedDistributionAndLocalSort(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.unspecified() + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioning = UnknownPartitioning(0) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeCommand = command) + } + + test("unspecified distribution and no sort: append") { + checkUnspecifiedDistributionAndNoSort("append") + } + + test("unspecified distribution and no sort: overwrite") { + checkUnspecifiedDistributionAndNoSort("overwrite") + } + + test("unspecified distribution and no sort: overwriteDynamic") { + checkUnspecifiedDistributionAndNoSort("overwriteDynamic") + } + + private def checkUnspecifiedDistributionAndNoSort(command: String): Unit = { + val tableOrdering = Array.empty[SortOrder] + val tableDistribution = Distributions.unspecified() + + val writeOrdering = Seq.empty[catalyst.expressions.SortOrder] + val writePartitioning = UnknownPartitioning(0) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeCommand = command) + } + + test("ordered distribution and sort with manual global sort: append") { + checkOrderedDistributionAndSortWithManualGlobalSort("append") + } + + test("ordered distribution and sort with manual global sort: overwrite") { + checkOrderedDistributionAndSortWithManualGlobalSort("overwrite") + } + + test("ordered distribution and sort with manual global sort: overwriteDynamic") { + checkOrderedDistributionAndSortWithManualGlobalSort("overwriteDynamic") + } + + private def checkOrderedDistributionAndSortWithManualGlobalSort(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.ordered(tableOrdering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioning = RangePartitioning(writeOrdering, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy("data", "id"), + writeCommand = command) + } + + test("ordered distribution and sort with incompatible global sort: append") { + checkOrderedDistributionAndSortWithIncompatibleGlobalSort("append") + } + + test("ordered distribution and sort with incompatible global sort: overwrite") { + checkOrderedDistributionAndSortWithIncompatibleGlobalSort("overwrite") + } + + test("ordered distribution and sort with incompatible global sort: overwriteDynamic") { + checkOrderedDistributionAndSortWithIncompatibleGlobalSort("overwriteDynamic") + } + + private def checkOrderedDistributionAndSortWithIncompatibleGlobalSort(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.ordered(tableOrdering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioning = RangePartitioning(writeOrdering, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy(df("data").desc, df("id").asc), + writeCommand = command) + } + + test("ordered distribution and sort with manual local sort: append") { + checkOrderedDistributionAndSortWithManualLocalSort("append") + } + + test("ordered distribution and sort with manual local sort: overwrite") { + checkOrderedDistributionAndSortWithManualLocalSort("overwrite") + } + + test("ordered distribution and sort with manual local sort: overwriteDynamic") { + checkOrderedDistributionAndSortWithManualLocalSort("overwriteDynamic") + } + + private def checkOrderedDistributionAndSortWithManualLocalSort(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.ordered(tableOrdering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioning = RangePartitioning(writeOrdering, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.sortWithinPartitions("data", "id"), + writeCommand = command) + } + + test("clustered distribution and local sort with manual global sort: append") { + checkClusteredDistributionAndLocalSortWithManualGlobalSort("append") + } + + test("clustered distribution and local sort with manual global sort: overwrite") { + checkClusteredDistributionAndLocalSortWithManualGlobalSort("overwrite") + } + + test("clustered distribution and local sort with manual global sort: overwriteDynamic") { + checkClusteredDistributionAndLocalSortWithManualGlobalSort("overwriteDynamic") + } + + private def checkClusteredDistributionAndLocalSortWithManualGlobalSort(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.clustered(Array(FieldReference("data"))) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data")) + val writePartitioning = HashPartitioning(writePartitioningExprs, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy("data", "id"), + writeCommand = command) + } + + test("clustered distribution and local sort with manual local sort: append") { + checkClusteredDistributionAndLocalSortWithManualLocalSort("append") + } + + test("clustered distribution and local sort with manual local sort: overwrite") { + checkClusteredDistributionAndLocalSortWithManualLocalSort("overwrite") + } + + test("clustered distribution and local sort with manual local sort: overwriteDynamic") { + checkClusteredDistributionAndLocalSortWithManualLocalSort("overwriteDynamic") + } + + private def checkClusteredDistributionAndLocalSortWithManualLocalSort(command: String): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.clustered(Array(FieldReference("data"))) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data")) + val writePartitioning = HashPartitioning(writePartitioningExprs, conf.numShufflePartitions) + + checkWriteRequirements( + tableDistribution, + tableOrdering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy("data", "id"), + writeCommand = command) + } + + private def checkWriteRequirements( + tableDistribution: Distribution, + tableOrdering: Array[SortOrder], + expectedWritePartitioning: physical.Partitioning, + expectedWriteOrdering: Seq[catalyst.expressions.SortOrder], + writeTransform: DataFrame => DataFrame = df => df, + writeCommand: String = "append"): Unit = { + + catalog.createTable(ident, schema, Array.empty, emptyProps, tableDistribution, tableOrdering) + + val df = spark.createDataFrame(Seq((1, "a"), (2, "b"), (3, "c"))).toDF("id", "data") + val writer = writeTransform(df).writeTo(tableNameAsString) + val executedPlan = writeCommand match { + case "append" => execute(writer.append()) + case "overwrite" => execute(writer.overwrite(lit(true))) + case "overwriteDynamic" => execute(writer.overwritePartitions()) + } + + checkPartitioningAndOrdering(executedPlan, expectedWritePartitioning, expectedWriteOrdering) + + checkAnswer(spark.table(tableNameAsString), df) + } + + private def checkPartitioningAndOrdering( + plan: SparkPlan, + partitioning: physical.Partitioning, + ordering: Seq[catalyst.expressions.SortOrder]): Unit = { + + val sorts = collect(plan) { case s: SortExec => s } + assert(sorts.size <= 1, "must be at most one sort") + val shuffles = collect(plan) { case s: ShuffleExchangeLike => s } + assert(shuffles.size <= 1, "must be at most one shuffle") + + val actualPartitioning = plan.outputPartitioning + val expectedPartitioning = partitioning match { + case p: physical.RangePartitioning => + val resolvedOrdering = p.ordering.map(resolveAttrs(_, plan)) + p.copy(ordering = resolvedOrdering.asInstanceOf[Seq[catalyst.expressions.SortOrder]]) + case p: physical.HashPartitioning => + val resolvedExprs = p.expressions.map(resolveAttrs(_, plan)) + p.copy(expressions = resolvedExprs) + case other => other + } + assert(actualPartitioning == expectedPartitioning, "partitioning must match") + + val actualOrdering = plan.outputOrdering + val expectedOrdering = ordering.map(resolveAttrs(_, plan)) + assert(actualOrdering == expectedOrdering, "ordering must match") + } + + private def resolveAttrs( + expr: catalyst.expressions.Expression, + plan: SparkPlan): catalyst.expressions.Expression = { + + expr.transform { + case UnresolvedAttribute(Seq(attrName)) => + plan.output.find(attr => resolver(attr.name, attrName)).get + case UnresolvedAttribute(nameParts) => + val attrName = nameParts.mkString(".") + fail(s"cannot resolve a nested attr: $attrName") + } + } + + private def attr(name: String): UnresolvedAttribute = { + UnresolvedAttribute(name) + } + + private def catalog: InMemoryTableCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("testcat") + catalog.asTableCatalog.asInstanceOf[InMemoryTableCatalog] + } + + // executes a write operation and keeps the executed physical plan + private def execute(writeFunc: => Unit): SparkPlan = { + var executedPlan: SparkPlan = null + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + executedPlan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + } + } + spark.listenerManager.register(listener) + + writeFunc + + sparkContext.listenerBus.waitUntilEmpty() + + executedPlan match { + case w: V2TableWriteExec => + stripAQEPlan(w.query) + case _ => + fail("expected V2TableWriteExec") + } + } +}