Skip to content

Commit e57c66b

Browse files
committed
Added a couple of tests for the full scenario from driver to receivers,
with several rate updates.
1 parent b425d32 commit e57c66b

File tree

3 files changed

+114
-39
lines changed

3 files changed

+114
-39
lines changed

streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
4545
* Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
4646
*/
4747
override protected[streaming] val rateController: Option[RateController] =
48-
RateEstimator.makeEstimator(ssc.conf).map { estimator =>
49-
new RateController(id, estimator) {
50-
override def publish(rate: Long): Unit =
51-
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
52-
}
53-
}
48+
RateEstimator.makeEstimator(ssc.conf).map { new ReceiverRateController(id, _) }
5449

5550
/**
5651
* Gets the receiver object that will be sent to the worker nodes
@@ -122,4 +117,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
122117
}
123118
Some(blockRDD)
124119
}
120+
121+
/**
122+
* A RateController that sends the new rate to receivers, via the receiver tracker.
123+
*/
124+
private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator)
125+
extends RateController(id, estimator) {
126+
override def publish(rate: Long): Unit =
127+
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
128+
}
125129
}
130+

streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,26 @@
1717

1818
package org.apache.spark.streaming.scheduler
1919

20+
import scala.collection.mutable
21+
import scala.reflect.ClassTag
22+
import scala.util.control.NonFatal
23+
24+
import org.scalatest.Matchers._
2025
import org.scalatest.concurrent.Eventually._
2126
import org.scalatest.time.SpanSugar._
2227

23-
import org.apache.spark.annotation.DeveloperApi
24-
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.streaming.{StreamingContext, TestOutputStreamWithPartitions, TestSuiteBase, Time}
26-
import org.apache.spark.streaming.dstream.InputDStream
28+
import org.apache.spark.streaming._
2729
import org.apache.spark.streaming.scheduler.rate.RateEstimator
2830

31+
32+
2933
class RateControllerSuite extends TestSuiteBase {
3034

35+
override def actuallyWait: Boolean = true
36+
3137
test("rate controller publishes updates") {
3238
val ssc = new StreamingContext(conf, batchDuration)
33-
val dstream = new MockRateLimitDStream(ssc)
39+
val dstream = new MockRateLimitDStream(ssc, Seq(Seq(1)), 1)
3440
val output = new TestOutputStreamWithPartitions(dstream)
3541
output.register()
3642
runStreams(ssc, 1, 1)
@@ -39,41 +45,98 @@ class RateControllerSuite extends TestSuiteBase {
3945
assert(dstream.publishCalls === 1)
4046
}
4147
}
48+
49+
test("receiver rate controller updates reach receivers") {
50+
val ssc = new StreamingContext(conf, batchDuration)
51+
52+
val dstream = new RateLimitInputDStream(ssc) {
53+
override val rateController =
54+
Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
55+
}
56+
SingletonDummyReceiver.reset()
57+
58+
val output = new TestOutputStreamWithPartitions(dstream)
59+
output.register()
60+
runStreams(ssc, 2, 2)
61+
62+
eventually(timeout(5.seconds)) {
63+
assert(dstream.getCurrentRateLimit === Some(200))
64+
}
65+
}
66+
67+
test("multiple rate controller updates reach receivers") {
68+
val ssc = new StreamingContext(conf, batchDuration)
69+
val rates = Seq(100L, 200L, 300L)
70+
71+
val dstream = new RateLimitInputDStream(ssc) {
72+
override val rateController =
73+
Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*)))
74+
}
75+
SingletonDummyReceiver.reset()
76+
77+
val output = new TestOutputStreamWithPartitions(dstream)
78+
output.register()
79+
80+
val observedRates = mutable.HashSet.empty[Long]
81+
82+
@volatile var done = false
83+
runInBackground {
84+
while (!done) {
85+
try {
86+
dstream.getCurrentRateLimit.foreach(observedRates += _)
87+
} catch {
88+
case NonFatal(_) => () // don't stop if the executor wasn't installed yet
89+
}
90+
Thread.sleep(20)
91+
}
92+
}
93+
runStreams(ssc, 4, 4)
94+
done = true
95+
96+
// Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver
97+
observedRates should contain theSameElementsAs (rates :+ Long.MaxValue)
98+
}
99+
100+
private def runInBackground(f: => Unit): Unit = {
101+
new Thread {
102+
override def run(): Unit = {
103+
f
104+
}
105+
}.start()
106+
}
42107
}
43108

