Skip to content

Commit 34a389d

Browse files
committed
Various style changes and a first test for the rate controller.
1 parent d32ca36 commit 34a389d

File tree

6 files changed

+122
-25
lines changed

6 files changed

+122
-25
lines changed

streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,17 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
5353
* A rate estimator configured by the user to compute a dynamic ingestion bound for this stream.
5454
* @see `RateEstimator`
5555
*/
56-
protected [streaming] val rateEstimator = ssc.conf
57-
.getOption("spark.streaming.RateEstimator")
58-
.getOrElse("noop") match {
59-
case _ => new NoopRateEstimator()
60-
}
56+
protected [streaming] val rateEstimator = newEstimator()
57+
58+
/**
59+
* Return the configured estimator, or `noop` if none was specified.
60+
*/
61+
private def newEstimator() =
62+
ssc.conf.get("spark.streaming.RateEstimator", "noop") match {
63+
case "noop" => new NoopRateEstimator()
64+
case estimator => throw new IllegalArgumentException(s"Unknown rate estimator: $estimator")
65+
}
66+
6167

6268
// Keep track of the freshest rate for this stream using the rateEstimator
6369
protected[streaming] val rateController: RateController = new RateController(id, rateEstimator) {

streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
4545
* Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
4646
*/
4747
override val rateController: RateController = new RateController(id, rateEstimator) {
48-
override def publish(rate: Long): Unit =
49-
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
50-
}
48+
override def publish(rate: Long): Unit =
49+
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
50+
}
5151

5252
/**
5353
* Gets the receiver object that will be sent to the worker nodes

streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
6767
eventLoop.start()
6868

6969
// Estimators receive updates from batch completion
70-
ssc.graph.getInputStreams.map(_.rateController).foreach(ssc.addStreamingListener(_))
70+
ssc.graph.getInputStreams.foreach(is => ssc.addStreamingListener(is.rateController))
7171
listenerBus.start(ssc.sparkContext)
7272
receiverTracker = new ReceiverTracker(ssc)
7373
inputInfoTracker = new InputInfoTracker(ssc)

streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ package org.apache.spark.streaming.scheduler
1919

2020
import java.util.concurrent.atomic.AtomicLong
2121

22+
import scala.concurrent.{ExecutionContext, Future}
23+
2224
import org.apache.spark.annotation.DeveloperApi
2325
import org.apache.spark.streaming.scheduler.rate.RateEstimator
2426
import org.apache.spark.util.ThreadUtils
2527

26-
import scala.concurrent.{ExecutionContext, Future}
27-
2828
/**
2929
* :: DeveloperApi ::
3030
* A StreamingListener that receives batch completion updates, and maintains
@@ -38,32 +38,34 @@ private [streaming] abstract class RateController(val streamUID: Int, rateEstima
3838
protected def publish(rate: Long): Unit
3939

4040
// Used to compute & publish the rate update asynchronously
41-
@transient private val executionContext = ExecutionContext.fromExecutorService(
41+
@transient
42+
implicit private val executionContext = ExecutionContext.fromExecutorService(
4243
ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update"))
4344

44-
private val rateLimit : AtomicLong = new AtomicLong(-1L)
45+
private val rateLimit: AtomicLong = new AtomicLong(-1L)
4546

46-
// Asynchronous computation of the rate update
47+
/**
48+
* Compute the new rate limit and publish it asynchronously.
49+
*/
4750
private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
4851
Future[Unit] {
4952
val newSpeed = rateEstimator.compute(time, elems, workDelay, waitDelay)
5053
newSpeed foreach { s =>
5154
rateLimit.set(s.toLong)
5255
publish(getLatestRate())
5356
}
54-
} (executionContext)
57+
}
5558

5659
def getLatestRate(): Long = rateLimit.get()
5760

