@@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming
2020import java .util .concurrent .CountDownLatch
2121
2222import org .apache .commons .lang3 .RandomStringUtils
23+ import org .mockito .Mockito ._
2324import org .scalactic .TolerantNumerics
2425import org .scalatest .concurrent .Eventually ._
2526import org .scalatest .BeforeAndAfter
2627import org .scalatest .concurrent .PatienceConfiguration .Timeout
28+ import org .scalatest .mock .MockitoSugar
2729
2830import org .apache .spark .internal .Logging
2931import org .apache .spark .sql .{DataFrame , Dataset }
@@ -32,11 +34,11 @@ import org.apache.spark.SparkException
3234import org .apache .spark .sql .execution .streaming ._
3335import org .apache .spark .sql .functions ._
3436import org .apache .spark .sql .internal .SQLConf
35- import org .apache .spark .sql .streaming .util .BlockingSource
37+ import org .apache .spark .sql .streaming .util .{ BlockingSource , MockSourceProvider }
3638import org .apache .spark .util .ManualClock
3739
3840
39- class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
41+ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar {
4042
4143 import AwaitTerminationTester ._
4244 import testImplicits ._
@@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
481483 }
482484 }
483485
486+ test(" StreamExecution should call stop() on sources when a stream is stopped" ) {
487+ var calledStop = false
488+ val source = new Source {
489+ override def stop (): Unit = {
490+ calledStop = true
491+ }
492+ override def getOffset : Option [Offset ] = None
493+ override def getBatch (start : Option [Offset ], end : Offset ): DataFrame = {
494+ spark.emptyDataFrame
495+ }
496+ override def schema : StructType = MockSourceProvider .fakeSchema
497+ }
498+
499+ MockSourceProvider .withMockSources(source) {
500+ val df = spark.readStream
501+ .format(" org.apache.spark.sql.streaming.util.MockSourceProvider" )
502+ .load()
503+
504+ testStream(df)(StopStream )
505+
506+ assert(calledStop, " Did not call stop on source for stopped stream" )
507+ }
508+ }
509+
510+ testQuietly(" SPARK-19774: StreamExecution should call stop() on sources when a stream fails" ) {
511+ var calledStop = false
512+ val source1 = new Source {
513+ override def stop (): Unit = {
514+ throw new RuntimeException (" Oh no!" )
515+ }
516+ override def getOffset : Option [Offset ] = Some (LongOffset (1 ))
517+ override def getBatch (start : Option [Offset ], end : Offset ): DataFrame = {
518+ spark.range(2 ).toDF(MockSourceProvider .fakeSchema.fieldNames: _* )
519+ }
520+ override def schema : StructType = MockSourceProvider .fakeSchema
521+ }
522+ val source2 = new Source {
523+ override def stop (): Unit = {
524+ calledStop = true
525+ }
526+ override def getOffset : Option [Offset ] = None
527+ override def getBatch (start : Option [Offset ], end : Offset ): DataFrame = {
528+ spark.emptyDataFrame
529+ }
530+ override def schema : StructType = MockSourceProvider .fakeSchema
531+ }
532+
533+ MockSourceProvider .withMockSources(source1, source2) {
534+ val df1 = spark.readStream
535+ .format(" org.apache.spark.sql.streaming.util.MockSourceProvider" )
536+ .load()
537+ .as[Int ]
538+
539+ val df2 = spark.readStream
540+ .format(" org.apache.spark.sql.streaming.util.MockSourceProvider" )
541+ .load()
542+ .as[Int ]
543+
544+ testStream(df1.union(df2).map(i => i / 0 ))(
545+ AssertOnQuery { sq =>
546+ intercept[StreamingQueryException ](sq.processAllAvailable())
547+ sq.exception.isDefined && ! sq.isActive
548+ }
549+ )
550+
551+ assert(calledStop, " Did not call stop on source for stopped stream" )
552+ }
553+ }
554+
484555 /** Create a streaming DF that only execute one batch in which it returns the given static DF */
485556 private def createSingleTriggerStreamingDF (triggerDF : DataFrame ): DataFrame = {
486557 require(! triggerDF.isStreaming)
0 commit comments