diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 907198ad5d23..6f287028f740 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -21,6 +21,7 @@ import java.util.Random import java.util.function.Supplier import scala.concurrent.Future +import scala.util.hashing import org.apache.spark._ import org.apache.spark.internal.config @@ -299,7 +300,14 @@ object ShuffleExchangeExec { def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { case RoundRobinPartitioning(numPartitions) => // Distributes elements evenly across output partitions, starting from a random partition. - var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) + // nextInt(numPartitions) implementation has a special case when bound is a power of 2, + // which is basically taking several highest bits from the initial seed, with only a + // minimal scrambling. Due to deterministic seed, using the generator only once, + // and lack of scrambling, the position values for power-of-two numPartitions always + // end up being almost the same regardless of the index. substantially scrambling the + // seed by hashing will help. Refer to SPARK-21782 for more details. + val partitionId = TaskContext.get().partitionId() + var position = new Random(hashing.byteswap32(partitionId)).nextInt(numPartitions) (row: InternalRow) => { // The HashPartitioner will handle the `mod` by the number of partitions position += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d7ea766b21b6..bc7775144e16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2147,6 +2147,12 @@ class DatasetSuite extends QueryTest (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) } + + test("SPARK-40407: repartition should not result in severe data skew") { + val df = spark.range(0, 100, 1, 50).repartition(4) + val result = df.mapPartitions(iter => Iterator.single(iter.length)).collect() + assert(result.sorted.toSeq === Seq(19, 25, 25, 31)) + } } case class Bar(a: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index ddf500b3f93b..e475b0fdaa14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2127,8 +2127,8 @@ class AdaptiveQueryExecSuite withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") { // partition size [0,258,72,72,72] checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4) - // partition size [72,216,216,144,72] - checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 4, 7) + // partition size [144,72,144,216,144] + checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 2, 6) } // no skewed partition should be optimized