@@ -1129,6 +1129,36 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
11291129 }.collect()
11301130 }
11311131
1132+ test(" SPARK-23496: order of input partitions can result in severe skew in coalesce" ) {
1133+ val numInputPartitions = 100
1134+ val numCoalescedPartitions = 50
1135+ val locations = Array (" locA" , " locB" )
1136+
1137+ val inputRDD = sc.makeRDD(Range (0 , numInputPartitions).toArray[Int ], numInputPartitions)
1138+ assert(inputRDD.getNumPartitions == numInputPartitions)
1139+
1140+ val locationPrefRDD = new LocationPrefRDD (inputRDD, { (p : Partition ) =>
1141+ if (p.index < numCoalescedPartitions) {
1142+ Seq (locations(0 ))
1143+ } else {
1144+ Seq (locations(1 ))
1145+ }
1146+ })
1147+ val coalescedRDD = new CoalescedRDD (locationPrefRDD, numCoalescedPartitions)
1148+
1149+ val numPartsPerLocation = coalescedRDD
1150+ .getPartitions
1151+ .map(coalescedRDD.getPreferredLocations(_).head)
1152+ .groupBy(identity)
1153+ .mapValues(_.size)
1154+
1155+ // Without the fix these would be:
1156+ // numPartsPerLocation(locations(0)) == numCoalescedPartitions - 1
1157+ // numPartsPerLocation(locations(1)) == 1
1158+ assert(numPartsPerLocation(locations(0 )) > 0.4 * numCoalescedPartitions)
1159+ assert(numPartsPerLocation(locations(1 )) > 0.4 * numCoalescedPartitions)
1160+ }
1161+
11321162 // NOTE
11331163 // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
11341164 // running after them and if they access sc those tests will fail as sc is already closed, because
@@ -1210,3 +1240,16 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria
12101240 groups.toArray
12111241 }
12121242}
1243+
1244+ /** Alters the preferred locations of the parent RDD using provided function. */
1245+ class LocationPrefRDD [T : ClassTag ](
1246+ @ transient var prev : RDD [T ],
1247+ val locationPicker : Partition => Seq [String ]) extends RDD [T ](prev) {
1248+ override protected def getPartitions : Array [Partition ] = prev.partitions
1249+
1250+ override def compute (partition : Partition , context : TaskContext ): Iterator [T ] =
1251+ null .asInstanceOf [Iterator [T ]]
1252+
1253+ override def getPreferredLocations (partition : Partition ): Seq [String ] =
1254+ locationPicker(partition)
1255+ }
0 commit comments