Skip to content

Commit 4f51bdf

Browse files
tbfenetmarkhamstra
authored andcommitted
updated streaming iterable
Conflicts: core/src/main/scala/org/apache/spark/rdd/RDD.scala core/src/main/scala/org/apache/spark/util/RDDiterable.scala Conflicts: core/src/main/scala/org/apache/spark/rdd/RDD.scala core/src/main/scala/org/apache/spark/util/RDDiterable.scala
1 parent 0324bfd commit 4f51bdf

File tree

4 files changed

+103
-77
lines changed

4 files changed

+103
-77
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator
4141
import org.apache.spark.partial.GroupedCountEvaluator
4242
import org.apache.spark.partial.PartialResult
4343
import org.apache.spark.storage.StorageLevel
44-
import org.apache.spark.util.{RDDiterable, Utils, BoundedPriorityQueue, SerializableHyperLogLog}
44+
import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogLog}
4545

4646
import org.apache.spark.SparkContext._
4747
import org.apache.spark._
@@ -603,8 +603,6 @@ abstract class RDD[T: ClassTag](
603603
sc.runJob(this, (iter: Iterator[T]) => f(iter))
604604
}
605605

606-
607-
608606
/**
609607
* Return an array that contains all of the elements in this RDD.
610608
*/
@@ -626,14 +624,16 @@ abstract class RDD[T: ClassTag](
626624
}
627625

628626
/**
629-
* Return iterable that lazily fetches partitions
630-
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism but also increases
631-
* driver memory requirement
627+
* Return iterator that lazily fetches partitions
628+
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism
629+
* but also increases driver memory requirement.
630+
* @param partitionBatchSize How many partitions fetch per job
632631
* @param timeOut how long to wait for each partition fetch
633632
* @return Iterable of every element in this RDD
634633
*/
635-
def toIterable(prefetchPartitions: Int = 1, timeOut: Duration = Duration(30, TimeUnit.SECONDS)) = {
636-
new RDDiterable[T](this, prefetchPartitions, timeOut)
634+
def toIterator(prefetchPartitions: Int = 1, partitionBatchSize: Int = 10,
635+
timeOut: Duration = Duration(30, TimeUnit.SECONDS)):Iterator[T] = {
636+
new RDDiterator[T](this, prefetchPartitions,partitionBatchSize, timeOut)
637637
}
638638

639639
/**
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package org.apache.spark.rdd
2+
3+
import scala.concurrent.{Await, Future}
4+
import scala.collection.mutable.ArrayBuffer
5+
import scala.concurrent.duration.Duration
6+
import scala.annotation.tailrec
7+
import scala.collection.mutable
8+
import org.apache.spark.rdd.RDDiterator._
9+
import org.apache.spark.FutureAction
10+
11+
/**
12+
* Iterable whose iterator iterates over all elements of an RDD without fetching all partitions
13+
* to the driver process
14+
*
15+
* @param rdd RDD to iterate
16+
* @param prefetchPartitions The number of partitions to prefetch.
17+
* If <1 will not prefetch.
18+
* partitions prefetched = min(prefetchPartitions, partitionBatchSize)
19+
* @param partitionBatchSize How many partitions to fetch per job
20+
* @param timeOut How long to wait for each partition before failing.
21+
*/
22+
class RDDiterator[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, partitionBatchSize: Int,
23+
timeOut: Duration)
24+
extends Iterator[T] {
25+
26+
val batchSize = math.max(1,partitionBatchSize)
27+
var partitionsBatches: Iterator[Seq[Int]] = Range(0, rdd.partitions.size).grouped(batchSize)
28+
var pendingFetchesQueue = mutable.Queue.empty[Future[Seq[Seq[T]]]]
29+
//add prefetchPartitions prefetch
30+
0.until(math.max(0, prefetchPartitions / batchSize)).foreach(x=>enqueueDataFetch())
31+
32+
var currentIterator: Iterator[T] = Iterator.empty
33+
@tailrec
34+
final def hasNext = {
35+
if (currentIterator.hasNext) {
36+
//Still values in the current partition
37+
true
38+
} else {
39+
//Move on to the next partition
40+
//Queue new prefetch of a partition
41+
enqueueDataFetch()
42+
if (pendingFetchesQueue.isEmpty) {
43+
//No more partitions
44+
currentIterator = Iterator.empty
45+
false
46+
} else {
47+
val future = pendingFetchesQueue.dequeue()
48+
currentIterator = Await.result(future, timeOut).flatMap(x => x).iterator
49+
//Next partition might be empty so check again.
50+
this.hasNext
51+
}
52+
}
53+
}
54+
def next() = {
55+
hasNext
56+
currentIterator.next()
57+
}
58+
59+
def enqueueDataFetch() ={
60+
if (partitionsBatches.hasNext) {
61+
pendingFetchesQueue.enqueue(fetchData(partitionsBatches.next(), rdd))
62+
}
63+
}
64+
}
65+
66+
object RDDiterator {
67+
private def fetchData[T: ClassManifest](partitionIndexes: Seq[Int],
68+
rdd: RDD[T]): FutureAction[Seq[Seq[T]]] = {
69+
val results = new ArrayBuffer[Seq[T]]()
70+
rdd.context.submitJob[T, Array[T], Seq[Seq[T]]](rdd,
71+
x => x.toArray,
72+
partitionIndexes,
73+
(inx: Int, res: Array[T]) => results.append(res),
74+
results.toSeq)
75+
}
76+
}

core/src/main/scala/org/apache/spark/util/RDDiterable.scala

Lines changed: 0 additions & 60 deletions
This file was deleted.

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -382,23 +382,33 @@ class RDDSuite extends FunSuite with SharedSparkContext {
382382

383383
test("toIterable") {
384384
var nums = sc.makeRDD(Range(1, 1000), 100)
385-
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
386-
assert(nums.toIterable().toArray === (1 to 999).toArray)
385+
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
386+
assert(nums.toIterator().toArray === (1 to 999).toArray)
387387

388388
nums = sc.makeRDD(Range(1000, 1, -1), 100)
389-
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
390-
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)
389+
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
390+
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)
391391

392392
nums = sc.makeRDD(Range(1, 100), 1000)
393-
assert(nums.toIterable(prefetchPartitions = 10).size === 99)
394-
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1, 100).toArray)
393+
assert(nums.toIterator(prefetchPartitions = 10).size === 99)
394+
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1, 100).toArray)
395395

396396
nums = sc.makeRDD(Range(1, 1000), 100)
397-
assert(nums.toIterable(prefetchPartitions = -1).size === 999)
398-
assert(nums.toIterable().toArray === (1 to 999).toArray)
399-
}
397+
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
398+
assert(nums.toIterator().toArray === (1 to 999).toArray)
399+
400+
nums = sc.makeRDD(Range(1, 1000), 100)
401+
assert(nums.toIterator(prefetchPartitions = 3,partitionBatchSize = 10).size === 999)
402+
assert(nums.toIterator().toArray === (1 to 999).toArray)
400403

404+
nums = sc.makeRDD(Range(1, 1000), 100)
405+
assert(nums.toIterator(prefetchPartitions = -1,partitionBatchSize = 0).size === 999)
406+
assert(nums.toIterator().toArray === (1 to 999).toArray)
401407

408+
nums = sc.makeRDD(Range(1, 1000), 100)
409+
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
410+
assert(nums.toIterator().toArray === (1 to 999).toArray)
411+
}
402412

403413
test("take") {
404414
var nums = sc.makeRDD(Range(1, 1000), 1)

0 commit comments

Comments
 (0)