1717
1818package org .apache .spark .streaming .scheduler
1919
20+ import scala .collection .mutable
21+ import scala .reflect .ClassTag
22+ import scala .util .control .NonFatal
23+
24+ import org .scalatest .Matchers ._
2025import org .scalatest .concurrent .Eventually ._
2126import org .scalatest .time .SpanSugar ._
2227
23- import org .apache .spark .annotation .DeveloperApi
24- import org .apache .spark .rdd .RDD
25- import org .apache .spark .streaming .{StreamingContext , TestOutputStreamWithPartitions , TestSuiteBase , Time }
26- import org .apache .spark .streaming .dstream .InputDStream
28+ import org .apache .spark .streaming ._
2729import org .apache .spark .streaming .scheduler .rate .RateEstimator
2830
31+
32+
2933class RateControllerSuite extends TestSuiteBase {
3034
35+ override def actuallyWait : Boolean = true
36+
3137 test(" rate controller publishes updates" ) {
3238 val ssc = new StreamingContext (conf, batchDuration)
33- val dstream = new MockRateLimitDStream (ssc)
39+ val dstream = new MockRateLimitDStream (ssc, Seq ( Seq ( 1 )), 1 )
3440 val output = new TestOutputStreamWithPartitions (dstream)
3541 output.register()
3642 runStreams(ssc, 1 , 1 )
@@ -39,41 +45,98 @@ class RateControllerSuite extends TestSuiteBase {
3945 assert(dstream.publishCalls === 1 )
4046 }
4147 }
48+
49+ test(" receiver rate controller updates reach receivers" ) {
50+ val ssc = new StreamingContext (conf, batchDuration)
51+
52+ val dstream = new RateLimitInputDStream (ssc) {
53+ override val rateController =
54+ Some (new ReceiverRateController (id, new ConstantEstimator (200.0 )))
55+ }
56+ SingletonDummyReceiver .reset()
57+
58+ val output = new TestOutputStreamWithPartitions (dstream)
59+ output.register()
60+ runStreams(ssc, 2 , 2 )
61+
62+ eventually(timeout(5 .seconds)) {
63+ assert(dstream.getCurrentRateLimit === Some (200 ))
64+ }
65+ }
66+
67+ test(" multiple rate controller updates reach receivers" ) {
68+ val ssc = new StreamingContext (conf, batchDuration)
69+ val rates = Seq (100L , 200L , 300L )
70+
71+ val dstream = new RateLimitInputDStream (ssc) {
72+ override val rateController =
73+ Some (new ReceiverRateController (id, new ConstantEstimator (rates.map(_.toDouble): _* )))
74+ }
75+ SingletonDummyReceiver .reset()
76+
77+ val output = new TestOutputStreamWithPartitions (dstream)
78+ output.register()
79+
80+ val observedRates = mutable.HashSet .empty[Long ]
81+
82+ @ volatile var done = false
83+ runInBackground {
84+ while (! done) {
85+ try {
86+ dstream.getCurrentRateLimit.foreach(observedRates += _)
87+ } catch {
88+ case NonFatal (_) => () // don't stop if the executor wasn't installed yet
89+ }
90+ Thread .sleep(20 )
91+ }
92+ }
93+ runStreams(ssc, 4 , 4 )
94+ done = true
95+
96+ // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver
97+ observedRates should contain theSameElementsAs (rates :+ Long .MaxValue )
98+ }
99+
100+ private def runInBackground (f : => Unit ): Unit = {
101+ new Thread {
102+ override def run (): Unit = {
103+ f
104+ }
105+ }.start()
106+ }
42107}
43108
44109/**
45110 * An InputDStream that counts how often its rate controller `publish` method was called.
46111 */
47- private class MockRateLimitDStream (@ transient ssc : StreamingContext )
48- extends InputDStream [Int ](ssc) {
112+ private class MockRateLimitDStream [T : ClassTag ](
113+ @ transient ssc : StreamingContext ,
114+ input : Seq [Seq [T ]],
115+ numPartitions : Int ) extends TestInputStream [T ](ssc, input, numPartitions) {
49116
50117 @ volatile
51118 var publishCalls = 0
52119
53- private object ConstantEstimator extends RateEstimator {
54- def compute (
55- time : Long ,
56- elements : Long ,
57- processingDelay : Long ,
58- schedulingDelay : Long ): Option [Double ] = {
59- Some (100.0 )
60- }
61- }
62-
63120 override val rateController : Option [RateController ] =
64- Some (new RateController (id, ConstantEstimator ) {
121+ Some (new RateController (id, new ConstantEstimator ( 100.0 ) ) {
65122 override def publish (rate : Long ): Unit = {
66123 publishCalls += 1
67124 }
68125 })
126+ }
69127
70- def compute (validTime : Time ): Option [RDD [Int ]] = {
71- val data = Seq (1 )
72- ssc.scheduler.inputInfoTracker.reportInfo(validTime, StreamInputInfo (id, data.size))
73- Some (ssc.sc.parallelize(data))
74- }
128+ private class ConstantEstimator (rates : Double * ) extends RateEstimator {
129+ private var idx : Int = 0
75130
76- def stop (): Unit = {}
131+ private def nextRate (): Double = {
132+ val rate = rates(idx)
133+ idx = (idx + 1 ) % rates.size
134+ rate
135+ }
77136
78- def start (): Unit = {}
137+ def compute (
138+ time : Long ,
139+ elements : Long ,
140+ processingDelay : Long ,
141+ schedulingDelay : Long ): Option [Double ] = Some (nextRate())
79142}
0 commit comments