Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class StreamExecution(
initializationLatch.countDown()

try {
stopSources()
state.set(TERMINATED)
currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)

Expand Down Expand Up @@ -558,6 +559,18 @@ class StreamExecution(
sparkSession.streams.postListenerEvent(event)
}

/** Stops all streaming sources safely. */
private def stopSources(): Unit = {
uniqueSources.foreach { source =>
try {
source.stop()
} catch {
case NonFatal(e) =>
logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e)
}
}
}

/**
* Signals to the thread executing micro-batches that it should stop running after the next
* batch. This method blocks until the thread stops running.
Expand All @@ -570,7 +583,6 @@ class StreamExecution(
microBatchThread.interrupt()
microBatchThread.join()
}
uniqueSources.foreach(_.stop())
logInfo(s"Query $prettyIdString was stopped")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming
import java.util.concurrent.CountDownLatch

import org.apache.commons.lang3.RandomStringUtils
import org.mockito.Mockito._
import org.scalactic.TolerantNumerics
import org.scalatest.concurrent.Eventually._
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.mock.MockitoSugar

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
Expand All @@ -32,11 +34,11 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.BlockingSource
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider}
import org.apache.spark.util.ManualClock


class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar {

import AwaitTerminationTester._
import testImplicits._
Expand Down Expand Up @@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
}
}

test("StreamExecution should call stop() on sources when a stream is stopped") {
var calledStop = false
val source = new Source {
override def stop(): Unit = {
calledStop = true
}
override def getOffset: Option[Offset] = None
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
spark.emptyDataFrame
}
override def schema: StructType = MockSourceProvider.fakeSchema
}

MockSourceProvider.withMockSources(source) {
val df = spark.readStream
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
.load()

testStream(df)(StopStream)

assert(calledStop, "Did not call stop on source for stopped stream")
}
}

testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") {
var calledStop = false
val source1 = new Source {
override def stop(): Unit = {
throw new RuntimeException("Oh no!")
}
override def getOffset: Option[Offset] = Some(LongOffset(1))
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*)
}
override def schema: StructType = MockSourceProvider.fakeSchema
}
val source2 = new Source {
override def stop(): Unit = {
calledStop = true
}
override def getOffset: Option[Offset] = None
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
spark.emptyDataFrame
}
override def schema: StructType = MockSourceProvider.fakeSchema
}

MockSourceProvider.withMockSources(source1, source2) {
val df1 = spark.readStream
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
.load()
.as[Int]

val df2 = spark.readStream
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
.load()
.as[Int]

testStream(df1.union(df2).map(i => i / 0))(
AssertOnQuery { sq =>
intercept[StreamingQueryException](sq.processAllAvailable())
sq.exception.isDefined && !sq.isActive
}
)

assert(calledStop, "Did not call stop on source for stopped stream")
}
}

/** Create a streaming DF that only execute one batch in which it returns the given static DF */
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
require(!triggerDF.isStreaming)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

/**
* A StreamSourceProvider that provides mocked Sources for unit testing. Example usage:
*
* {{{
* MockSourceProvider.withMockSources(source1, source2) {
* val df1 = spark.readStream
* .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
* .load()
*
* val df2 = spark.readStream
* .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
* .load()
*
* df1.union(df2)
* ...
* }
* }}}
*/
class MockSourceProvider extends StreamSourceProvider {
override def sourceSchema(
spark: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
("dummySource", MockSourceProvider.fakeSchema)
}

override def createSource(
spark: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
MockSourceProvider.sourceProviderFunction()
}
}

object MockSourceProvider {
// Function to generate sources. May provide multiple sources if the user implements such a
// function.
private var sourceProviderFunction: () => Source = _

final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

def withMockSources(sources: Source*)(f: => Unit): Unit = {
require(sources.nonEmpty)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can use withMockSources(firstSource: Source, otherSources: Source*) to make it become a compile error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

var i = 0
val srcProvider = () => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: srcProvider is not necessary. You can just assign the func to sourceProviderFunction

val source = sources(i % sources.length)
i += 1
source
}
sourceProviderFunction = srcProvider
try {
f
} finally {
sourceProviderFunction = null
}
}
}