44109
/**
45110
* An InputDStream that counts how often its rate controller `publish` method was called.
46111
*/
47-
private class MockRateLimitDStream(@transient ssc: StreamingContext)
48-
extends InputDStream[Int](ssc) {
112+
private class MockRateLimitDStream[T: ClassTag](
113+
@transient ssc: StreamingContext,
114+
input: Seq[Seq[T]],
115+
numPartitions: Int) extends TestInputStream[T](ssc, input, numPartitions) {
49116

50117
@volatile
51118
var publishCalls = 0
52119

53-
private object ConstantEstimator extends RateEstimator {
54-
def compute(
55-
time: Long,
56-
elements: Long,
57-
processingDelay: Long,
58-
schedulingDelay: Long): Option[Double] = {
59-
Some(100.0)
60-
}
61-
}
62-
63120
override val rateController: Option[RateController] =
64-
Some(new RateController(id, ConstantEstimator) {
121+
Some(new RateController(id, new ConstantEstimator(100.0)) {
65122
override def publish(rate: Long): Unit = {
66123
publishCalls += 1
67124
}
68125
})
126+
}
69127

70-
def compute(validTime: Time): Option[RDD[Int]] = {
71-
val data = Seq(1)
72-
ssc.scheduler.inputInfoTracker.reportInfo(validTime, StreamInputInfo(id, data.size))
73-
Some(ssc.sc.parallelize(data))
74-
}
128+
private class ConstantEstimator(rates: Double*) extends RateEstimator {
129+
private var idx: Int = 0
75130

76-
def stop(): Unit = {}
131+
private def nextRate(): Double = {
132+
val rate = rates(idx)
133+
idx = (idx + 1) % rates.size
134+
rate
135+
}
77136

78-
def start(): Unit = {}
137+
def compute(
138+
time: Long,
139+
elements: Long,
140+
processingDelay: Long,
141+
schedulingDelay: Long): Option[Double] = Some(nextRate())
79142
}

streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@
1818
package org.apache.spark.streaming.scheduler
1919

2020
import org.scalatest.concurrent.Eventually._
21-
import org.scalatest.concurrent.Timeouts
2221
import org.scalatest.time.SpanSugar._
23-
import org.apache.spark.streaming._
22+
2423
import org.apache.spark.SparkConf
24+
import org.apache.spark.annotation.DeveloperApi
2525
import org.apache.spark.storage.StorageLevel
26-
import org.apache.spark.streaming.receiver._
27-
import org.apache.spark.util.Utils
28-
import org.apache.spark.streaming.dstream.InputDStream
29-
import scala.reflect.ClassTag
26+
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestSuiteBase}
3027
import org.apache.spark.streaming.dstream.ReceiverInputDStream
28+
import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisor}
29+
3130

3231
/** Testsuite for receiver scheduling */
3332
class ReceiverTrackerSuite extends TestSuiteBase {
@@ -129,12 +128,20 @@ private class RateLimitInputDStream(@transient ssc_ : StreamingContext)
129128
}
130129

131130
/**
132-
* A Receiver as an object so we can read its rate limit.
131+
* A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when
132+
* reusing this receiver, otherwise a non-null `executor_` field will prevent it from being
133+
* serialized when receivers are installed on executors.
133134
*
134135
* @note It's necessary to be a top-level object, or else serialization would create another
135136
* one on the executor side and we won't be able to read its rate limit.
136137
*/
137-
private object SingletonDummyReceiver extends DummyReceiver
138+
private object SingletonDummyReceiver extends DummyReceiver {
139+
140+
/** Reset the object to be usable in another test. */
141+
def reset(): Unit = {
142+
executor_ = null
143+
}
144+
}
138145

139146
/**
140147
* Dummy receiver implementation

0 commit comments

Comments
 (0)