diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233f..e48b0fa0eafd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import scala.language.existentials + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -30,7 +32,10 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * - Intra-partition ordering of data: In this case the distribution describes guarantees made * about how tuples are distributed within a single partition. */ -sealed trait Distribution +sealed trait Distribution { + /** If defined, then represents how many partitions are expected by the distribution */ + def numPartitions: Option[Int] = None +} /** * Represents a distribution where no promises are made about co-location of data. @@ -49,12 +54,20 @@ case object AllTuples extends Distribution * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + numClusters: Option[Int] = None, + hashingFunctionClass: Option[Class[_ <: HashExpression[Int]]] = None) + extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + require(numClusters.isEmpty || numClusters.get > 0, + "Number of cluster (if set) should only be a positive integer") + + override def numPartitions: Option[Int] = numClusters } /** @@ -234,7 +247,10 @@ case object SinglePartition extends Partitioning { * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class HashPartitioning( + expressions: Seq[Expression], + numPartitions: Int, + hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash]) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions @@ -243,8 +259,10 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, numClusters, hashingFunctionClazz) => + (numClusters.isEmpty || numClusters.get == numPartitions) && + (hashingFunctionClazz.isEmpty || hashingFunctionClazz.get == hashingFunctionClass) && + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } @@ -260,9 +278,16 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) /** * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less - * than numPartitions) based on hashing expressions. + * than numPartitions) based on hashing expression(s) and the hashing function. */ - def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + def partitionIdExpression: Expression = { + val hashExpression = hashingFunctionClass match { + case m if m == classOf[Murmur3Hash] => new Murmur3Hash(expressions) + case h if h == classOf[HiveHash] => HiveHash(expressions) + case _ => throw new Exception(s"Unsupported hashingFunction: $hashingFunctionClass") + } + Pmod(hashExpression, Literal(numPartitions)) + } } /** @@ -289,8 +314,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, numClusters, hashingFunctionClass) => + (numClusters.isEmpty || numClusters.get == numPartitions) && hashingFunctionClass.isEmpty && + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d5..4f483167a642 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{HiveHash, Murmur3Hash} import org.apache.spark.sql.catalyst.plans.physical._ class DistributionSuite extends SparkFunSuite { @@ -79,6 +80,26 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(10), Some(classOf[Murmur3Hash])), + true) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(12), Some(classOf[Murmur3Hash])), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('d, 'e), Some(10), Some(classOf[Murmur3Hash])), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(10), Some(classOf[HiveHash])), + false) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), AllTuples, @@ -127,19 +148,34 @@ class DistributionSuite extends SparkFunSuite { checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + ClusteredDistribution(Seq('a, 'b, 'c), Some(10), None), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'b, 'a)), + ClusteredDistribution(Seq('c, 'b, 'a), Some(10), None), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('b, 'c, 'a, 'd)), + ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(10), None), true) + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(10), Some(classOf[Murmur3Hash])), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(12), Some(classOf[Murmur3Hash])), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(10), Some(classOf[HiveHash])), + false) + // Cases which need an exchange between two data properties. // TODO: We can have an optimization to first sort the dataset // by a.asc and then sort b, and c in a partition. This optimization diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala index 5b802ccc637d..0e5995fd9119 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.expressions.{HiveHash, InterpretedMutableProjection, Literal, Murmur3Hash} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} class PartitioningSuite extends SparkFunSuite { + private val expressions = Seq(Literal(2), Literal(3)) + test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { - val expressions = Seq(Literal(2), Literal(3)) // Consider two HashPartitionings that have the same _set_ of hash expressions but which are // created with different orderings of those expressions: val partitioningA = HashPartitioning(expressions, 100) @@ -34,11 +35,13 @@ class PartitioningSuite extends SparkFunSuite { val distribution = ClusteredDistribution(expressions) assert(partitioningA.satisfies(distribution)) assert(partitioningB.satisfies(distribution)) + // These partitionings compute different hashcodes for the same input row: def computeHashCode(partitioning: HashPartitioning): Int = { val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) hashExprProj.apply(InternalRow.empty).hashCode() } + assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) // Thus, these partitionings are incompatible: assert(!partitioningA.compatibleWith(partitioningB)) @@ -52,4 +55,18 @@ class PartitioningSuite extends SparkFunSuite { assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } + + test("HashPartitioning compatibility should be sensitive to hashing function") { + val partitioningA = HashPartitioning(expressions, 100, classOf[Murmur3Hash]) + val partitioningB = HashPartitioning(expressions, 100, classOf[HiveHash]) + assert(partitioningA != partitioningB) + assert(!partitioningA.compatibleWith(partitioningB)) + } + + test("HashPartitioning compatibility should be sensitive to number of partitions") { + val partitioningA = HashPartitioning(expressions, 10, classOf[Murmur3Hash]) + val partitioningB = HashPartitioning(expressions, 1212, classOf[Murmur3Hash]) + assert(partitioningA != partitioningB) + assert(!partitioningA.compatibleWith(partitioningB)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 7cd4baef89e7..57c69d50c556 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -23,9 +23,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SortOrder} import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -43,6 +44,10 @@ trait RunnableCommand extends logical.Command { // `ExecutedCommand` during query planning. lazy val metrics: Map[String, SQLMetric] = Map.empty + def requiredDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + + def requiredOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { throw new NotImplementedError } @@ -94,6 +99,10 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray + override def requiredChildDistribution: Seq[Distribution] = cmd.requiredDistribution + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = cmd.requiredOrdering + protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult, 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 514969715091..3cf19550e1f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -34,12 +34,11 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -107,7 +106,7 @@ object FileFormatWriter extends Logging { outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], - bucketSpec: Option[BucketSpec], + bucketIdExpression: Option[Expression], statsTrackers: Seq[WriteJobStatsTracker], options: Map[String, String]) : Set[String] = { @@ -121,17 +120,6 @@ object FileFormatWriter extends Logging { val partitionSet = AttributeSet(partitionColumns) val dataColumns = allColumns.filterNot(partitionSet.contains) - val bucketIdExpression = bucketSpec.map { spec => - val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - val sortColumns = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) - } - val caseInsensitiveOptions = CaseInsensitiveMap(options) // Note: prepareWrite has side effect. It sets "job". @@ -155,19 +143,6 @@ object FileFormatWriter extends Logging { statsTrackers = statsTrackers ) - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns - // the sort order doesn't matter - val actualOrdering = plan.outputOrdering.map(_.child) - val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { - false - } else { - requiredOrdering.zip(actualOrdering).forall { - case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) - } - } - SQLExecution.checkSQLExecutionId(sparkSession) // This call shouldn't be put into the `try` block below because it only initializes and @@ -175,14 +150,7 @@ object FileFormatWriter extends Logging { committer.setupJob(job) try { - val rdd = if (orderingMatched) { - plan.execute() - } else { - SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), - global = false, - child = plan).execute() - } + val rdd = plan.execute() val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( rdd, @@ -195,7 +163,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rdd.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res @@ -514,18 +482,18 @@ object FileFormatWriter extends Logging { var recordsInFile: Long = 0L var fileCounter = 0 val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None + var currentPartitionValues: Option[UnsafeRow] = None var currentBucketId: Option[Int] = None for (row <- iter) { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + if (isPartitioned && currentPartitionValues != nextPartitionValues) { + currentPartitionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) } if (isBucketed) { currentBucketId = nextBucketId @@ -536,7 +504,7 @@ object FileFormatWriter extends Logging { fileCounter = 0 releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) + newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions) } else if (desc.maxRecordsPerFile > 0 && recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. @@ -547,7 +515,7 @@ object FileFormatWriter extends Logging { s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) + newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions) } val outputRow = getOutputRow(row) currentWriter.write(outputRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 64e5a57adc37..b075962bb4f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -25,8 +25,9 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, HiveHash, Murmur3Hash, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.util.SchemaUtils @@ -141,6 +142,10 @@ case class InsertIntoHadoopFsRelationCommand( } } + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = query.output.filterNot(partitionSet.contains) + val bucketIdExpression = getBucketIdExpression(dataColumns) + val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, @@ -151,7 +156,7 @@ case class InsertIntoHadoopFsRelationCommand( qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, partitionColumns = partitionColumns, - bucketSpec = bucketSpec, + bucketIdExpression = bucketIdExpression, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = options) @@ -175,6 +180,43 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + private def getBucketIdExpression(dataColumns: Seq[Attribute]): Option[Expression] = { + bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning( + bucketColumns, + spec.numBuckets, + classOf[Murmur3Hash] + ).partitionIdExpression + } + } + + /** + * How is `requiredOrdering` determined ? + * + * table type | requiredOrdering + * -----------------+------------------------------------------------- + * normal table | partition columns + * bucketed table | (partition columns + bucketId + sort columns) + * -----------------+------------------------------------------------- + */ + override def requiredOrdering: Seq[Seq[SortOrder]] = { + val sortExpressions = bucketSpec match { + case Some(spec) => + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = query.output.filterNot(partitionSet.contains) + val bucketIdExpression = getBucketIdExpression(dataColumns) + val sortColumns = spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + partitionColumns ++ bucketIdExpression ++ sortColumns + + case _ => partitionColumns + } + Seq(sortExpressions.map(SortOrder(_, Ascending))) + } + /** * Deletes all partition files that match the specified static prefix. Partitions with custom * locations are also cleared based on the custom locations map given to this class. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b91d07744255..9b5186eeca1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.exchange +import scala.language.existentials + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -47,10 +49,40 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def createPartitioning( requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { + numPartitions: Int, + childPartitionings: Seq[Partitioning] = Seq()): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + + case ClusteredDistribution(clustering, numClusters, hashingFunctionClass) => + assert(numClusters.isEmpty || numPartitions == numClusters.get) + + hashingFunctionClass match { + case Some(clazz) => + HashPartitioning(clustering, numPartitions, clazz) + case None => + val distinctChildHashingFunctions = childPartitionings.map { + case HashPartitioning(_, _, hashingFunction) => hashingFunction + case _ => classOf[Murmur3Hash] + }.distinct + + // If all the children use the same hashing function, then use it. Else fallback to the + // default hashing function (ie. Murmur3Hash). This might not be the most optimal thing + // to do. eg. In case of join, if left child is hashed using HiveHash and the right one + // using Murmur3Hash, this would shuffle the left relation. If the left relation is + // larger than the right relation, the cost of shuffling it will be high. Instead more + // optimal thing to do would be to shuffle the right relation using HiveHash so that + // at the join side both the children are shuffled using the same hashing function. + // Using Murmur3Hash might turn out better if there are downstream operator's needing + // data partitioned over Murmur3Hash which cannot be estimated at this point. + val targetHashingClass = if (distinctChildHashingFunctions.length == 1) { + distinctChildHashingFunctions.head + } else { + classOf[Murmur3Hash] + } + HashPartitioning(clustering, numPartitions, targetHashingClass) + } + case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } @@ -135,8 +167,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // // It will be great to introduce a new Partitioning to represent the post-shuffle // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + val numPartitions = distribution.numPartitions.getOrElse(defaultNumPreShufflePartitions) + val targetPartitioning = createPartitioning(distribution, numPartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) ShuffleExchange(targetPartitioning, child, Some(coordinator)) } @@ -155,6 +187,15 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) + // We don't expect an operator to expect different number of partitions across its children + val distinctNumPartitonsExpected = requiredChildDistributions.flatMap(_.numPartitions).distinct + assert(distinctNumPartitonsExpected.size <= 1) + val numPreShufflePartitions = if (distinctNumPartitonsExpected.isEmpty) { + defaultNumPreShufflePartitions + } else { + distinctNumPartitonsExpected.head + } + // Ensure that the operator's children satisfy their output distribution requirements: children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => @@ -162,7 +203,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + ShuffleExchange(createPartitioning(distribution, numPreShufflePartitions), child) } // If the operator has multiple children and specifies child output distributions (e.g. join), @@ -179,11 +220,18 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // First check if the existing partitions of the children all match. This means they are // partitioned by the same partitioning into the same number of partitions. In that case, // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + val childPartitionings = children.map(_.outputPartitioning) + val maxChildrenNumPartitions = childPartitionings.map(_.numPartitions).max + val useExistingPartitioning = childPartitionings.zip(requiredChildDistributions).forall { + case (childPartitioning, distribution) => + distribution.numPartitions match { + case Some(expectedPartitions) if expectedPartitions != maxChildrenNumPartitions => + false + case None => + val targetPartitioning = + createPartitioning(distribution, maxChildrenNumPartitions, childPartitionings) + childPartitioning.guarantees(targetPartitioning) + } } children = if (useExistingPartitioning) { @@ -194,21 +242,32 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // Now, we will determine the number of partitions that will be used by created // partitioning schemes. val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + if (distinctNumPartitonsExpected.nonEmpty) { + distinctNumPartitonsExpected.head + } else { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => + distribution.numPartitions match { + case Some(expectedPartitions) if expectedPartitions != maxChildrenNumPartitions => + true + case None => + val targetPartitioning = + createPartitioning(distribution, maxChildrenNumPartitions, childPartitionings) + !child.outputPartitioning.guarantees(targetPartitioning) + } + } + // If we need to shuffle all children, we use numPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) numPreShufflePartitions else maxChildrenNumPartitions } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions } children.zip(requiredChildDistributions).map { case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) + val targetPartitioning = + createPartitioning(distribution, numPartitions, childPartitionings) if (child.outputPartitioning.guarantees(targetPartitioning)) { child } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index eebe6ad2e794..5d14ba2a8d09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -203,7 +203,7 @@ object ShuffleExchange { serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(_, n) => + case HashPartitioning(_, n, _) => new Partitioner { override def numPartitions: Int = n // For HashPartitioning, the partitioning key is already a valid partition ID, as we use diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala index 534d8c5689c2..7ba01d59ab87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala @@ -54,13 +54,13 @@ class ReorderJoinPredicates extends Rule[SparkPlan] { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { leftPartitioning match { - case HashPartitioning(leftExpressions, _) + case HashPartitioning(leftExpressions, _, _) if leftExpressions.length == leftKeys.length && leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => reorder(leftExpressions, leftKeys) case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) + case HashPartitioning(rightExpressions, _, _) if rightExpressions.length == rightKeys.length && rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => reorder(rightExpressions, rightKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 72e5ac40bbfe..f99357ecee71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -127,7 +127,7 @@ class FileStreamSink( outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, partitionColumns = partitionColumns, - bucketSpec = None, + bucketIdExpression = None, statsTrackers = Nil, options = options) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7c0b9bf19bf3..8ed17f191d8c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -51,6 +51,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveExternalCatalog.DATASOURCE_PROVIDER import org.apache.spark.sql.hive.client.HiveClientImpl._ import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, Utils} @@ -918,7 +919,10 @@ private[hive] object HiveClientImpl { } table.bucketSpec match { - case Some(bucketSpec) if DDLUtils.isHiveTable(table) => + case Some(bucketSpec) if DDLUtils.isHiveTable(table) || + (table.tableType != CatalogTableType.VIEW && + table.properties.get(DATASOURCE_PROVIDER).isEmpty) => + hiveTable.setNumBuckets(bucketSpec.numBuckets) hiveTable.setBucketCols(bucketSpec.bucketColumnNames.toList.asJava) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 46610f84dd82..b5b52680e730 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -22,11 +22,12 @@ import java.net.URI import java.text.SimpleDateFormat import java.util.{Date, Locale, Random} +import scala.collection.mutable import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.hive.common.FileUtils +import org.apache.hadoop.hive.common.{FileUtils, HiveStatsUtils} import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.exec.TaskRunner import org.apache.hadoop.hive.ql.plan.TableDesc @@ -35,8 +36,9 @@ import org.apache.spark.SparkException import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, HiveHash, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashPartitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.{CommandUtils, DataWritingCommand} import org.apache.spark.sql.execution.datasources.FileFormatWriter @@ -311,25 +313,10 @@ case class InsertIntoHiveTable( } } - table.bucketSpec match { - case Some(bucketSpec) => - // Writes to bucketed hive tables are allowed only if user does not care about maintaining - // table's bucketing ie. both "hive.enforce.bucketing" and "hive.enforce.sorting" are - // set to false - val enforceBucketingConfig = "hive.enforce.bucketing" - val enforceSortingConfig = "hive.enforce.sorting" - - val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + - "currently does NOT populate bucketed output which is compatible with Hive." - - if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || - hadoopConf.get(enforceSortingConfig, "true").toBoolean) { - throw new AnalysisException(message) - } else { - logWarning(message + s" Inserting data anyways since both $enforceBucketingConfig and " + - s"$enforceSortingConfig are set to false.") - } - case _ => // do nothing since table has no bucketing + if (!overwrite && table.bucketSpec.isDefined) { + throw new SparkException("Appending data to hive bucketed table is not allowed as it " + + "will break the table's bucketing guarantee. Consider overwriting instead. Table = " + + table.qualifiedName) } val committer = FileCommitProtocol.instantiate( @@ -344,6 +331,9 @@ case class InsertIntoHiveTable( }.asInstanceOf[Attribute] } + val (_, dataColumns) = getPartitionAndDataColumns + val bucketIdExpression = getBucketIdExpression(dataColumns) + FileFormatWriter.write( sparkSession = sparkSession, plan = children.head, @@ -352,10 +342,21 @@ case class InsertIntoHiveTable( outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, - bucketSpec = None, + bucketIdExpression = bucketIdExpression, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = Map.empty) + // validate bucketing based on number of files before loading to metastore + table.bucketSpec.foreach { spec => + if (partition.nonEmpty && numDynamicPartitions > 0) { + val validPartitionPaths = + getValidPartitionPaths(hadoopConf, tmpLocation, numDynamicPartitions) + validateBuckets(hadoopConf, validPartitionPaths, table.bucketSpec.get.numBuckets) + } else { + validateBuckets(hadoopConf, Seq(tmpLocation), table.bucketSpec.get.numBuckets) + } + } + if (partition.nonEmpty) { if (numDynamicPartitions > 0) { externalCatalog.loadDynamicPartitions( @@ -371,10 +372,10 @@ case class InsertIntoHiveTable( // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries // scalastyle:on val oldPart = - externalCatalog.getPartitionOption( - table.database, - table.identifier.table, - partitionSpec) + externalCatalog.getPartitionOption( + table.database, + table.identifier.table, + partitionSpec) var doHiveOverwrite = overwrite @@ -447,4 +448,132 @@ case class InsertIntoHiveTable( // TODO: implement hive compatibility as rules. Seq.empty[Row] } + + private def getValidPartitionPaths( + conf: Configuration, + outputPath: Path, + numDynamicPartitions: Int): Seq[Path] = { + val validPartitionPaths = mutable.HashSet[Path]() + try { + val fs = outputPath.getFileSystem(conf) + HiveStatsUtils.getFileStatusRecurse(outputPath, numDynamicPartitions, fs) + .filter(_.isDirectory) + .foreach(d => validPartitionPaths.add(d.getPath)) + } catch { + case e: IOException => + throw new SparkException("Unable to extract partition paths from temporary output " + + s"location $outputPath due to : ${e.getMessage}", e) + } + validPartitionPaths.toSeq + } + + private def validateBuckets(conf: Configuration, outputPaths: Seq[Path], numBuckets: Int) = { + val bucketedFilePattern = """part-(\d+)(?:.*)?$""".r + + def getBucketIdFromFilename(fileName : String): Option[Int] = + fileName match { + case bucketedFilePattern(bucketId) => Some(bucketId.toInt) + case _ => None + } + + outputPaths.foreach(outputPath => { + val fs = outputPath.getFileSystem(conf) + val allFiles = fs.listStatus(outputPath) + if (allFiles != null && allFiles.nonEmpty) { + val files = allFiles.filterNot(_.getPath.getName == "_SUCCESS") + .map(_.getPath.getName) + .sortBy(_.toString) + + var expectedBucketId = 0 + files.foreach { case file => + getBucketIdFromFilename(file) match { + case Some(id) if id == expectedBucketId => + expectedBucketId += 1 + case Some(_) => + throw new SparkException( + s"Potentially missing bucketed output files in temporary bucketed output " + + s"location. Aborting job. Output location : $outputPath, files found : " + + files.mkString("[", ",", "]")) + case None => + throw new SparkException( + s"Invalid file found in temporary bucketed output location. Aborting job. " + + s"Output location : $outputPath, bad file : $file") + } + } + + if (expectedBucketId != numBuckets) { + throw new SparkException( + s"Potentially missing bucketed output files in temporary bucketed output location. " + + s"Aborting job. Output location : $outputPath, files found : " + + files.mkString("[", ",", "]")) + } + } + }) + } + + private def getPartitionAndDataColumns: (Seq[Attribute], Seq[Attribute]) = { + val allColumns = query.output + val partitionColumnNames = partition.keySet + allColumns.partition(c => partitionColumnNames.contains(c.name)) + } + + /** + * Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + * guarantee the data distribution is same between shuffle and bucketed data source, which + * enables us to only shuffle one side when join a bucketed table and a normal one. + */ + private def getBucketIdExpression(dataColumns: Seq[Attribute]): Option[Expression] = + table.bucketSpec.map { spec => + HashPartitioning( + spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get), + spec.numBuckets, + classOf[HiveHash] + ).partitionIdExpression + } + + /** + * If the table is bucketed, then requiredDistribution would be the bucket columns. + * Else it would be empty + */ + override def requiredDistribution: Seq[Distribution] = table.bucketSpec match { + case Some(bucketSpec) => + val (_, dataColumns) = getPartitionAndDataColumns + Seq(ClusteredDistribution( + bucketSpec.bucketColumnNames.map(b => dataColumns.find(_.name == b).get), + Option(bucketSpec.numBuckets), + Some(classOf[HiveHash]) + )) + + case _ => Seq(UnspecifiedDistribution) + } + + /** + * How is `requiredOrdering` determined ? + * + * table type | normal table | bucketed table + * --------------------+--------------------+----------------------------------------------- + * non-partitioned | Nil | sort columns + * static partition | Nil | sort columns + * dynamic partition | partition columns | (partition columns + bucketId + sort columns) + * --------------------+--------------------+----------------------------------------------- + */ + override def requiredOrdering: Seq[Seq[SortOrder]] = { + val (partitionColumns, dataColumns) = getPartitionAndDataColumns + val isDynamicPartitioned = + table.partitionColumnNames.nonEmpty && partition.values.exists(_.isEmpty) + + val sortExpressions = table.bucketSpec match { + case Some(bucketSpec) => + val sortColumns = bucketSpec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + if (isDynamicPartitioned) { + partitionColumns ++ getBucketIdExpression(dataColumns) ++ sortColumns + } else { + sortColumns + } + + case _ => if (isDynamicPartitioned) partitionColumns else Nil + } + + Seq(sortExpressions.map(SortOrder(_, Ascending))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index cc80f2e481cb..20dae9150ce9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.hive import java.io.File +import java.net.URI import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.expressions.HiveHashFunction import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -497,31 +497,186 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } } - testBucketedTable("INSERT should NOT fail if strict bucketing is NOT enforced") { - tableName => - withSQLConf("hive.enforce.bucketing" -> "false", "hive.enforce.sorting" -> "false") { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 4, 2 AS c, 3 AS b") - checkAnswer(sql(s"SELECT a, b, c, d FROM $tableName"), Row(1, 2, 3, 4)) + private def validateBucketingAndSorting(numBuckets: Int, dir: URI): Unit = { + val bucketFiles = new File(dir).listFiles().filter(_.getName.startsWith("part-")) + .sortWith((x, y) => x.getName < y.getName) + assert(bucketFiles.length === numBuckets) + + bucketFiles.zipWithIndex.foreach { case(bucketFile, bucketId) => + val rows = spark.read.format("text").load(bucketFile.getAbsolutePath).collect() + var prevKey: Option[Int] = None + rows.foreach(row => { + val key = row.getString(0).split("\t")(0).toInt + assert(HiveHashFunction.hash(key, IntegerType, seed = 0) % numBuckets === bucketId) + + if (prevKey.isDefined) { + assert(prevKey.get <= key) + } + prevKey = Some(key) + }) + } + } + + test("Write data to a non-partitioned bucketed table") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + val session = spark.sessionState + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString)).toDF("key", "value") + .write.mode(SaveMode.Overwrite).insertInto(tableName) + + val dir = session.catalog.defaultTablePath(session.sqlParser.parseTableIdentifier(tableName)) + validateBucketingAndSorting(numBuckets, dir) + } + } + + test("Write data to a bucketed table with static partition") { + val numBuckets = 8 + val tableName = "bucketizedTable" + val sourceTableName = "sourceTable" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName, sourceTableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString)) + .toDF("key", "value") + .createOrReplaceTempView(sourceTableName) + + sql(s""" + |INSERT OVERWRITE TABLE $tableName PARTITION(part1="val1", part2="val2") + |SELECT key, value + |FROM $sourceTableName + |""".stripMargin) + + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> "val1", "part2" -> "val2") + ).location + + validateBucketingAndSorting(numBuckets, dir) } + } } - testBucketedTable("INSERT should fail if strict bucketing / sorting is enforced") { - tableName => - withSQLConf("hive.enforce.bucketing" -> "true", "hive.enforce.sorting" -> "false") { - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") + test("Write data to a bucketed table with dynamic partitions") { + val numBuckets = 7 + val tableName = "bucketizedTable" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 1000) + .map(i => (i, i.toString, (if (i > 50) i % 2 else 2 - i % 2).toString, (i % 3).toString)) + .toDF("key", "value", "part1", "part2") + .write.mode(SaveMode.Overwrite).insertInto(tableName) + + (0 until 2).zip(0 until 3).foreach { case (part1, part2) => + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> part1.toString, "part2" -> part2.toString) + ).location + + validateBucketingAndSorting(numBuckets, dir) } } - withSQLConf("hive.enforce.bucketing" -> "false", "hive.enforce.sorting" -> "true") { - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") + } + } + + test("Write data to a bucketed table with dynamic partitions (along with static partitions)") { + val numBuckets = 8 + val tableName = "bucketizedTable" + val sourceTableName = "sourceTable" + val part1StaticValue = "0" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName, sourceTableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString, (i % 3).toString)) + .toDF("key", "value", "part") + .createOrReplaceTempView(sourceTableName) + + sql(s""" + |INSERT OVERWRITE TABLE $tableName PARTITION(part1="$part1StaticValue", part2) + |SELECT key, value, part + |FROM $sourceTableName + |""".stripMargin) + + (0 until 3).foreach { case part2 => + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> part1StaticValue, "part2" -> part2.toString) + ).location + + validateBucketingAndSorting(numBuckets, dir) } } - withSQLConf("hive.enforce.bucketing" -> "true", "hive.enforce.sorting" -> "true") { - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") - } + } + } + + test("Appends to bucketed table should NOT be allowed as it breaks bucketing guarantee") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + val df = (0 until 100).map(i => (i, i.toString)).toDF("key", "value") + val e = intercept[SparkException] { + df.write.mode(SaveMode.Append).insertInto(tableName) } + assert(e.getMessage.contains("Appending data to hive bucketed table is not allowed")) + } + } + + test("Fail the query if number of files produced != number of buckets") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + val df = (0 until (numBuckets / 2)).map(i => (i, i.toString)).toDF("key", "value") + val e = intercept[SparkException] { + df.write.mode(SaveMode.Overwrite).insertInto(tableName) + } + assert(e.getMessage.contains("Potentially missing bucketed output files")) + } } test("SPARK-20594: hive.exec.stagingdir was deleted by Hive") {