Skip to content
5 changes: 5 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ def __hash__(self):
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
"pyspark.sql.tests.streaming.test_streaming",
"pyspark.sql.tests.streaming.test_streaming_foreach",
"pyspark.sql.tests.streaming.test_streaming_foreachBatch",
"pyspark.sql.tests.streaming.test_streaming_listener",
"pyspark.sql.tests.test_types",
"pyspark.sql.tests.test_udf",
Expand Down Expand Up @@ -749,6 +751,8 @@ def __hash__(self):
"pyspark.sql.connect.dataframe",
"pyspark.sql.connect.functions",
"pyspark.sql.connect.avro.functions",
"pyspark.sql.connect.streaming.readwriter",
"pyspark.sql.connect.streaming.query",
# sql unittests
"pyspark.sql.tests.connect.test_client",
"pyspark.sql.tests.connect.test_connect_plan",
Expand All @@ -773,6 +777,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_arrow_map",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
"pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map",
"pyspark.sql.tests.connect.streaming.test_parity_streaming",
# ml doctests
"pyspark.ml.connect.functions",
# ml unittests
Expand Down
35 changes: 30 additions & 5 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import json
import sys
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional

from pyspark.errors import StreamingQueryException
Expand Down Expand Up @@ -65,10 +66,11 @@ def isActive(self) -> bool:

isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__

# TODO (SPARK-42960): Implement and uncomment the doc
def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]:
raise NotImplementedError()

awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
# awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__

