Skip to content

Commit 479582b

Browse files
committed
Address and fix timestamp for rampUpTimeSeconds
1 parent 3a95b55 commit 479582b

File tree

2 files changed

+38
-35
lines changed

2 files changed

+38
-35
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit
2424
import org.apache.commons.io.IOUtils
2525

2626
import org.apache.spark.internal.Logging
27+
import org.apache.spark.network.util.JavaUtils
2728
import org.apache.spark.sql.{DataFrame, SQLContext}
2829
import org.apache.spark.sql.catalyst.InternalRow
2930
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
@@ -37,11 +38,12 @@ import org.apache.spark.util.{ManualClock, SystemClock}
3738
* with 0L.
3839
*
3940
* This source supports the following options:
40-
* - `tuplesPerSecond` (default: 1): How many tuples should be generated per second.
41-
* - `rampUpTimeSeconds` (default: 0): How many seconds to ramp up before the generating speed
42-
* becomes `tuplesPerSecond`.
43-
* - `numPartitions` (default: Spark's default parallelism): The partition number for the generated
44-
* tuples.
41+
* - `tuplesPerSecond` (e.g. 100, default: 1): How many tuples should be generated per second.
42+
* - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
43+
* becomes `tuplesPerSecond`. Using finer granularities than seconds will be truncated to integer
44+
* seconds.
45+
* - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
46+
* generated tuples.
4547
*/
4648
class RateSourceProvider extends StreamSourceProvider with DataSourceRegister {
4749

@@ -63,22 +65,23 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister {
6365
val tuplesPerSecond = params.get("tuplesPerSecond").map(_.toLong).getOrElse(1L)
6466
if (tuplesPerSecond <= 0) {
6567
throw new IllegalArgumentException(
66-
s"Invalid value '${params("tuplesPerSecond")}' for option 'tuplesPerSecond', " +
68+
s"Invalid value '${params("tuplesPerSecond")}'. The option 'tuplesPerSecond' " +
6769
"must be positive")
6870
}
6971

70-
val rampUpTimeSeconds = params.get("rampUpTimeSeconds").map(_.toLong).getOrElse(0L)
72+
val rampUpTimeSeconds =
73+
params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
7174
if (rampUpTimeSeconds < 0) {
7275
throw new IllegalArgumentException(
73-
s"Invalid value '${params("rampUpTimeSeconds")}' for option 'rampUpTimeSeconds', " +
76+
s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
7477
"must not be negative")
7578
}
7679

7780
val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
7881
sqlContext.sparkContext.defaultParallelism)
7982
if (numPartitions <= 0) {
8083
throw new IllegalArgumentException(
81-
s"Invalid value '${params("numPartitions")}' for option 'numPartitions', " +
84+
s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
8285
"must be positive")
8386
}
8487

@@ -117,8 +120,9 @@ class RateStreamSource(
117120
private val maxSeconds = Long.MaxValue / tuplesPerSecond
118121

119122
if (rampUpTimeSeconds > maxSeconds) {
120-
throw new ArithmeticException("integer overflow. Max offset with tuplesPerSecond " +
121-
s"$tuplesPerSecond is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
123+
throw new ArithmeticException(
124+
s"Integer overflow. Max offset with $tuplesPerSecond tuplesPerSecond" +
125+
s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
122126
}
123127

124128
private val startTimeMs = {
@@ -175,10 +179,10 @@ class RateStreamSource(
175179
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
176180
val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
177181
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
178-
assert(startSeconds <= endSeconds)
182+
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
179183
if (endSeconds > maxSeconds) {
180-
throw new ArithmeticException("integer overflow. Max offset with " +
181-
s"tuplesPerSecond $tuplesPerSecond is $maxSeconds, but it's $endSeconds now.")
184+
throw new ArithmeticException("Integer overflow. Max offset with " +
185+
s"$tuplesPerSecond tuplesPerSecond is $maxSeconds, but it's $endSeconds now.")
182186
}
183187
// Fix "lastTimeMs" for recovery
184188
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
@@ -188,11 +192,17 @@ class RateStreamSource(
188192
val rangeEnd = valueAtSecond(endSeconds, tuplesPerSecond, rampUpTimeSeconds)
189193
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
190194
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
191-
val localStartTimeMs = startTimeMs
192-
val localPerSecond = tuplesPerSecond
195+
196+
if (rangeStart == rangeEnd) {
197+
return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema)
198+
}
199+
200+
val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
201+
val relativeMsPerValue =
202+
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds) / (rangeEnd - rangeStart)
193203

194204
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
195-
val relative = v * 1000L / localPerSecond
205+
val relative = (v - rangeStart) * relativeMsPerValue
196206
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
197207
}
198208
sqlContext.internalCreateDataFrame(rdd, schema)

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,6 @@ class RateSourceSuite extends StreamTest {
3939
}
4040
}
4141

42-
private def getManualClockFromQuery(query: StreamExecution): ManualClock = {
43-
val rateSource = query.logicalPlan.collect {
44-
case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] =>
45-
source.asInstanceOf[RateStreamSource]
46-
}.head
47-
rateSource.clock.asInstanceOf[ManualClock]
48-
}
49-
5042
test("basic") {
5143
val input = spark.readStream
5244
.format("rate")
@@ -80,30 +72,31 @@ class RateSourceSuite extends StreamTest {
8072
assert(valueAtSecond(seconds = 5, tuplesPerSecond = 10, rampUpTimeSeconds = 4) === 30)
8173
}
8274

83-
test("rampUpTimeSeconds") {
75+
test("rampUpTime") {
8476
val input = spark.readStream
8577
.format("rate")
8678
.option("tuplesPerSecond", "10")
87-
.option("rampUpTimeSeconds", "4")
79+
.option("rampUpTime", "4s")
8880
.option("useManualClock", "true")
8981
.load()
90-
.select($"value")
82+
.as[(java.sql.Timestamp, Long)]
83+
.map(v => (v._1.getTime, v._2))
9184
testStream(input)(
9285
AdvanceRateManualClock(seconds = 1),
93-
CheckLastBatch(0 until 2: _*), // speed = 2
86+
CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2
9487
AdvanceRateManualClock(seconds = 1),
95-
CheckLastBatch(2 until 6: _*), // speed = 4
88+
CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
9689
AdvanceRateManualClock(seconds = 1),
97-
CheckLastBatch(6 until 12: _*), // speed = 6
90+
CheckLastBatch((6 until 12).map(v => 2000 + (v - 6) * 166 -> v): _*), // speed = 6
9891
AdvanceRateManualClock(seconds = 1),
99-
CheckLastBatch(12 until 20: _*), // speed = 8
92+
CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8
10093
AdvanceRateManualClock(seconds = 1),
10194
// Now we should reach full speed
102-
CheckLastBatch(20 until 30: _*), // speed = 10
95+
CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10
10396
AdvanceRateManualClock(seconds = 1),
104-
CheckLastBatch(30 until 40: _*), // speed = 10
97+
CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10
10598
AdvanceRateManualClock(seconds = 1),
106-
CheckLastBatch(40 until 50: _*) // speed = 10
99+
CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10
107100
)
108101
}
109102

0 commit comments

Comments
 (0)