Skip to content

Commit 3a95b55

Browse files
committed
The speed should be always <= tuplesPerSecond
1 parent 5b45b1b commit 3a95b55

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class RateStreamSource(
110110
useManualClock: Boolean) extends Source with Logging {
111111

112112
import RateSourceProvider._
113+
import RateStreamSource._
113114

114115
val clock = if (useManualClock) new ManualClock else new SystemClock
115116

@@ -183,15 +184,8 @@ class RateStreamSource(
183184
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
184185
lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
185186
}
186-
val (rangeStart, rangeEnd) = if (rampUpTimeSeconds > endSeconds) {
187-
(math.rint(tuplesPerSecond * (startSeconds * 1.0 / rampUpTimeSeconds)).toLong * startSeconds,
188-
math.rint(tuplesPerSecond * (endSeconds * 1.0 / rampUpTimeSeconds)).toLong * endSeconds)
189-
} else if (startSeconds < rampUpTimeSeconds) {
190-
(math.rint(tuplesPerSecond * (startSeconds * 1.0 / rampUpTimeSeconds)).toLong * startSeconds,
191-
endSeconds * tuplesPerSecond)
192-
} else {
193-
(startSeconds * tuplesPerSecond, endSeconds * tuplesPerSecond)
194-
}
187+
val rangeStart = valueAtSecond(startSeconds, tuplesPerSecond, rampUpTimeSeconds)
188+
val rangeEnd = valueAtSecond(endSeconds, tuplesPerSecond, rampUpTimeSeconds)
195189
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
196190
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
197191
val localStartTimeMs = startTimeMs
@@ -206,3 +200,30 @@ class RateStreamSource(
206200

207201
override def stop(): Unit = {}
208202
}
203+
204+
object RateStreamSource {
205+
206+
/** Calculate the end value we will emit at the time `seconds`. */
207+
def valueAtSecond(seconds: Long, tuplesPerSecond: Long, rampUpTimeSeconds: Long): Long = {
208+
// E.g., rampUpTimeSeconds = 4, tuplesPerSecond = 10
209+
// Then speedDeltaPerSecond = 2
210+
//
211+
// seconds = 0 1 2 3 4 5 6
212+
// speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
213+
// end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
214+
val speedDeltaPerSecond = tuplesPerSecond / (rampUpTimeSeconds + 1)
215+
if (seconds <= rampUpTimeSeconds) {
216+
// Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
217+
// avoid overflow
218+
if (seconds % 2 == 1) {
219+
(seconds + 1) / 2 * speedDeltaPerSecond * seconds
220+
} else {
221+
seconds / 2 * speedDeltaPerSecond * (seconds + 1)
222+
}
223+
} else {
224+
// rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
225+
val rampUpPart = valueAtSecond(rampUpTimeSeconds, tuplesPerSecond, rampUpTimeSeconds)
226+
rampUpPart + (seconds - rampUpTimeSeconds) * tuplesPerSecond
227+
}
228+
}
229+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,46 @@ class RateSourceSuite extends StreamTest {
6464
)
6565
}
6666

67+
test("valueAtSecond") {
68+
import RateStreamSource._
69+
70+
assert(valueAtSecond(seconds = 0, tuplesPerSecond = 5, rampUpTimeSeconds = 2) === 0)
71+
assert(valueAtSecond(seconds = 1, tuplesPerSecond = 5, rampUpTimeSeconds = 2) === 1)
72+
assert(valueAtSecond(seconds = 2, tuplesPerSecond = 5, rampUpTimeSeconds = 2) === 3)
73+
assert(valueAtSecond(seconds = 3, tuplesPerSecond = 5, rampUpTimeSeconds = 2) === 8)
74+
75+
assert(valueAtSecond(seconds = 0, tuplesPerSecond = 10, rampUpTimeSeconds = 4) === 0)
76+
assert(valueAtSecond(seconds = 1, tuplesPerSecond = 10, rampUpTimeSeconds = 4) === 2)
77+
assert(valueAtSecond(seconds = 2, tuplesPerSecond = 10, rampUpTimeSeconds = 4) === 6)
78+
assert(valueAtSecond(seconds = 3, tuplesPerSecond = 10, rampUpTimeSeconds = 4) === 12)
79+
assert(valueAtSecond(seconds = 4, tuplesPerSecond = 10, rampUpTimeSeconds = 4) === 20)
80+
assert(valueAtSecond(seconds = 5, tuplesPerSecond = 10, rampUpTimeSeconds = 4) === 30)
81+
}
82+
6783
test("rampUpTimeSeconds") {
6884
val input = spark.readStream
6985
.format("rate")
7086
.option("tuplesPerSecond", "10")
71-
.option("rampUpTimeSeconds", "5")
87+
.option("rampUpTimeSeconds", "4")
7288
.option("useManualClock", "true")
7389
.load()
7490
.select($"value")
7591
testStream(input)(
7692
AdvanceRateManualClock(seconds = 1),
77-
CheckLastBatch((0 until 2): _*),
93+
CheckLastBatch(0 until 2: _*), // speed = 2
94+
AdvanceRateManualClock(seconds = 1),
95+
CheckLastBatch(2 until 6: _*), // speed = 4
7896
AdvanceRateManualClock(seconds = 1),
79-
CheckLastBatch((2 until 8): _*),
97+
CheckLastBatch(6 until 12: _*), // speed = 6
8098
AdvanceRateManualClock(seconds = 1),
81-
CheckLastBatch((8 until 18): _*),
99+
CheckLastBatch(12 until 20: _*), // speed = 8
82100
AdvanceRateManualClock(seconds = 1),
83-
CheckLastBatch((18 until 32): _*),
101+
// Now we should reach full speed
102+
CheckLastBatch(20 until 30: _*), // speed = 10
84103
AdvanceRateManualClock(seconds = 1),
85-
CheckLastBatch((32 until 50): _*),
104+
CheckLastBatch(30 until 40: _*), // speed = 10
86105
AdvanceRateManualClock(seconds = 1),
87-
CheckLastBatch((50 until 60): _*)
106+
CheckLastBatch(40 until 50: _*) // speed = 10
88107
)
89108
}
90109

0 commit comments

Comments
 (0)