@property
def status(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -114,7 +116,8 @@ def stop(self) -> None:
cmd.stop = True
self._execute_streaming_query_cmd(cmd)

stop.__doc__ = PySparkStreamingQuery.stop.__doc__
# TODO (SPARK-42962): uncomment below
# stop.__doc__ = PySparkStreamingQuery.stop.__doc__

def explain(self, extended: bool = False) -> None:
cmd = pb2.StreamingQueryCommand()
Expand All @@ -124,6 +127,7 @@ def explain(self, extended: bool = False) -> None:

explain.__doc__ = PySparkStreamingQuery.explain.__doc__

# TODO (SPARK-42960): Implement and uncomment the doc
def exception(self) -> Optional[StreamingQueryException]:
raise NotImplementedError()

Expand All @@ -149,10 +153,31 @@ def _execute_streaming_query_cmd(


def _test() -> None:
# TODO(SPARK-43031): port _test() from legacy query.py.
pass
import doctest
import os
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.streaming.query

os.chdir(os.environ["SPARK_HOME"])

globs = pyspark.sql.connect.streaming.query.__dict__.copy()

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.streaming.query tests")
.remote("local[4]")
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.streaming.query,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
globs["spark"].stop()

if failure_count:
sys.exit(-1)


if __name__ == "__main__":
# TODO(SPARK-43031): Add this file dev/sparktestsupport/modules.py to enable testing in CI.
_test()
114 changes: 105 additions & 9 deletions python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,75 @@ def json(

Copy link
Contributor Author

@WweiL WweiL Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ignore the change in this file. They are added and to be reviewed in #40689.

json.__doc__ = PySparkDataStreamReader.json.__doc__

# def orc() TODO
# def parquet() TODO
# def text() TODO
def orc(
self,
path: str,
mergeSchema: Optional[bool] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
) -> "DataFrame":
self._set_opts(
mergeSchema=mergeSchema,
pathGlobFilter=pathGlobFilter,
recursiveFileLookup=recursiveFileLookup,
)
if isinstance(path, str):
return self.load(path=path, format="orc")
else:
raise TypeError("path can be only a single string")

orc.__doc__ = PySparkDataStreamReader.orc.__doc__

def parquet(
self,
path: str,
mergeSchema: Optional[bool] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
datetimeRebaseMode: Optional[Union[bool, str]] = None,
int96RebaseMode: Optional[Union[bool, str]] = None,
) -> "DataFrame":
self._set_opts(
mergeSchema=mergeSchema,
pathGlobFilter=pathGlobFilter,
recursiveFileLookup=recursiveFileLookup,
datetimeRebaseMode=datetimeRebaseMode,
int96RebaseMode=int96RebaseMode,
)
self._set_opts(
mergeSchema=mergeSchema,
pathGlobFilter=pathGlobFilter,
recursiveFileLookup=recursiveFileLookup,
datetimeRebaseMode=datetimeRebaseMode,
int96RebaseMode=int96RebaseMode,
)
if isinstance(path, str):
return self.load(path=path, format="parquet")
else:
raise TypeError("path can be only a single string")

parquet.__doc__ = PySparkDataStreamReader.parquet.__doc__

def text(
self,
path: str,
wholetext: bool = False,
lineSep: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
) -> "DataFrame":
self._set_opts(
wholetext=wholetext,
lineSep=lineSep,
pathGlobFilter=pathGlobFilter,
recursiveFileLookup=recursiveFileLookup,
)
if isinstance(path, str):
return self.load(path=path, format="text")
else:
raise TypeError("path can be only a single string")

text.__doc__ = PySparkDataStreamReader.text.__doc__

def csv(
self,
Expand Down Expand Up @@ -245,7 +311,7 @@ def csv(

csv.__doc__ = PySparkDataStreamReader.csv.__doc__

# def table() TODO. Use Read(table_name) relation.
# def table() TODO(SPARK-43042). Use Read(table_name) relation.


DataStreamReader.__doc__ = PySparkDataStreamReader.__doc__
Expand Down Expand Up @@ -366,6 +432,7 @@ def trigger(

trigger.__doc__ = PySparkDataStreamWriter.trigger.__doc__

# TODO (SPARK-43054): Implement and uncomment the doc
@overload
def foreach(self, f: Callable[[Row], None]) -> "DataStreamWriter":
...
Expand All @@ -377,7 +444,13 @@ def foreach(self, f: "SupportsProcess") -> "DataStreamWriter":
def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataStreamWriter":
raise NotImplementedError("foreach() is not implemented.")

foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__
# foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__

# TODO (SPARK-42944): Implement and uncomment the doc
def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter":
raise NotImplementedError("foreachBatch() is not implemented.")

# foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__

def _start_internal(
self,
Expand Down Expand Up @@ -435,7 +508,8 @@ def start(
**options,
)

start.__doc__ = PySparkDataStreamWriter.start.__doc__
# TODO (SPARK-42962): uncomment below
# start.__doc__ = PySparkDataStreamWriter.start.__doc__

def toTable(
self,
Expand All @@ -460,10 +534,32 @@ def toTable(


def _test() -> None:
# TODO(SPARK-43031): port _test() from legacy query.py.
pass
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.streaming.readwriter

globs = pyspark.sql.connect.readwriter.__dict__.copy()

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.streaming.readwriter tests")
.remote("local[4]")
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.streaming.readwriter,
globs=globs,
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL,
)

globs["spark"].stop()

if failure_count:
sys.exit(-1)


if __name__ == "__main__":
# TODO(SPARK-43031): Add this file dev/sparktestsupport/modules.py to enable testing in CI.
_test()
2 changes: 1 addition & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def writeStream(self) -> DataStreamWriter:
>>> with tempfile.TemporaryDirectory() as d:
... # Create a table with Rate source.
... df.writeStream.toTable(
... "my_table", checkpointLocation=d) # doctest: +ELLIPSIS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what does # doctest: +ELLIPSIS mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that in doctest the result will be kind of regex checked if this flag is set. Like the line below right now is <...streaming.query.StreamingQuery object at 0x...>, but before it was <pyspark.sql.streaming.query.StreamingQuery object at 0x...>, which would conflict with connect's test, which returns <pyspark.sql.connect.streaming.query.StreamingQuery object at 0x...>

So to make this test works for both connect and non-connect, we enable this regex-like check and replace below with ...
But it doesn't matter anyway, as we enabled the flag in test options in the __main__ method below

... "my_table", checkpointLocation=d)
<...streaming.query.StreamingQuery object at 0x...>
"""
return DataStreamWriter(self)
Expand Down
14 changes: 9 additions & 5 deletions python/pyspark/sql/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

import json
import sys
from typing import Any, Dict, List, Optional

from py4j.java_gateway import JavaObject, java_import
Expand All @@ -37,6 +36,9 @@ class StreamingQuery:

.. versionadded:: 2.0.0

.. versionchanged:: 3.5.0
Supports Spark Connect.

Notes
-----
This API is evolving.
Expand Down Expand Up @@ -68,7 +70,7 @@ def id(self) -> str:

Get the unique id of this query that persists across restarts from checkpoint data

>>> sq.id # doctest: +ELLIPSIS
>>> sq.id
'...'

>>> sq.stop()
Expand All @@ -95,7 +97,7 @@ def runId(self) -> str:

Get the unique id of this query that does not persist across restarts

>>> sq.runId # doctest: +ELLIPSIS
>>> sq.runId
'...'

>>> sq.stop()
Expand Down Expand Up @@ -219,7 +221,7 @@ def status(self) -> Dict[str, Any]:

Get the current status of the query

>>> sq.status # doctest: +ELLIPSIS
>>> sq.status
{'message': '...', 'isDataAvailable': ..., 'isTriggerActive': ...}

>>> sq.stop()
Expand Down Expand Up @@ -248,7 +250,7 @@ def recentProgress(self) -> List[Dict[str, Any]]:

Get an array of the most recent query progress updates for this query

>>> sq.recentProgress # doctest: +ELLIPSIS
>>> sq.recentProgress
[...]

>>> sq.stop()
Expand Down Expand Up @@ -330,6 +332,7 @@ def stop(self) -> None:
Stop streaming query

>>> sq.stop()

>>> sq.isActive
False
"""
Expand Down Expand Up @@ -632,6 +635,7 @@ def removeListener(self, listener: StreamingQueryListener) -> None:
def _test() -> None:
import doctest
import os
import sys
from pyspark.sql import SparkSession
import pyspark.sql.streaming.query
from py4j.protocol import Py4JError
Expand Down
Loading