diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 09d0a8189d25..d69a14ff1469 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -112,6 +112,7 @@ private object ParallelCollectionRDD { * it efficient to run Spark over RDDs representing large sets of numbers. */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { + def validNumSlices(vSeq: Seq[_]) = if (vSeq.length < numSlices) vSeq.length else numSlices if (numSlices < 1) { throw new IllegalArgumentException("Positive number of slices required") } @@ -126,28 +127,31 @@ private object ParallelCollectionRDD { r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { - (0 until numSlices).map(i => { - val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i + 1) * r.length.toLong) / numSlices).toInt + val vNumSlices = validNumSlices(r) + (0 until vNumSlices).map(i => { + val start = ((i * r.length.toLong) / vNumSlices).toInt + val end = (((i + 1) * r.length.toLong) / vNumSlices).toInt new Range(r.start + start * r.step, r.start + end * r.step, r.step) }).asInstanceOf[Seq[Seq[T]]] } case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc - val slices = new ArrayBuffer[Seq[T]](numSlices) - val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything + val vNumSlices = validNumSlices(nr) + val slices = new ArrayBuffer[Seq[T]](vNumSlices) var r = nr - for (i <- 0 until numSlices) { + for (i <- 0 until vNumSlices) { + val sliceSize = (((i + 1) * nr.length.toLong) / vNumSlices).toInt - ((i * nr.length.toLong) / vNumSlices).toInt slices += r.take(sliceSize).asInstanceOf[Seq[T]] r = r.drop(sliceSize) } slices } case _ => { + val vNumSlices = validNumSlices(seq) val array = seq.toArray // To prevent O(n^2) operations for List etc - (0 until numSlices).map(i => { - val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i + 1) * array.length.toLong) / numSlices).toInt + (0 until vNumSlices).map(i => { + val start = ((i * array.length.toLong) / vNumSlices).toInt + val end = (((i + 1) * array.length.toLong) / vNumSlices).toInt array.slice(start, end).toSeq }) }