diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5a89a90b735a6..5ac3f1f5b0cac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -216,26 +216,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - lazy val clusteringSet = expressions.toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + expressions.toSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } + } /** @@ -257,15 +254,13 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = ordering.map(_.child).toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala new file mode 100644 index 0000000000000..5b802ccc637dd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} + +class PartitioningSuite extends SparkFunSuite { + test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { + val expressions = Seq(Literal(2), Literal(3)) + // Consider two HashPartitionings that have the same _set_ of hash expressions but which are + // created with different orderings of those expressions: + val partitioningA = HashPartitioning(expressions, 100) + val partitioningB = HashPartitioning(expressions.reverse, 100) + // These partitionings are not considered equal: + assert(partitioningA != partitioningB) + // However, they both satisfy the same clustered distribution: + val distribution = ClusteredDistribution(expressions) + assert(partitioningA.satisfies(distribution)) + assert(partitioningB.satisfies(distribution)) + // These partitionings compute different hashcodes for the same input row: + def computeHashCode(partitioning: HashPartitioning): Int = { + val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) + hashExprProj.apply(InternalRow.empty).hashCode() + } + assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) + // Thus, these partitionings are incompatible: + assert(!partitioningA.compatibleWith(partitioningB)) + assert(!partitioningB.compatibleWith(partitioningA)) + assert(!partitioningA.guarantees(partitioningB)) + assert(!partitioningB.guarantees(partitioningA)) + + // Just to be sure that we haven't cheated by having these methods always return false, + // check that identical partitionings are still compatible with and guarantee each other: + assert(partitioningA === partitioningA) + assert(partitioningA.guarantees(partitioningA)) + assert(partitioningA.compatibleWith(partitioningA)) + } +}