From 4d0fcdd8fa028dc6e3f96b9cd01be998bd10f710 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Thu, 6 Apr 2023 15:18:31 -0700 Subject: [PATCH 01/12] done --- dev/sparktestsupport/modules.py | 5 + python/pyspark/sql/connect/streaming/query.py | 29 +- .../sql/connect/streaming/readwriter.py | 102 ++++- python/pyspark/sql/streaming/query.py | 7 +- python/pyspark/sql/streaming/readwriter.py | 80 ++-- .../streaming/test_parity_streaming.py | 69 ++++ .../sql/tests/streaming/test_streaming.py | 376 +++--------------- .../test_streaming_foreach_family.py | 369 +++++++++++++++++ 8 files changed, 656 insertions(+), 381 deletions(-) create mode 100644 python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py create mode 100644 python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1a28a644e550..d946783c8126 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -495,6 +495,7 @@ def __hash__(self): "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", "pyspark.sql.tests.streaming.test_streaming", + "pyspark.sql.tests.streaming.test_streaming_foreach_family", "pyspark.sql.tests.streaming.test_streaming_listener", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", @@ -749,6 +750,8 @@ def __hash__(self): "pyspark.sql.connect.dataframe", "pyspark.sql.connect.functions", "pyspark.sql.connect.avro.functions", + "pyspark.sql.connect.streaming.readwriter", + "pyspark.sql.connect.streaming.query", # sql unittests "pyspark.sql.tests.connect.test_client", "pyspark.sql.tests.connect.test_connect_plan", @@ -777,6 +780,8 @@ def __hash__(self): "pyspark.ml.connect.functions", # ml unittests "pyspark.ml.tests.connect.test_connect_function", + # streaming unittests + "pyspark.sql.tests.connect.streaming.test_parity_streaming", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 2866945d161f..3e8c679a9bca 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -16,6 +16,7 @@ # import json +import sys from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional from pyspark.errors import StreamingQueryException @@ -149,10 +150,32 @@ def _execute_streaming_query_cmd( def _test() -> None: - # TODO(SPARK-43031): port _test() from legacy query.py. - pass + import doctest + import os + from pyspark.sql import SparkSession as PySparkSession + import pyspark.sql.connect.streaming.query + from py4j.protocol import Py4JError + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.connect.streaming.query.__dict__.copy() + + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.streaming.query tests") + .remote("local[4]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.streaming.query, + globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, + ) + globs["spark"].stop() + + if failure_count: + sys.exit(-1) if __name__ == "__main__": - # TODO(SPARK-43031): Add this file dev/sparktestsupport/modules.py to enable testing in CI. _test() diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index b266f485c96c..6246d24a0928 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -168,9 +168,75 @@ def json( json.__doc__ = PySparkDataStreamReader.json.__doc__ - # def orc() TODO - # def parquet() TODO - # def text() TODO + def orc( + self, + path: str, + mergeSchema: Optional[bool] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + ) -> "DataFrame": + self._set_opts( + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + ) + if isinstance(path, str): + return self.load(path=path, format="orc") + else: + raise TypeError("path can be only a single string") + + orc.__doc__ = PySparkDataStreamReader.orc.__doc__ + + def parquet( + self, + path: str, + mergeSchema: Optional[bool] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + datetimeRebaseMode: Optional[Union[bool, str]] = None, + int96RebaseMode: Optional[Union[bool, str]] = None, + ) -> "DataFrame": + self._set_opts( + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + datetimeRebaseMode=datetimeRebaseMode, + int96RebaseMode=int96RebaseMode, + ) + self._set_opts( + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + datetimeRebaseMode=datetimeRebaseMode, + int96RebaseMode=int96RebaseMode, + ) + if isinstance(path, str): + return self.load(path=path, format="parquet") + else: + raise TypeError("path can be only a single string") + + parquet.__doc__ = PySparkDataStreamReader.parquet.__doc__ + + def text( + self, + path: str, + wholetext: bool = False, + lineSep: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + ) -> "DataFrame": + self._set_opts( + wholetext=wholetext, + lineSep=lineSep, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + ) + if isinstance(path, str): + return self.load(path=path, format="text") + else: + raise TypeError("path can be only a single string") + + text.__doc__ = PySparkDataStreamReader.text.__doc__ def csv( self, @@ -245,7 +311,7 @@ def csv( csv.__doc__ = PySparkDataStreamReader.csv.__doc__ - # def table() TODO. Use Read(table_name) relation. + # def table() TODO(SPARK-43042). Use Read(table_name) relation. DataStreamReader.__doc__ = PySparkDataStreamReader.__doc__ @@ -460,10 +526,32 @@ def toTable( def _test() -> None: - # TODO(SPARK-43031): port _test() from legacy query.py. - pass + import sys + import doctest + from pyspark.sql import SparkSession as PySparkSession + import pyspark.sql.connect.streaming.readwriter + + globs = pyspark.sql.connect.readwriter.__dict__.copy() + + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.streaming.readwriter tests") + .remote("local[4]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.streaming.readwriter, + globs=globs, + optionflags=doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) if __name__ == "__main__": - # TODO(SPARK-43031): Add this file dev/sparktestsupport/modules.py to enable testing in CI. _test() diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 3c43628bf378..0268de2da6ec 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -188,7 +188,8 @@ def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: Return whether the query has terminated or not within 5 seconds - >>> sq.awaitTermination(5) + TODO(SPARK-42960): remove the SKIP flag below + >>> sq.awaitTermination(5) # doctest: +SKIP False >>> sq.stop() @@ -330,7 +331,9 @@ def stop(self) -> None: Stop streaming query >>> sq.stop() - >>> sq.isActive + + # TODO(SPARK-42940): remove the SKIP flag below + >>> sq.isActive # doctest: +SKIP False """ self._jsq.stop() diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index c58848dc5085..16c44ddbbcbf 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -49,8 +49,8 @@ class DataStreamReader(OptionUtils): Examples -------- - >>> spark.readStream - + >>> spark.readStream # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamReader object ...> The example below uses Rate source that generates rows continuously. After that, we operate a modulo by 3, and then writes the stream out to the console. @@ -89,8 +89,8 @@ def format(self, source: str) -> "DataStreamReader": Examples -------- - >>> spark.readStream.format("text") - + >>> spark.readStream.format("text") # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamReader object ...> This API allows to configure other sources to read. The example below writes a small text file, and reads it back via Text source. @@ -132,10 +132,10 @@ def schema(self, schema: Union[StructType, str]) -> "DataStreamReader": Examples -------- >>> from pyspark.sql.types import StructField, StructType, StringType - >>> spark.readStream.schema(StructType([StructField("data", StringType(), True)])) - - >>> spark.readStream.schema("col0 INT, col1 DOUBLE") - + >>> spark.readStream.schema(StructType([StructField("data", StringType(), True)])) # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamReader object ...> + >>> spark.readStream.schema("col0 INT, col1 DOUBLE") # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamReader object ...> The example below specifies a different schema to CSV file. @@ -171,8 +171,8 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamReader" Examples -------- - >>> spark.readStream.option("x", 1) - + >>> spark.readStream.option("x", 1) # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamReader object ...> The example below specifies 'rowsPerSecond' option to Rate source in order to generate 10 rows every second. @@ -197,8 +197,8 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamReader": Examples -------- - >>> spark.readStream.options(x="1", y=2) - + >>> spark.readStream.options(x="1", y=2) # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamReader object ...> The example below specifies 'rowsPerSecond' and 'numPartitions' options to Rate source in order to generate 10 rows with 10 partitions every second. @@ -763,8 +763,8 @@ def outputMode(self, outputMode: str) -> "DataStreamWriter": Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.outputMode('append') - + >>> df.writeStream.outputMode('append') # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> The example below uses Complete mode that the entire aggregated counts are printed out. @@ -797,8 +797,8 @@ def format(self, source: str) -> "DataStreamWriter": Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.format("text") - + >>> df.writeStream.format("text") # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> This API allows to configure the source to write. The example below writes a CSV file from Rate source in a streaming manner. @@ -831,8 +831,8 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamWriter" Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.option("x", 1) - + >>> df.writeStream.option("x", 1) # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> The example below specifies 'numRows' option to Console source in order to print 3 rows for every batch. @@ -859,8 +859,8 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamWriter": Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.option("x", 1) - + >>> df.writeStream.option("x", 1) # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> The example below specifies 'numRows' and 'truncate' options to Console source in order to print 3 rows for every batch without truncating the results. @@ -904,8 +904,8 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.partitionBy("value") - + >>> df.writeStream.partitionBy("value") # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> Partition-by timestamp column from Rate source. @@ -1014,18 +1014,18 @@ def trigger( Trigger the query for execution every 5 seconds - >>> df.writeStream.trigger(processingTime='5 seconds') - + >>> df.writeStream.trigger(processingTime='5 seconds') # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> Trigger the query for execution every 5 seconds - >>> df.writeStream.trigger(continuous='5 seconds') - + >>> df.writeStream.trigger(continuous='5 seconds') # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> Trigger the query for reading all available data with multiple batches - >>> df.writeStream.trigger(availableNow=True) - + >>> df.writeStream.trigger(availableNow=True) # doctest: +ELLIPSIS + <...streaming.readwriter.DataStreamWriter object ...> """ params = [processingTime, once, continuous, availableNow] @@ -1150,6 +1150,7 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt This API is evolving. Examples + TODO(SPARK-43054): remove the SKIP flags below -------- >>> import time >>> df = spark.readStream.format("rate").load() @@ -1159,9 +1160,9 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt >>> def print_row(row): ... print(row) ... - >>> q = df.writeStream.foreach(print_row).start() - >>> time.sleep(3) - >>> q.stop() + >>> q = df.writeStream.foreach(print_row).start() # doctest: +SKIP + >>> time.sleep(3) # doctest: +SKIP + >>> q.stop() # doctest: +SKIP Print every row using a object with process() method @@ -1176,9 +1177,9 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt ... def close(self, error): ... print("Closed with error: %s" % str(error)) ... - >>> q = df.writeStream.foreach(print_row).start() - >>> time.sleep(3) - >>> q.stop() + >>> q = df.writeStream.foreach(print_row).start() # doctest: +SKIP + >>> time.sleep(3) # doctest: +SKIP + >>> q.stop() # doctest: +SKIP """ from pyspark.rdd import _wrap_function @@ -1280,14 +1281,15 @@ def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamW Examples -------- + # TODO(SPARK-42944): remove the SKIP flags below >>> import time >>> df = spark.readStream.format("rate").load() >>> def func(batch_df, batch_id): ... batch_df.collect() ... - >>> q = df.writeStream.foreachBatch(func).start() - >>> time.sleep(3) - >>> q.stop() + >>> q = df.writeStream.foreachBatch(func).start() # doctest: +SKIP + >>> time.sleep(3) # doctest: +SKIP + >>> q.stop() # doctest: +SKIP """ from pyspark.java_gateway import ensure_callback_server_started @@ -1359,7 +1361,9 @@ def start( >>> q.name 'this_query' >>> q.stop() - >>> q.isActive + + # TODO(SPARK-42940): remove the SKIP flag below + >>> q.isActive # doctest: +SKIP False Example with using other parameters with a trigger. diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py new file mode 100644 index 000000000000..d28d2c0524a9 --- /dev/null +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.testing.connectutils import should_test_connect +from pyspark.sql.tests.streaming.test_streaming import StreamingTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class StreamingParityTests(StreamingTestsMixin, ReusedConnectTestCase): + @unittest.skip("Will be supported with SPARK-42960.") + def test_stream_await_termination(self): + super().test_stream_await_termination() + + @unittest.skip("Will be supported with SPARK-42960.") + def test_stream_exception(self): + super().test_stream_exception() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_stream_status_and_progress(self): + super().test_stream_status_and_progress() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_query_manager_await_termination(self): + super().test_query_manager_await_termination() + + @unittest.skip("table API will be supported later with SPARK-43042.") + def test_streaming_read_from_table(self): + super().test_streaming_read_from_table() + + @unittest.skip("table API will be supported later with SPARK-43042.") + def test_streaming_write_to_table(self): + super().test_streaming_write_to_table() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_stream_save_options(self): + super().test_stream_save_options() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_stream_save_options_overwrite(self): + super().test_stream_save_options_overwrite() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.streaming.test_parity_streaming import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 9f02ae848bf6..2b3903a855ae 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -26,7 +26,39 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingTests(ReusedSQLTestCase): +class StreamingTestsMixin: + def test_streaming_query_functions_sanity(self): + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + query = ( + df.writeStream.format("memory") + .queryName("test_streaming_query_functions_sanity") + .start() + ) + try: + self.assertEquals(query.name, "test_streaming_query_functions_sanity") + self.assertTrue(isinstance(query.id, str)) + self.assertTrue(isinstance(query.runId, str)) + self.assertTrue(query.isActive) + # TODO: Will be uncommented with [SPARK-42960] + # self.assertEqual(query.exception(), None) + # self.assertFalse(query.awaitTermination(1)) + query.processAllAvailable() + recentProgress = query.recentProgress + lastProgress = query.lastProgress + self.assertEqual(lastProgress["name"], query.name) + self.assertEqual(lastProgress["id"], query.id) + self.assertTrue(any(p == lastProgress for p in recentProgress)) + query.explain() + + except Exception as e: + self.fail( + "Streaming query functions sanity check shouldn't throw any error. " + "Error message: " + str(e) + ) + + finally: + query.stop() + def test_stream_trigger(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") @@ -77,8 +109,14 @@ def test_stream_read_options_overwrite(self): .schema(bad_schema) .load(path="python/test_support/sql/streaming", schema=schema, format="text") ) - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct") + # TODO: Moving this outside of with block will trigger the following error, + # which doesn't happen in non-connect + # pyspark.errors.exceptions.connect.AnalysisException: + # There is a 'path' option set and load() is called with a path parameter. + # Either remove the path option, or call load() without the parameter. + # To ignore this check, set 'spark.sql.legacy.pathOptionBehavior.enabled' to 'true'. + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct") def test_stream_save_options(self): df = ( @@ -295,334 +333,6 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) - class ForeachWriterTester: - def __init__(self, spark): - self.spark = spark - - def write_open_event(self, partitionId, epochId): - self._write_event(self.open_events_dir, {"partition": partitionId, "epoch": epochId}) - - def write_process_event(self, row): - self._write_event(self.process_events_dir, {"value": "text"}) - - def write_close_event(self, error): - self._write_event(self.close_events_dir, {"error": str(error)}) - - def write_input_file(self): - self._write_event(self.input_dir, "text") - - def open_events(self): - return self._read_events(self.open_events_dir, "partition INT, epoch INT") - - def process_events(self): - return self._read_events(self.process_events_dir, "value STRING") - - def close_events(self): - return self._read_events(self.close_events_dir, "error STRING") - - def run_streaming_query_on_writer(self, writer, num_files): - self._reset() - try: - sdf = self.spark.readStream.format("text").load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - for i in range(num_files): - self.write_input_file() - sq.processAllAvailable() - finally: - self.stop_all() - - def assert_invalid_writer(self, writer, msg=None): - self._reset() - try: - sdf = self.spark.readStream.format("text").load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - self.write_input_file() - sq.processAllAvailable() - self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected - except Exception as e: - if msg: - assert msg in str(e), "%s not in %s" % (msg, str(e)) - - finally: - self.stop_all() - - def stop_all(self): - for q in self.spark.streams.active: - q.stop() - - def _reset(self): - self.input_dir = tempfile.mkdtemp() - self.open_events_dir = tempfile.mkdtemp() - self.process_events_dir = tempfile.mkdtemp() - self.close_events_dir = tempfile.mkdtemp() - - def _read_events(self, dir, json): - rows = self.spark.read.schema(json).json(dir).collect() - dicts = [row.asDict() for row in rows] - return dicts - - def _write_event(self, dir, event): - import uuid - - with open(os.path.join(dir, str(uuid.uuid4())), "w") as f: - f.write("%s\n" % str(event)) - - def __getstate__(self): - return (self.open_events_dir, self.process_events_dir, self.close_events_dir) - - def __setstate__(self, state): - self.open_events_dir, self.process_events_dir, self.close_events_dir = state - - # Those foreach tests are failed in macOS High Sierra by defined rules - # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html - # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. - def test_streaming_foreach_with_simple_function(self): - tester = self.ForeachWriterTester(self.spark) - - def foreach_func(row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(foreach_func, 2) - self.assertEqual(len(tester.process_events()), 2) - - def test_streaming_foreach_with_basic_open_process_close(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partitionId, epochId): - tester.write_open_event(partitionId, epochId) - return True - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - open_events = tester.open_events() - self.assertEqual(len(open_events), 2) - self.assertSetEqual(set([e["epoch"] for e in open_events]), {0, 1}) - - self.assertEqual(len(tester.process_events()), 2) - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) - - def test_streaming_foreach_with_open_returning_false(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return False - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - self.assertEqual(len(tester.open_events()), 2) - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) - - def test_streaming_foreach_without_open_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 2) - - def test_streaming_foreach_without_close_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return True - - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 2) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_without_open_and_close_methods(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_with_process_throwing_error(self): - from pyspark.errors import StreamingQueryException - - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - raise RuntimeError("test error") - - def close(self, error): - tester.write_close_event(error) - - try: - tester.run_streaming_query_on_writer(ForeachWriter(), 1) - self.fail("bad writer did not fail the query") # this is not expected - except StreamingQueryException: - # TODO: Verify whether original error message is inside the exception - pass - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - close_events = tester.close_events() - self.assertEqual(len(close_events), 1) - # TODO: Verify whether original error message is inside the exception - - def test_streaming_foreach_with_invalid_writers(self): - - tester = self.ForeachWriterTester(self.spark) - - def func_with_iterator_input(iter): - for x in iter: - print(x) - - tester.assert_invalid_writer(func_with_iterator_input) - - class WriterWithoutProcess: - def open(self, partition): - pass - - tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") - - class WriterWithNonCallableProcess: - process = True - - tester.assert_invalid_writer( - WriterWithNonCallableProcess(), "'process' in provided object is not callable" - ) - - class WriterWithNoParamProcess: - def process(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamProcess()) - - # Abstract class for tests below - class WithProcess: - def process(self, row): - pass - - class WriterWithNonCallableOpen(WithProcess): - open = True - - tester.assert_invalid_writer( - WriterWithNonCallableOpen(), "'open' in provided object is not callable" - ) - - class WriterWithNoParamOpen(WithProcess): - def open(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamOpen()) - - class WriterWithNonCallableClose(WithProcess): - close = True - - tester.assert_invalid_writer( - WriterWithNonCallableClose(), "'close' in provided object is not callable" - ) - - def test_streaming_foreachBatch(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - collected[batch_id] = batch_df.collect() - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_tempview(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - batch_df.createOrReplaceTempView("updates") - # it should use the spark session within given DataFrame, as microbatch execution will - # clone the session which is no longer same with the session used to start the - # streaming query - collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_propagates_python_errors(self): - from pyspark.errors import StreamingQueryException - - q = None - - def collectBatch(df, id): - raise RuntimeError("this should fail the query") - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.fail("Expected a failure") - except StreamingQueryException as e: - self.assertTrue("this should fail" in str(e)) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_graceful_stop(self): - # SPARK-39218: Make foreachBatch streaming query stop gracefully - def func(batch_df, _): - batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000) - - q = self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() - time.sleep(3) # 'rowsPerSecond' defaults to 1. Waits 3 secs out for the input. - q.stop() - self.assertIsNone(q.exception(), "No exception has to be propagated.") - def test_streaming_read_from_table(self): with self.table("input_table", "this_query"): self.spark.sql("CREATE TABLE input_table (value string) USING parquet") @@ -648,6 +358,10 @@ def test_streaming_write_to_table(self): self.assertTrue(len(result) > 0) +class StreamingTests(StreamingTestsMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": import unittest from pyspark.sql.tests.streaming.test_streaming import * # noqa: F401 diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py new file mode 100644 index 000000000000..86ecfebf6fe0 --- /dev/null +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py @@ -0,0 +1,369 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import tempfile +import time + +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class StreamingTestsForeachFamilyMixin: + class ForeachWriterTester: + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event(self.open_events_dir, {"partition": partitionId, "epoch": epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {"value": "text"}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {"error": str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, "partition INT, epoch INT") + + def process_events(self): + return self._read_events(self.process_events_dir, "value STRING") + + def close_events(self): + return self._read_events(self.close_events_dir, "error STRING") + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format("text").load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format("text").load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + + with open(os.path.join(dir, str(uuid.uuid4())), "w") as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + # Those foreach tests are failed in macOS High Sierra by defined rules + # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html + # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e["epoch"] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.errors import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise RuntimeError("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess: + process = True + + tester.assert_invalid_writer( + WriterWithNonCallableProcess(), "'process' in provided object is not callable" + ) + + class WriterWithNoParamProcess: + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess: + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer( + WriterWithNonCallableOpen(), "'open' in provided object is not callable" + ) + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer( + WriterWithNonCallableClose(), "'close' in provided object is not callable" + ) + + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_tempview(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + batch_df.createOrReplaceTempView("updates") + # it should use the spark session within given DataFrame, as microbatch execution will + # clone the session which is no longer same with the session used to start the + # streaming query + collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.errors import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise RuntimeError("this should fail the query") + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_graceful_stop(self): + # SPARK-39218: Make foreachBatch streaming query stop gracefully + def func(batch_df, _): + batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000) + + q = self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() + time.sleep(3) # 'rowsPerSecond' defaults to 1. Waits 3 secs out for the input. + q.stop() + self.assertIsNone(q.exception(), "No exception has to be propagated.") + + +class StreamingTestsForeachFamily(StreamingTestsForeachFamilyMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.streaming.test_streaming_foreach_family import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From 0ae7e33929ae09112b6dbd31e33f90f36ef71a2a Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Thu, 6 Apr 2023 17:59:00 -0700 Subject: [PATCH 02/12] add versionchanged to query and readwriter --- python/pyspark/sql/streaming/query.py | 3 +++ python/pyspark/sql/streaming/readwriter.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 0268de2da6ec..ca83fcddf522 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -36,6 +36,9 @@ class StreamingQuery: All these methods are thread-safe. .. versionadded:: 2.0.0 + + .. versionchanged:: 3.5.0 + Supports Spark Connect. Notes ----- diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index 16c44ddbbcbf..359d8cf7cd48 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -42,6 +42,9 @@ class DataStreamReader(OptionUtils): Use :attr:`SparkSession.readStream ` to access this. .. versionadded:: 2.0.0 + + .. versionchanged:: 3.5.0 + Supports Spark Connect. Notes ----- From 17720b7121eee050a546126fad7ce8229fe6bda2 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 10 Apr 2023 11:21:52 -0700 Subject: [PATCH 03/12] style --- python/pyspark/sql/streaming/query.py | 2 +- python/pyspark/sql/streaming/readwriter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 75dafa19b7c1..d909eba0a60d 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -36,7 +36,7 @@ class StreamingQuery: All these methods are thread-safe. .. versionadded:: 2.0.0 - + .. versionchanged:: 3.5.0 Supports Spark Connect. diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index 9fa5ee993df1..793c4f4cdd1a 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -42,7 +42,7 @@ class DataStreamReader(OptionUtils): Use :attr:`SparkSession.readStream ` to access this. .. versionadded:: 2.0.0 - + .. versionchanged:: 3.5.0 Supports Spark Connect. From 1e68a3c212fa7ae39b958d79eec7af77b4884859 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 10 Apr 2023 12:42:14 -0700 Subject: [PATCH 04/12] comments --- python/pyspark/sql/tests/streaming/test_streaming.py | 12 +++--------- .../tests/streaming/test_streaming_foreach_family.py | 6 +----- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 2b3903a855ae..838d413a0cc3 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -27,15 +27,15 @@ class StreamingTestsMixin: - def test_streaming_query_functions_sanity(self): + def test_streaming_query_functions_basic(self): df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() query = ( df.writeStream.format("memory") - .queryName("test_streaming_query_functions_sanity") + .queryName("test_streaming_query_functions_basic") .start() ) try: - self.assertEquals(query.name, "test_streaming_query_functions_sanity") + self.assertEquals(query.name, "test_streaming_query_functions_basic") self.assertTrue(isinstance(query.id, str)) self.assertTrue(isinstance(query.runId, str)) self.assertTrue(query.isActive) @@ -109,12 +109,6 @@ def test_stream_read_options_overwrite(self): .schema(bad_schema) .load(path="python/test_support/sql/streaming", schema=schema, format="text") ) - # TODO: Moving this outside of with block will trigger the following error, - # which doesn't happen in non-connect - # pyspark.errors.exceptions.connect.AnalysisException: - # There is a 'path' option set and load() is called with a path parameter. - # Either remove the path option, or call load() without the parameter. - # To ignore this check, set 'spark.sql.legacy.pathOptionBehavior.enabled' to 'true'. self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct") diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py index 86ecfebf6fe0..89fb32e5c034 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py @@ -22,7 +22,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingTestsForeachFamilyMixin: +class StreamingTestsForeachFamily(ReusedSQLTestCase): class ForeachWriterTester: def __init__(self, spark): self.spark = spark @@ -352,10 +352,6 @@ def func(batch_df, _): self.assertIsNone(q.exception(), "No exception has to be propagated.") -class StreamingTestsForeachFamily(StreamingTestsForeachFamilyMixin, ReusedSQLTestCase): - pass - - if __name__ == "__main__": import unittest from pyspark.sql.tests.streaming.test_streaming_foreach_family import * # noqa: F401 From dc05be8c79b7a3eea47f008dc6f5c137349203a1 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 10 Apr 2023 14:07:29 -0700 Subject: [PATCH 05/12] address comments, add a new foreachBatch test class, remove all ELLIPSIS flag as it's already in test options --- dev/sparktestsupport/modules.py | 3 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/streaming/query.py | 8 +- python/pyspark/sql/streaming/readwriter.py | 28 ++--- ...ch_family.py => test_streaming_foreach.py} | 4 +- .../streaming/test_streaming_foreachBatch.py | 104 ++++++++++++++++++ 6 files changed, 127 insertions(+), 22 deletions(-) rename python/pyspark/sql/tests/streaming/{test_streaming_foreach_family.py => test_streaming_foreach.py} (98%) create mode 100644 python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index a65789c1da8e..08924a86fd7c 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -495,7 +495,8 @@ def __hash__(self): "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", "pyspark.sql.tests.streaming.test_streaming", - "pyspark.sql.tests.streaming.test_streaming_foreach_family", + "pyspark.sql.tests.streaming.test_streaming_foreach", + "pyspark.sql.tests.streaming.test_streaming_foreachBatch", "pyspark.sql.tests.streaming.test_streaming_listener", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e7df25d20fcb..542b898015b4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -535,7 +535,7 @@ def writeStream(self) -> DataStreamWriter: >>> with tempfile.TemporaryDirectory() as d: ... # Create a table with Rate source. ... df.writeStream.toTable( - ... "my_table", checkpointLocation=d) # doctest: +ELLIPSIS + ... "my_table", checkpointLocation=d) <...streaming.query.StreamingQuery object at 0x...> """ return DataStreamWriter(self) diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index d909eba0a60d..fd0318f821b5 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -71,7 +71,7 @@ def id(self) -> str: Get the unique id of this query that persists across restarts from checkpoint data - >>> sq.id # doctest: +ELLIPSIS + >>> sq.id '...' >>> sq.stop() @@ -98,7 +98,7 @@ def runId(self) -> str: Get the unique id of this query that does not persist across restarts - >>> sq.runId # doctest: +ELLIPSIS + >>> sq.runId '...' >>> sq.stop() @@ -223,7 +223,7 @@ def status(self) -> Dict[str, Any]: Get the current status of the query - >>> sq.status # doctest: +ELLIPSIS + >>> sq.status {'message': '...', 'isDataAvailable': ..., 'isTriggerActive': ...} >>> sq.stop() @@ -252,7 +252,7 @@ def recentProgress(self) -> List[Dict[str, Any]]: Get an array of the most recent query progress updates for this query - >>> sq.recentProgress # doctest: +ELLIPSIS + >>> sq.recentProgress [...] >>> sq.stop() diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index 793c4f4cdd1a..df21b257f8b2 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -52,7 +52,7 @@ class DataStreamReader(OptionUtils): Examples -------- - >>> spark.readStream # doctest: +ELLIPSIS + >>> spark.readStream <...streaming.readwriter.DataStreamReader object ...> The example below uses Rate source that generates rows continuously. @@ -92,7 +92,7 @@ def format(self, source: str) -> "DataStreamReader": Examples -------- - >>> spark.readStream.format("text") # doctest: +ELLIPSIS + >>> spark.readStream.format("text") <...streaming.readwriter.DataStreamReader object ...> This API allows to configure other sources to read. The example below writes a small text @@ -135,9 +135,9 @@ def schema(self, schema: Union[StructType, str]) -> "DataStreamReader": Examples -------- >>> from pyspark.sql.types import StructField, StructType, StringType - >>> spark.readStream.schema(StructType([StructField("data", StringType(), True)])) # doctest: +ELLIPSIS + >>> spark.readStream.schema(StructType([StructField("data", StringType(), True)])) <...streaming.readwriter.DataStreamReader object ...> - >>> spark.readStream.schema("col0 INT, col1 DOUBLE") # doctest: +ELLIPSIS + >>> spark.readStream.schema("col0 INT, col1 DOUBLE") <...streaming.readwriter.DataStreamReader object ...> The example below specifies a different schema to CSV file. @@ -174,7 +174,7 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamReader" Examples -------- - >>> spark.readStream.option("x", 1) # doctest: +ELLIPSIS + >>> spark.readStream.option("x", 1) <...streaming.readwriter.DataStreamReader object ...> The example below specifies 'rowsPerSecond' option to Rate source in order to generate @@ -200,7 +200,7 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamReader": Examples -------- - >>> spark.readStream.options(x="1", y=2) # doctest: +ELLIPSIS + >>> spark.readStream.options(x="1", y=2) <...streaming.readwriter.DataStreamReader object ...> The example below specifies 'rowsPerSecond' and 'numPartitions' options to @@ -766,7 +766,7 @@ def outputMode(self, outputMode: str) -> "DataStreamWriter": Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.outputMode('append') # doctest: +ELLIPSIS + >>> df.writeStream.outputMode('append') <...streaming.readwriter.DataStreamWriter object ...> The example below uses Complete mode that the entire aggregated counts are printed out. @@ -800,7 +800,7 @@ def format(self, source: str) -> "DataStreamWriter": Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.format("text") # doctest: +ELLIPSIS + >>> df.writeStream.format("text") <...streaming.readwriter.DataStreamWriter object ...> This API allows to configure the source to write. The example below writes a CSV @@ -834,7 +834,7 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamWriter" Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.option("x", 1) # doctest: +ELLIPSIS + >>> df.writeStream.option("x", 1) <...streaming.readwriter.DataStreamWriter object ...> The example below specifies 'numRows' option to Console source in order to print @@ -862,7 +862,7 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamWriter": Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.option("x", 1) # doctest: +ELLIPSIS + >>> df.writeStream.option("x", 1) <...streaming.readwriter.DataStreamWriter object ...> The example below specifies 'numRows' and 'truncate' options to Console source in order @@ -907,7 +907,7 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] Examples -------- >>> df = spark.readStream.format("rate").load() - >>> df.writeStream.partitionBy("value") # doctest: +ELLIPSIS + >>> df.writeStream.partitionBy("value") <...streaming.readwriter.DataStreamWriter object ...> Partition-by timestamp column from Rate source. @@ -1017,17 +1017,17 @@ def trigger( Trigger the query for execution every 5 seconds - >>> df.writeStream.trigger(processingTime='5 seconds') # doctest: +ELLIPSIS + >>> df.writeStream.trigger(processingTime='5 seconds') <...streaming.readwriter.DataStreamWriter object ...> Trigger the query for execution every 5 seconds - >>> df.writeStream.trigger(continuous='5 seconds') # doctest: +ELLIPSIS + >>> df.writeStream.trigger(continuous='5 seconds') <...streaming.readwriter.DataStreamWriter object ...> Trigger the query for reading all available data with multiple batches - >>> df.writeStream.trigger(availableNow=True) # doctest: +ELLIPSIS + >>> df.writeStream.trigger(availableNow=True) <...streaming.readwriter.DataStreamWriter object ...> """ params = [processingTime, once, continuous, availableNow] diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py similarity index 98% rename from python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py rename to python/pyspark/sql/tests/streaming/test_streaming_foreach.py index 89fb32e5c034..bac0c45e8303 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach_family.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py @@ -22,7 +22,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingTestsForeachFamily(ReusedSQLTestCase): +class StreamingTestsForeach(ReusedSQLTestCase): class ForeachWriterTester: def __init__(self, spark): self.spark = spark @@ -354,7 +354,7 @@ def func(batch_df, _): if __name__ == "__main__": import unittest - from pyspark.sql.tests.streaming.test_streaming_foreach_family import * # noqa: F401 + from pyspark.sql.tests.streaming.test_streaming_foreach import * # noqa: F401 try: import xmlrunner diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py new file mode 100644 index 000000000000..7d56804c3353 --- /dev/null +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import tempfile +import time + +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class StreamingTestsForeachBatch(ReusedSQLTestCase): + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_tempview(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + batch_df.createOrReplaceTempView("updates") + # it should use the spark session within given DataFrame, as microbatch execution will + # clone the session which is no longer same with the session used to start the + # streaming query + collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.errors import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise RuntimeError("this should fail the query") + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_graceful_stop(self): + # SPARK-39218: Make foreachBatch streaming query stop gracefully + def func(batch_df, _): + batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000) + + q = self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() + time.sleep(3) # 'rowsPerSecond' defaults to 1. Waits 3 secs out for the input. + q.stop() + self.assertIsNone(q.exception(), "No exception has to be propagated.") + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.streaming.test_streaming_foreachBatch import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From c1674ebb85fafb42fc5453832ad7807e3807ba4b Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 10 Apr 2023 14:08:56 -0700 Subject: [PATCH 06/12] minor --- .../tests/streaming/test_streaming_foreach.py | 67 ------------------- 1 file changed, 67 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py index bac0c45e8303..ffaedd0a18fc 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py @@ -284,73 +284,6 @@ class WriterWithNonCallableClose(WithProcess): WriterWithNonCallableClose(), "'close' in provided object is not callable" ) - def test_streaming_foreachBatch(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - collected[batch_id] = batch_df.collect() - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_tempview(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - batch_df.createOrReplaceTempView("updates") - # it should use the spark session within given DataFrame, as microbatch execution will - # clone the session which is no longer same with the session used to start the - # streaming query - collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_propagates_python_errors(self): - from pyspark.errors import StreamingQueryException - - q = None - - def collectBatch(df, id): - raise RuntimeError("this should fail the query") - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.fail("Expected a failure") - except StreamingQueryException as e: - self.assertTrue("this should fail" in str(e)) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_graceful_stop(self): - # SPARK-39218: Make foreachBatch streaming query stop gracefully - def func(batch_df, _): - batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000) - - q = self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() - time.sleep(3) # 'rowsPerSecond' defaults to 1. Waits 3 secs out for the input. - q.stop() - self.assertIsNone(q.exception(), "No exception has to be propagated.") - if __name__ == "__main__": import unittest From 23b9c93e4ad16f050695cff82648861786d9832c Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 10 Apr 2023 14:10:16 -0700 Subject: [PATCH 07/12] minor --- python/pyspark/sql/streaming/query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index fd0318f821b5..6d14cc7560e8 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -16,7 +16,6 @@ # import json -import sys from typing import Any, Dict, List, Optional from py4j.java_gateway import JavaObject, java_import @@ -638,6 +637,7 @@ def removeListener(self, listener: StreamingQueryListener) -> None: def _test() -> None: import doctest import os + import sys from pyspark.sql import SparkSession import pyspark.sql.streaming.query from py4j.protocol import Py4JError From 60ddd01191aaa0487bf54c5b07b1f04ae214bc18 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 10 Apr 2023 22:35:20 -0700 Subject: [PATCH 08/12] lint --- python/pyspark/sql/connect/streaming/query.py | 1 - .../sql/tests/connect/streaming/test_parity_streaming.py | 1 - python/pyspark/sql/tests/streaming/test_streaming_foreach.py | 1 - .../pyspark/sql/tests/streaming/test_streaming_foreachBatch.py | 2 -- 4 files changed, 5 deletions(-) diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 3e8c679a9bca..64455e8b394e 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -154,7 +154,6 @@ def _test() -> None: import os from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.streaming.query - from py4j.protocol import Py4JError os.chdir(os.environ["SPARK_HOME"]) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py index d28d2c0524a9..6b4460bab521 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py @@ -17,7 +17,6 @@ import unittest -from pyspark.testing.connectutils import should_test_connect from pyspark.sql.tests.streaming.test_streaming import StreamingTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py index ffaedd0a18fc..8bd36020c9ad 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py @@ -17,7 +17,6 @@ import os import tempfile -import time from pyspark.testing.sqlutils import ReusedSQLTestCase diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py index 7d56804c3353..7e5720e42999 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -15,8 +15,6 @@ # limitations under the License. # -import os -import tempfile import time from pyspark.testing.sqlutils import ReusedSQLTestCase From e576821a78c65b8d721ea0bddfda326c47ee99ce Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 11 Apr 2023 10:52:13 -0700 Subject: [PATCH 09/12] remove empty line --- python/pyspark/sql/streaming/readwriter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index df21b257f8b2..e4c38dc07925 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -1364,7 +1364,6 @@ def start( >>> q.name 'this_query' >>> q.stop() - >>> q.isActive # doctest: +SKIP False From 26e2488274e07ce9b7e8808740c70aa3ec7f6610 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 11 Apr 2023 17:35:47 -0700 Subject: [PATCH 10/12] remove several docs in connect readwriter.py and query.py to pass doc test --- dev/sparktestsupport/modules.py | 2 -- python/pyspark/sql/connect/streaming/query.py | 7 ++++-- .../sql/connect/streaming/readwriter.py | 12 ++++++++-- python/pyspark/sql/streaming/query.py | 6 ++--- python/pyspark/sql/streaming/readwriter.py | 22 +++++++++---------- 5 files changed, 27 insertions(+), 22 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 08924a86fd7c..7b1d57b95d5b 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -751,8 +751,6 @@ def __hash__(self): "pyspark.sql.connect.dataframe", "pyspark.sql.connect.functions", "pyspark.sql.connect.avro.functions", - "pyspark.sql.connect.streaming.readwriter", - "pyspark.sql.connect.streaming.query", # sql unittests "pyspark.sql.tests.connect.test_client", "pyspark.sql.tests.connect.test_connect_plan", diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 64455e8b394e..aebab9fc69fd 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -66,10 +66,11 @@ def isActive(self) -> bool: isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__ + # TODO (SPARK-42960): Implement and uncomment the doc def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: raise NotImplementedError() - awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__ + # awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__ @property def status(self) -> Dict[str, Any]: @@ -115,7 +116,8 @@ def stop(self) -> None: cmd.stop = True self._execute_streaming_query_cmd(cmd) - stop.__doc__ = PySparkStreamingQuery.stop.__doc__ + # TODO (SPARK-42962): uncomment below + # stop.__doc__ = PySparkStreamingQuery.stop.__doc__ def explain(self, extended: bool = False) -> None: cmd = pb2.StreamingQueryCommand() @@ -125,6 +127,7 @@ def explain(self, extended: bool = False) -> None: explain.__doc__ = PySparkStreamingQuery.explain.__doc__ + # TODO (SPARK-42960): Implement and uncomment the doc def exception(self) -> Optional[StreamingQueryException]: raise NotImplementedError() diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 6246d24a0928..b89a6db1a9d7 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -432,6 +432,7 @@ def trigger( trigger.__doc__ = PySparkDataStreamWriter.trigger.__doc__ + # TODO (SPARK-43054): Implement and uncomment the doc @overload def foreach(self, f: Callable[[Row], None]) -> "DataStreamWriter": ... @@ -443,7 +444,13 @@ def foreach(self, f: "SupportsProcess") -> "DataStreamWriter": def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataStreamWriter": raise NotImplementedError("foreach() is not implemented.") - foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__ + # foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__ + + # TODO (SPARK-42944): Implement and uncomment the doc + def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter": + raise NotImplementedError("foreachBatch() is not implemented.") + + # foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__ def _start_internal( self, @@ -501,7 +508,8 @@ def start( **options, ) - start.__doc__ = PySparkDataStreamWriter.start.__doc__ + # TODO (SPARK-42962): uncomment below + # start.__doc__ = PySparkDataStreamWriter.start.__doc__ def toTable( self, diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 6d14cc7560e8..b902f0514fce 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -156,7 +156,6 @@ def isActive(self) -> bool: """ return self._jsq.isActive() - # TODO(SPARK-42960): remove the doctest: +SKIP flag below def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: """ Waits for the termination of `this` query, either by :func:`query.stop()` or by an @@ -191,7 +190,7 @@ def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: Return whether the query has terminated or not within 5 seconds - >>> sq.awaitTermination(5) # doctest: +SKIP + >>> sq.awaitTermination(5) False >>> sq.stop() @@ -317,7 +316,6 @@ def processAllAvailable(self) -> None: """ return self._jsq.processAllAvailable() - # TODO(SPARK-42940): remove the doctest: +SKIP flag below def stop(self) -> None: """ Stop this streaming query. @@ -335,7 +333,7 @@ def stop(self) -> None: >>> sq.stop() - >>> sq.isActive # doctest: +SKIP + >>> sq.isActive False """ self._jsq.stop() diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index e4c38dc07925..529e3aeb60d9 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -1162,9 +1162,9 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt >>> def print_row(row): ... print(row) ... - >>> q = df.writeStream.foreach(print_row).start() # doctest: +SKIP - >>> time.sleep(3) # doctest: +SKIP - >>> q.stop() # doctest: +SKIP + >>> q = df.writeStream.foreach(print_row).start() + >>> time.sleep(3) + >>> q.stop() Print every row using a object with process() method @@ -1179,9 +1179,9 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt ... def close(self, error): ... print("Closed with error: %s" % str(error)) ... - >>> q = df.writeStream.foreach(print_row).start() # doctest: +SKIP - >>> time.sleep(3) # doctest: +SKIP - >>> q.stop() # doctest: +SKIP + >>> q = df.writeStream.foreach(print_row).start() + >>> time.sleep(3) + >>> q.stop() """ from pyspark.rdd import _wrap_function @@ -1264,7 +1264,6 @@ def func_with_open_process_close(partition_id: Any, iterator: Iterator) -> Itera self._jwrite.foreach(jForeachWriter) return self - # TODO(SPARK-42944): remove the doctest: +SKIP flag below def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter": """ Sets the output of the streaming query to be processed using the provided @@ -1289,9 +1288,9 @@ def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamW >>> def func(batch_df, batch_id): ... batch_df.collect() ... - >>> q = df.writeStream.foreachBatch(func).start() # doctest: +SKIP - >>> time.sleep(3) # doctest: +SKIP - >>> q.stop() # doctest: +SKIP + >>> q = df.writeStream.foreachBatch(func).start() + >>> time.sleep(3) + >>> q.stop() """ from pyspark.java_gateway import ensure_callback_server_started @@ -1305,7 +1304,6 @@ def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamW ensure_callback_server_started(gw) return self - # TODO(SPARK-42940): remove the doctest: +SKIP flag below def start( self, path: Optional[str] = None, @@ -1364,7 +1362,7 @@ def start( >>> q.name 'this_query' >>> q.stop() - >>> q.isActive # doctest: +SKIP + >>> q.isActive False Example with using other parameters with a trigger. From e25f7e6c27fa105b30bb8bf907ee7457fe12fe28 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 11 Apr 2023 18:01:17 -0700 Subject: [PATCH 11/12] minor, add back doc tests in module.py --- dev/sparktestsupport/modules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7b1d57b95d5b..08924a86fd7c 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -751,6 +751,8 @@ def __hash__(self): "pyspark.sql.connect.dataframe", "pyspark.sql.connect.functions", "pyspark.sql.connect.avro.functions", + "pyspark.sql.connect.streaming.readwriter", + "pyspark.sql.connect.streaming.query", # sql unittests "pyspark.sql.tests.connect.test_client", "pyspark.sql.tests.connect.test_connect_plan", From aa1d4c288d5a7a9b9d083365e478e7838a69e3b7 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 11 Apr 2023 22:39:32 -0700 Subject: [PATCH 12/12] style --- python/pyspark/sql/connect/streaming/readwriter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index b89a6db1a9d7..e702b3523a4a 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -449,7 +449,7 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt # TODO (SPARK-42944): Implement and uncomment the doc def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter": raise NotImplementedError("foreachBatch() is not implemented.") - + # foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__ def _start_internal(