Skip to content

Commit 52ce1b7

Browse files
committed
SPARK-1158: Fix flaky RateLimitedOutputStreamSuite.
There was actually a problem with the RateLimitedOutputStream implementation where the first second doesn't output anything because of integer rounding.
1 parent c3f5e07 commit 52ce1b7

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,20 @@ import scala.annotation.tailrec
2222
import java.io.OutputStream
2323
import java.util.concurrent.TimeUnit._
2424

25+
import org.apache.spark.Logging
26+
27+
2528
private[streaming]
26-
class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream {
27-
val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
28-
val CHUNK_SIZE = 8192
29-
var lastSyncTime = System.nanoTime
30-
var bytesWrittenSinceSync: Long = 0
29+
class RateLimitedOutputStream(out: OutputStream, desiredBytesPerSec: Int)
30+
extends OutputStream
31+
with Logging {
32+
33+
require(desiredBytesPerSec > 0)
34+
35+
private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
36+
private val CHUNK_SIZE = 8192
37+
private var lastSyncTime = System.nanoTime
38+
private var bytesWrittenSinceSync = 0L
3139

3240
override def write(b: Int) {
3341
waitToWrite(1)
@@ -59,9 +67,9 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu
5967
@tailrec
6068
private def waitToWrite(numBytes: Int) {
6169
val now = System.nanoTime
62-
val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS)
63-
val rate = bytesWrittenSinceSync.toDouble / elapsedSecs
64-
if (rate < bytesPerSec) {
70+
val elapsedNanosecs = math.max(now - lastSyncTime, 1)
71+
val rate = bytesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs
72+
if (rate < desiredBytesPerSec) {
6573
// It's okay to write; just update some variables and return
6674
bytesWrittenSinceSync += numBytes
6775
if (now > lastSyncTime + SYNC_INTERVAL) {
@@ -71,13 +79,14 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu
7179
}
7280
} else {
7381
// Calculate how much time we should sleep to bring ourselves to the desired rate.
74-
// Based on throttler in Kafka
75-
// scalastyle:off
76-
// (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala)
77-
// scalastyle:on
78-
val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs),
79-
SECONDS)
80-
if (sleepTime > 0) Thread.sleep(sleepTime)
82+
val targetTimeInMillis = bytesWrittenSinceSync * 1000 / desiredBytesPerSec
83+
val elapsedTimeInMillis = elapsedNanosecs / 1000000
84+
val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis
85+
if (sleepTimeInMillis > 0) {
86+
logTrace("Natural rate is " + rate + " per second but desired rate is " +
87+
desiredBytesPerSec + ", sleeping for " + sleepTimeInMillis + " ms to compensate.")
88+
Thread.sleep(sleepTimeInMillis)
89+
}
8190
waitToWrite(numBytes)
8291
}
8392
}

streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
package org.apache.spark.streaming.util
1919

20-
import org.scalatest.FunSuite
2120
import java.io.ByteArrayOutputStream
2221
import java.util.concurrent.TimeUnit._
2322

23+
import org.scalatest.FunSuite
24+
2425
class RateLimitedOutputStreamSuite extends FunSuite {
2526

2627
private def benchmark[U](f: => U): Long = {
@@ -29,12 +30,14 @@ class RateLimitedOutputStreamSuite extends FunSuite {
2930
System.nanoTime - start
3031
}
3132

32-
ignore("write") {
33+
test("write") {
3334
val underlying = new ByteArrayOutputStream
3435
val data = "X" * 41000
35-
val stream = new RateLimitedOutputStream(underlying, 10000)
36+
val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000)
3637
val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }
37-
assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4)
38-
assert(underlying.toString("UTF-8") == data)
38+
39+
// We accept anywhere from 4.0 to 4.99999 seconds since the value is rounded down.
40+
assert(SECONDS.convert(elapsedNs, NANOSECONDS) === 4)
41+
assert(underlying.toString("UTF-8") === data)
3942
}
4043
}

0 commit comments

Comments
 (0)