Skip to content

Commit bc08338

Browse files
committed
ready for review
1 parent 7e5359b commit bc08338

File tree

3 files changed

+170
-3
lines changed

3 files changed

+170
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ class StreamExecution(
321321
initializationLatch.countDown()
322322

323323
try {
324+
stopSources()
324325
state.set(TERMINATED)
325326
currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)
326327

@@ -558,6 +559,18 @@ class StreamExecution(
558559
sparkSession.streams.postListenerEvent(event)
559560
}
560561

562+
/** Stops all streaming sources safely. */
563+
private def stopSources(): Unit = {
564+
uniqueSources.foreach { source =>
565+
try {
566+
source.stop()
567+
} catch {
568+
case NonFatal(e) =>
569+
logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e)
570+
}
571+
}
572+
}
573+
561574
/**
562575
* Signals to the thread executing micro-batches that it should stop running after the next
563576
* batch. This method blocks until the thread stops running.
@@ -570,7 +583,6 @@ class StreamExecution(
570583
microBatchThread.interrupt()
571584
microBatchThread.join()
572585
}
573-
uniqueSources.foreach(_.stop())
574586
logInfo(s"Query $prettyIdString was stopped")
575587
}
576588

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming
2020
import java.util.concurrent.CountDownLatch
2121

2222
import org.apache.commons.lang3.RandomStringUtils
23+
import org.mockito.Mockito._
2324
import org.scalactic.TolerantNumerics
2425
import org.scalatest.concurrent.Eventually._
2526
import org.scalatest.BeforeAndAfter
2627
import org.scalatest.concurrent.PatienceConfiguration.Timeout
28+
import org.scalatest.mock.MockitoSugar
2729

2830
import org.apache.spark.internal.Logging
2931
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -32,11 +34,11 @@ import org.apache.spark.SparkException
3234
import org.apache.spark.sql.execution.streaming._
3335
import org.apache.spark.sql.functions._
3436
import org.apache.spark.sql.internal.SQLConf
35-
import org.apache.spark.sql.streaming.util.BlockingSource
37+
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider}
3638
import org.apache.spark.util.ManualClock
3739

3840

39-
class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
41+
class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar {
4042

4143
import AwaitTerminationTester._
4244
import testImplicits._
@@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
481483
}
482484
}
483485

486+
test("StreamExecution should call stop() on sources when a stream is stopped") {
487+
var calledStop = false
488+
val source = new Source {
489+
override def stop(): Unit = {
490+
calledStop = true
491+
}
492+
override def getOffset: Option[Offset] = None
493+
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
494+
spark.emptyDataFrame
495+
}
496+
override def schema: StructType = MockSourceProvider.fakeSchema
497+
}
498+
499+
MockSourceProvider.withMockSources(source) {
500+
val df = spark.readStream
501+
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
502+
.load()
503+
504+
testStream(df)(StopStream)
505+
506+
assert(calledStop, "Did not call stop on source for stopped stream")
507+
}
508+
}
509+
510+
testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") {
511+
var calledStop = false
512+
val source1 = new Source {
513+
override def stop(): Unit = {
514+
throw new RuntimeException("Oh no!")
515+
}
516+
override def getOffset: Option[Offset] = Some(LongOffset(1))
517+
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
518+
spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*)
519+
}
520+
override def schema: StructType = MockSourceProvider.fakeSchema
521+
}
522+
val source2 = new Source {
523+
override def stop(): Unit = {
524+
calledStop = true
525+
}
526+
override def getOffset: Option[Offset] = None
527+
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
528+
spark.emptyDataFrame
529+
}
530+
override def schema: StructType = MockSourceProvider.fakeSchema
531+
}
532+
533+
MockSourceProvider.withMockSources(source1, source2) {
534+
val df1 = spark.readStream
535+
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
536+
.load()
537+
.as[Int]
538+
539+
val df2 = spark.readStream
540+
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
541+
.load()
542+
.as[Int]
543+
544+
testStream(df1.union(df2).map(i => i / 0))(
545+
AssertOnQuery { sq =>
546+
intercept[StreamingQueryException](sq.processAllAvailable())
547+
sq.exception.isDefined && !sq.isActive
548+
}
549+
)
550+
551+
assert(calledStop, "Did not call stop on source for stopped stream")
552+
}
553+
}
554+
484555
/** Create a streaming DF that only execute one batch in which it returns the given static DF */
485556
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
486557
require(!triggerDF.isStreaming)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.streaming.util
19+
20+
import org.apache.spark.sql.SQLContext
21+
import org.apache.spark.sql.execution.streaming.Source
22+
import org.apache.spark.sql.sources.StreamSourceProvider
23+
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
24+
25+
/**
26+
* A StreamSourceProvider that provides mocked Sources for unit testing. Example usage:
27+
*
28+
* {{{
29+
* MockSourceProvider.withMockSources(source1, source2) {
30+
* val df1 = spark.readStream
31+
* .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
32+
* .load()
33+
*
34+
* val df2 = spark.readStream
35+
* .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
36+
* .load()
37+
*
38+
* df1.union(df2)
39+
* ...
40+
* }
41+
* }}}
42+
*/
43+
class MockSourceProvider extends StreamSourceProvider {
44+
override def sourceSchema(
45+
spark: SQLContext,
46+
schema: Option[StructType],
47+
providerName: String,
48+
parameters: Map[String, String]): (String, StructType) = {
49+
("dummySource", MockSourceProvider.fakeSchema)
50+
}
51+
52+
override def createSource(
53+
spark: SQLContext,
54+
metadataPath: String,
55+
schema: Option[StructType],
56+
providerName: String,
57+
parameters: Map[String, String]): Source = {
58+
MockSourceProvider.sourceProviderFunction()
59+
}
60+
}
61+
62+
object MockSourceProvider {
63+
// Function to generate sources. May provide multiple sources if the user implements such a
64+
// function.
65+
private var sourceProviderFunction: () => Source = _
66+
67+
final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
68+
69+
def withMockSources(sources: Source*)(f: => Unit): Unit = {
70+
require(sources.nonEmpty)
71+
var i = 0
72+
val srcProvider = () => {
73+
val source = sources(i % sources.length)
74+
i += 1
75+
source
76+
}
77+
sourceProviderFunction = srcProvider
78+
try {
79+
f
80+
} finally {
81+
sourceProviderFunction = null
82+
}
83+
}
84+
}

0 commit comments

Comments
 (0)