Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.streaming
import java.io.{File, FileWriter}
import java.util.concurrent.TimeUnit

import scala.collection.mutable
import scala.jdk.CollectionConverters._

import org.scalatest.concurrent.Eventually.eventually
Expand All @@ -32,7 +31,7 @@ import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf, window}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.util.SparkFileUtils

Expand Down Expand Up @@ -354,9 +353,15 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging {
}

test("streaming query listener") {
testStreamingQueryListener(new EventCollectorV1, "_v1")
testStreamingQueryListener(new EventCollectorV2, "_v2")
}

private def testStreamingQueryListener(
listener: StreamingQueryListener,
tablePostfix: String): Unit = {
assert(spark.streams.listListeners().length == 0)

val listener = new EventCollector
spark.streams.addListener(listener)

val q = spark.readStream
Expand All @@ -370,19 +375,29 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging {
q.processAllAvailable()
eventually(timeout(30.seconds)) {
assert(q.isActive)
checkAnswer(spark.table("my_listener_table").toDF(), Seq(Row(1, 2), Row(4, 5)))

assert(!spark.table(s"listener_start_events$tablePostfix").toDF().isEmpty)
assert(!spark.table(s"listener_progress_events$tablePostfix").toDF().isEmpty)
}
} finally {
q.stop()
spark.sql("DROP TABLE IF EXISTS my_listener_table")

eventually(timeout(30.seconds)) {
assert(!q.isActive)
assert(!spark.table(s"listener_terminated_events$tablePostfix").toDF().isEmpty)
}

spark.sql(s"DROP TABLE IF EXISTS listener_start_events$tablePostfix")
spark.sql(s"DROP TABLE IF EXISTS listener_progress_events$tablePostfix")
spark.sql(s"DROP TABLE IF EXISTS listener_terminated_events$tablePostfix")
}

// List listeners after adding a new listener, length should be 1.
val listeners = spark.streams.listListeners()
assert(listeners.length == 1)

// Add listener1 as another instance of EventCollector and validate
val listener1 = new EventCollector
val listener1 = new EventCollectorV2
spark.streams.addListener(listener1)
assert(spark.streams.listListeners().length == 2)
spark.streams.removeListener(listener1)
Expand Down Expand Up @@ -462,35 +477,56 @@ case class TestClass(value: Int) {
override def toString: String = value.toString
}

class EventCollector extends StreamingQueryListener {
@volatile var startEvent: QueryStartedEvent = null
@volatile var terminationEvent: QueryTerminatedEvent = null
@volatile var idleEvent: QueryIdleEvent = null
abstract class EventCollector extends StreamingQueryListener {
private lazy val spark = SparkSession.builder().getOrCreate()

private val _progressEvents = new mutable.Queue[StreamingQueryProgress]
protected def tablePostfix: String

def progressEvents: Seq[StreamingQueryProgress] = _progressEvents.synchronized {
_progressEvents.clone().toSeq
protected def handleOnQueryStarted(event: QueryStartedEvent): Unit = {
val df = spark.createDataFrame(Seq((event.json, 0)))
df.write.mode("append").saveAsTable(s"listener_start_events$tablePostfix")
}

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
startEvent = event
val spark = SparkSession.builder().getOrCreate()
val df = spark.createDataFrame(Seq((1, 2), (4, 5)))
df.write.saveAsTable("my_listener_table")
protected def handleOnQueryProgress(event: QueryProgressEvent): Unit = {
val df = spark.createDataFrame(Seq((event.json, 0)))
df.write.mode("append").saveAsTable(s"listener_progress_events$tablePostfix")
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
_progressEvents += event.progress
protected def handleOnQueryTerminated(event: QueryTerminatedEvent): Unit = {
val df = spark.createDataFrame(Seq((event.json, 0)))
df.write.mode("append").saveAsTable(s"listener_terminated_events$tablePostfix")
}
}

override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = {
idleEvent = event
}
/**
* V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`,
* `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
*/
class EventCollectorV1 extends EventCollector {
override protected def tablePostfix: String = "_v1"

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
terminationEvent = event
}
override def onQueryStarted(event: QueryStartedEvent): Unit = handleOnQueryStarted(event)

override def onQueryProgress(event: QueryProgressEvent): Unit = handleOnQueryProgress(event)

override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
handleOnQueryTerminated(event)
}

/**
* V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
*/
class EventCollectorV2 extends EventCollector {
override protected def tablePostfix: String = "_v2"

override def onQueryStarted(event: QueryStartedEvent): Unit = handleOnQueryStarted(event)

override def onQueryProgress(event: QueryProgressEvent): Unit = handleOnQueryProgress(event)

override def onQueryIdle(event: QueryIdleEvent): Unit = {}

override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
handleOnQueryTerminated(event)
}

class ForeachBatchFn(val viewName: String)
Expand Down
133 changes: 80 additions & 53 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,77 +26,104 @@
from pyspark.testing.connectutils import ReusedConnectTestCase


class TestListener(StreamingQueryListener):
# V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`,
# `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
class TestListenerV1(StreamingQueryListener):
def onQueryStarted(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_start_events")
df.write.mode("append").saveAsTable("listener_start_events_v1")

def onQueryProgress(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_progress_events")
df.write.mode("append").saveAsTable("listener_progress_events_v1")

def onQueryTerminated(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_terminated_events_v1")


# V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
class TestListenerV2(StreamingQueryListener):
def onQueryStarted(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_start_events_v2")

def onQueryProgress(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_progress_events_v2")

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_terminated_events")
df.write.mode("append").saveAsTable("listener_terminated_events_v2")


class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase):
def test_listener_events(self):
test_listener = TestListener()

try:
self.spark.streams.addListener(test_listener)

# This ensures the read socket on the server won't crash (i.e. because of timeout)
# when there hasn't been a new event for a long time
time.sleep(30)

df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("complete")
.start()
)

self.assertTrue(q.isActive)
# ensure at least one batch is ran
while q.lastProgress is None or q.lastProgress["batchId"] == 0:
time.sleep(5)
q.stop()
self.assertFalse(q.isActive)

time.sleep(60) # Sleep to make sure listener_terminated_events is written successfully

start_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_start_events").collect()[0][0]
)

progress_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_progress_events").collect()[0][0]
)

terminated_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_terminated_events").collect()[0][0]
)

self.check_start_event(start_event)
self.check_progress_event(progress_event)
self.check_terminated_event(terminated_event)

finally:
self.spark.streams.removeListener(test_listener)

# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)
def verify(test_listener, table_postfix):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just an indentation change, with receiving test_listener from outside of the method.

try:
self.spark.streams.addListener(test_listener)

# This ensures the read socket on the server won't crash (i.e. because of timeout)
# when there hasn't been a new event for a long time
time.sleep(30)

df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("complete")
.start()
)

self.assertTrue(q.isActive)
# ensure at least one batch is ran
while q.lastProgress is None or q.lastProgress["batchId"] == 0:
time.sleep(5)
q.stop()
self.assertFalse(q.isActive)

# Sleep to make sure listener_terminated_events is written successfully
time.sleep(60)

start_table_name = "listener_start_events" + table_postfix
progress_tbl_name = "listener_progress_events" + table_postfix
terminated_tbl_name = "listener_terminated_events" + table_postfix

start_event = pyspark.cloudpickle.loads(
self.spark.read.table(start_table_name).collect()[0][0]
)

progress_event = pyspark.cloudpickle.loads(
self.spark.read.table(progress_tbl_name).collect()[0][0]
)

terminated_event = pyspark.cloudpickle.loads(
self.spark.read.table(terminated_tbl_name).collect()[0][0]
)

self.check_start_event(start_event)
self.check_progress_event(progress_event)
self.check_terminated_event(terminated_event)

finally:
self.spark.streams.removeListener(test_listener)

# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)

verify(TestListenerV1(), "_v1")
verify(TestListenerV2(), "_v2")

def test_accessing_spark_session(self):
spark = self.spark
Expand Down