58-
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted){
59-
val elements = batchCompleted.batchInfo.streamIdToInputInfo
61+
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
62+
val elements = batchCompleted.batchInfo.streamIdToInputInfo
6063

61-
for (
64+
for {
6265
processingEnd <- batchCompleted.batchInfo.processingEndTime;
6366
workDelay <- batchCompleted.batchInfo.processingDelay;
6467
waitDelay <- batchCompleted.batchInfo.schedulingDelay;
6568
elems <- elements.get(streamUID).map(_.numRecords)
66-
) computeAndPublish(processingEnd, elems, workDelay, waitDelay)
69+
} computeAndPublish(processingEnd, elems, workDelay, waitDelay)
6770
}
68-
6971
}

streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,27 @@ private[streaming] trait RateEstimator extends Serializable {
3131
* Computes the number of elements the stream attached to this `RateEstimator`
3232
* should ingest per second, given an update on the size and completion
3333
* times of the latest batch.
34+
*
35+
* @param time The timetamp of the current batch interval that just finished
36+
* @param elements The number of elements that were processed in this batch
37+
* @param processingDelay The time in ms that took for the job to complete
38+
* @param schedulingDelay The time in ms that the job spent in the scheduling queue
3439
*/
35-
def compute(time: Long, elements: Long,
36-
processingDelay: Long, schedulingDelay: Long): Option[Double]
40+
def compute(
41+
time: Long,
42+
elements: Long,
43+
processingDelay: Long,
44+
schedulingDelay: Long): Option[Double]
3745
}
3846

3947
/**
40-
* The trivial rate estimator never sends an update
48+
* The trivial rate estimator never sends an update
4149
*/
4250
private[streaming] class NoopRateEstimator extends RateEstimator {
4351

44-
def compute(time: Long, elements: Long,
45-
processingDelay: Long, schedulingDelay: Long): Option[Double] = None
52+
def compute(
53+
time: Long,
54+
elements: Long,
55+
processingDelay: Long,
56+
schedulingDelay: Long): Option[Double] = None
4657
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.streaming.scheduler
19+
20+
import org.scalatest.concurrent.Eventually._
21+
import org.scalatest.time.SpanSugar._
22+
23+
import org.apache.spark.annotation.DeveloperApi
24+
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.streaming.{StreamingContext, TestOutputStreamWithPartitions, TestSuiteBase, Time}
26+
import org.apache.spark.streaming.dstream.InputDStream
27+
import org.apache.spark.streaming.scheduler.rate.RateEstimator
28+
29+
class RateControllerSuite extends TestSuiteBase {
30+
31+
test("rate controller publishes updates") {
32+
val ssc = new StreamingContext(conf, batchDuration)
33+
val dstream = new MockRateLimitDStream(ssc)
34+
val output = new TestOutputStreamWithPartitions(dstream)
35+
output.register()
36+
runStreams(ssc, 1, 1)
37+
38+
eventually(timeout(2.seconds)) {
39+
assert(dstream.publishCalls === 1)
40+
}
41+
}
42+
}
43+
44+
/**
45+
* An InputDStream that counts how often its rate controller `publish` method was called.
46+
*/
47+
private class MockRateLimitDStream(@transient ssc: StreamingContext)
48+
extends InputDStream[Int](ssc) {
49+
50+
@volatile
51+
var publishCalls = 0
52+
53+
private object ConstantEstimator extends RateEstimator {
54+
def compute(
55+
time: Long,
56+
elements: Long,
57+
processingDelay: Long,
58+
schedulingDelay: Long): Option[Double] = {
59+
Some(100.0)
60+
}
61+
}
62+
63+
override val rateController: RateController = new RateController(id, ConstantEstimator) {
64+
override def publish(rate: Long): Unit = {
65+
publishCalls += 1
66+
}
67+
}
68+
69+
def compute(validTime: Time): Option[RDD[Int]] = {
70+
val data = Seq(1)
71+
ssc.scheduler.inputInfoTracker.reportInfo(validTime, StreamInputInfo(id, data.size))
72+
Some(ssc.sc.parallelize(data))
73+
}
74+
75+
def stop(): Unit = {}
76+
77+
def start(): Unit = {}
78+
}

0 commit comments

Comments
 (0)