diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5a13674e8bfb..6c9377a436eb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -53,6 +53,7 @@ private[spark] object PythonEvalType { val SQL_MAP_PANDAS_ITER_UDF = 205 val SQL_COGROUPED_MAP_PANDAS_UDF = 206 val SQL_MAP_ARROW_ITER_UDF = 207 + val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -65,6 +66,7 @@ private[spark] object PythonEvalType { case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" } } @@ -537,13 +539,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() + val funcInitTime = stream.readLong() val finishTime = stream.readLong() val boot = bootTime - startTime val init = initTime - bootTime - val finish = finishTime - initTime + val funcInit = funcInitTime - initTime + val finish = finishTime - funcInitTime val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) + logInfo("Times: total = %s, boot = %s, init = %s, func_init = %s, finish = %s" + .format(total, boot, init, funcInit, finish)) val memoryBytesSpilled = stream.readLong() val diskBytesSpilled = stream.readLong() context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index aef79c7882ca..484a07c18ed0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -3051,7 +3051,7 @@ private[spark] object Utils extends Logging { * and return the trailing part after the last dollar sign in the middle */ @scala.annotation.tailrec - private def stripDollars(s: String): String = { + def stripDollars(s: String): String = { val lastDollarIndex = s.lastIndexOf('$') if (lastDollarIndex < s.length - 1) { // The last char is not a dollar sign diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2b9d52693794..f9e2144d334e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -452,6 +452,7 @@ def __hash__(self): "pyspark.sql.tests.test_group", "pyspark.sql.tests.test_pandas_cogrouped_map", "pyspark.sql.tests.test_pandas_grouped_map", + "pyspark.sql.tests.test_pandas_grouped_map_with_state", "pyspark.sql.tests.test_pandas_map", "pyspark.sql.tests.test_arrow_map", "pyspark.sql.tests.test_pandas_udf", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7ef0014ae751..5f4f4d494e13 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -105,6 +105,7 @@ PandasMapIterUDFType, PandasCogroupedMapUDFType, ArrowMapIterUDFType, + PandasGroupedMapUDFWithStateType, ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType @@ -147,6 +148,7 @@ class PythonEvalType: SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 + SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 def portable_hash(x: Hashable) -> int: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 27ac64a7238b..7b972edc88dd 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -22,19 +22,19 @@ from typing import ( Iterable, NewType, Tuple, - Type, TypeVar, Union, ) from typing_extensions import Protocol, Literal from types import FunctionType -from pyspark.sql._typing import LiteralType +import pyarrow from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries from numpy import ndarray as NDArray -import pyarrow +from pyspark.sql._typing import LiteralType +from pyspark.sql.streaming.state import GroupStateImpl ArrayLike = NDArray DataFrameLike = PandasDataFrame @@ -51,6 +51,7 @@ PandasScalarIterUDFType = Literal[204] PandasMapIterUDFType = Literal[205] PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] +PandasGroupedMapUDFWithStateType = Literal[208] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... @@ -253,9 +254,11 @@ PandasScalarIterFunction = Union[ PandasGroupedMapFunction = Union[ Callable[[DataFrameLike], DataFrameLike], - Callable[[Any, DataFrameLike], DataFrameLike], + Callable[[Tuple, DataFrameLike], DataFrameLike], ] +PandasGroupedMapFunctionWithState = Callable[[Tuple, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]] + class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 94fabdbb2959..1c6c2219edce 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, None, ]: # None means it should infer the type from type hints. @@ -399,6 +400,7 @@ def _create_pandas_udf(f, returnType, evalType): ) elif evalType in [ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 6178433573e9..948fe5ce7135 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -15,18 +15,20 @@ # limitations under the License. # import sys -from typing import List, Union, TYPE_CHECKING +from typing import List, Union, TYPE_CHECKING, cast import warnings from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType +from pyspark.sql.streaming.state import GroupStateTimeout +from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, PandasGroupedMapFunction, + PandasGroupedMapFunctionWithState, PandasCogroupedMapFunction, ) from pyspark.sql.group import GroupedData @@ -216,6 +218,45 @@ def applyInPandas( jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.session) + def applyInPandasWithState( + self, + func: "PandasGroupedMapFunctionWithState", + outputStructType: Union[StructType, str], + stateStructType: Union[StructType, str], + outputMode: str, + timeoutConf: str, + ) -> DataFrame: + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + + if isinstance(outputStructType, str): + outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + if isinstance(stateStructType, str): + stateStructType = cast(StructType, _parse_datatype_string(stateStructType)) + + udf = pandas_udf( + func, # type: ignore[call-overload] + returnType=outputStructType, + functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ) + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.applyInPandasWithState( + udf_column._jc.expr(), + self.session._jsparkSession.parseDataType(outputStructType.json()), + self.session._jsparkSession.parseDataType(stateStructType.json()), + outputMode, + timeoutConf, + ) + return DataFrame(jdf, self.session) + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ Cogroups this group with another group so that we can run cogrouped operations. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 992e82b403a1..2f8bac1092cc 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -18,8 +18,12 @@ """ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ +import sys +import time -from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer +from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer +from pyspark.sql.pandas.types import to_arrow_type +from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType class SpecialLengths: @@ -371,3 +375,249 @@ def load_stream(self, stream): raise ValueError( "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group) ) + + +class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + + def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema, + softLimitBytesPerBatch, minDataCountForSample, softTimeoutMillisPurgeBatch): + super(ApplyInPandasWithStateSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name) + self.pickleSer = CPickleSerializer() + self.utf8_deserializer = UTF8Deserializer() + self.state_object_schema = state_object_schema + + self.result_state_df_type = StructType([ + StructField('properties', StringType()), + StructField('keyRowAsUnsafe', BinaryType()), + StructField('object', BinaryType()), + StructField('oldTimeoutTimestamp', LongType()), + ]) + + self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) + self.softLimitBytesPerBatch = softLimitBytesPerBatch + self.minDataCountForSample = minDataCountForSample + self.softTimeoutMillisPurgeBatch = softTimeoutMillisPurgeBatch + + def load_stream(self, stream): + import pyarrow as pa + import json + from itertools import groupby + from pyspark.sql.streaming.state import GroupStateImpl + + def gen_data_and_state(batches): + state_for_current_group = None + + for batch in batches: + batch_schema = batch.schema + data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) + state_schema = pa.schema([batch_schema[-1], ]) + + batch_columns = batch.columns + data_columns = batch_columns[0:-1] + state_column = batch_columns[-1] + + data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) + state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) + + state_arrow = pa.Table.from_batches([state_batch]).itercolumns() + state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] + + for state_idx in range(0, len(state_pandas)): + state_info_col = state_pandas.iloc[state_idx] + + if not state_info_col: + # no more data with grouping key + state + break + + state_info_col_properties = state_info_col['properties'] + state_info_col_key_row = state_info_col['keyRowAsUnsafe'] + state_info_col_object = state_info_col['object'] + + data_start_offset = state_info_col['startOffset'] + num_data_rows = state_info_col['numRows'] + is_last_chunk = state_info_col['isLastChunk'] + + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + if state_for_current_group: + # use the state, we already have state for same group and there should be some + # data in same group being processed earlier + state = state_for_current_group + else: + # there is no state being stored for same group, construct one + state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties) + + if is_last_chunk: + # discard the state being cached for same group + state_for_current_group = None + elif not state_for_current_group: + # there's no cached state but expected to have additional data in same group + # cache the current state + state_for_current_group = state + + data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) + data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() + + data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] + + # state info + yield (data_pandas, state, ) + + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + + data_state_generator = gen_data_and_state(batches) + + # state will be same object for same grouping key + for state, data in groupby(data_state_generator, key=lambda x: x[1]): + yield (data, state,) + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + """ + Arrow RecordBatch requires all columns to have all same number of rows. + Insert empty data for state/data with less elements to compensate. + """ + + import pandas as pd + import pyarrow as pa + + max_data_cnt = max(pdf_data_cnt, state_data_cnt) + + empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt + empty_row_cnt_in_state = max_data_cnt - state_data_cnt + + empty_rows_pdf = pd.DataFrame( + dict.fromkeys(pa.schema(pdf_schema).names), + index=[x for x in range(0, empty_row_cnt_in_data)]) + empty_rows_state = pd.DataFrame( + columns=['properties', 'keyRowAsUnsafe', 'object', 'oldTimeoutTimestamp'], + index=[x for x in range(0, empty_row_cnt_in_state)]) + + pdfs.append(empty_rows_pdf) + state_pdfs.append(empty_rows_state) + + merged_pdf = pd.concat(pdfs, ignore_index=True) + merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) + + return self._create_batch([ + (merged_pdf, pdf_schema), + (merged_state_pdf, self.result_state_pdf_arrow_type)]) + + def init_stream_yield_batches(): + import pandas as pd + import pyarrow as pa + + should_write_start_length = True + + pdfs = [] + state_pdfs = [] + return_schema = None + + pdf_data_cnt = 0 + state_data_cnt = 0 + + sampled_data_size_per_row = 0 + + last_purged_time_ns = time.time_ns() + + for data in iterator: + packaged_result = data[0] + + pdf_iter = packaged_result[0][0] + state = packaged_result[0][1] + # this won't change across batches + return_schema = packaged_result[1] + + for pdf in pdf_iter: + if len(pdf) > 0: + pdf_data_cnt += len(pdf) + pdfs.append(pdf) + + if sampled_data_size_per_row == 0 and \ + pdf_data_cnt > self.minDataCountForSample: + memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs] + sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt + + # This effectively works after the sampling has completed, size we multiply by 0 + # if the sampling is still in progress. + batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) >= \ + self.softLimitBytesPerBatch + + if batch_over_limit_on_size: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + last_purged_time_ns = time.time_ns() + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + # pick up state for only last chunk as state should have been updated so far + state_properties = state.json().encode("utf-8") + state_key_row_as_binary = state._keyAsUnsafe + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + state_old_timeout_timestamp = state.oldTimeoutTimestamp + + state_dict = { + 'properties': [state_properties, ], + 'keyRowAsUnsafe': [state_key_row_as_binary, ], + 'object': [state_object, ], + 'oldTimeoutTimestamp': [state_old_timeout_timestamp, ], + } + + state_pdf = pd.DataFrame.from_dict(state_dict) + + state_pdfs.append(state_pdf) + state_data_cnt += 1 + + cur_time_ns = time.time_ns() + is_timed_out_on_purge = ((cur_time_ns - last_purged_time_ns) // 1000000) >= \ + self.softTimeoutMillisPurgeBatch + if is_timed_out_on_purge: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + last_purged_time_ns = cur_time_ns + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + # end of loop, we may have remaining data + if pdf_data_cnt > 0 or state_data_cnt > 0: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py new file mode 100644 index 000000000000..c036f9704557 --- /dev/null +++ b/python/pyspark/sql/streaming/state.py @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import datetime +import json +from typing import Tuple, Optional + +from pyspark.sql.types import DateType, Row, StructType + +__all__ = ["GroupStateImpl", "GroupStateTimeout"] + + +class GroupStateTimeout: + NoTimeout: str = "NoTimeout" + ProcessingTimeTimeout: str = "ProcessingTimeTimeout" + EventTimeTimeout: str = "EventTimeTimeout" + + +class GroupStateImpl: + NO_TIMESTAMP: int = -1 + + def __init__( + self, + # JVM Constructor + optionalValue: Row, + batchProcessingTimeMs: int, + eventTimeWatermarkMs: int, + timeoutConf: str, + hasTimedOut: bool, + watermarkPresent: bool, + # JVM internal state. + defined: bool, + updated: bool, + removed: bool, + timeoutTimestamp: int, + # Python internal state. + keyAsUnsafe: bytes, + valueSchema: StructType, + ) -> None: + self._keyAsUnsafe = keyAsUnsafe + self._value = optionalValue + self._batch_processing_time_ms = batchProcessingTimeMs + self._event_time_watermark_ms = eventTimeWatermarkMs + + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + self._timeout_conf = timeoutConf + + self._has_timed_out = hasTimedOut + self._watermark_present = watermarkPresent + + self._defined = defined + self._updated = updated + self._removed = removed + self._timeout_timestamp = timeoutTimestamp + # Python internal state. + self._old_timeout_timestamp = timeoutTimestamp + + self._value_schema = valueSchema + + @property + def exists(self) -> bool: + return self._defined + + @property + def get(self) -> Tuple: + if self.exists: + return tuple(self._value) + else: + raise ValueError("State is either not defined or has already been removed") + + @property + def getOption(self) -> Optional[Tuple]: + if self.exists: + return tuple(self._value) + else: + return None + + @property + def hasTimedOut(self) -> bool: + return self._has_timed_out + + # NOTE: this function is only available to PySpark implementation due to underlying + # implementation, do not port to Scala implementation! + @property + def oldTimeoutTimestamp(self) -> int: + return self._old_timeout_timestamp + + def update(self, newValue: Tuple) -> None: + if newValue is None: + raise ValueError("'None' is not a valid state value") + + self._value = Row(*newValue) + self._defined = True + self._updated = True + self._removed = False + + def remove(self) -> None: + self._defined = False + self._updated = False + self._removed = True + + def setTimeoutDuration(self, durationMs: int) -> None: + if isinstance(durationMs, str): + # TODO(SPARK-XXXXX): Support string representation of durationMs. + raise ValueError("durationMs should be int but get :%s" % type(durationMs)) + + if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if durationMs <= 0: + raise ValueError("Timeout duration must be positive") + self._timeout_timestamp = durationMs + self._batch_processing_time_ms + + # TODO(SPARK-XXXXX): Implement additionalDuration parameter. + def setTimeoutTimestamp(self, timestampMs: int) -> None: + if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if isinstance(timestampMs, datetime.datetime): + timestampMs = DateType().toInternal(timestampMs) + + if timestampMs <= 0: + raise ValueError("Timeout timestamp must be positive") + + if ( + self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + and timestampMs < self._event_time_watermark_ms + ): + raise ValueError( + "Timeout timestamp (%s) cannot be earlier than the " + "current watermark (%s)" % (timestampMs, self._event_time_watermark_ms) + ) + + self._timeout_timestamp = timestampMs + + def getCurrentWatermarkMs(self) -> int: + if not self._watermark_present: + raise RuntimeError( + "Cannot get event time watermark timestamp without setting watermark before " + "applyInPandasWithState" + ) + return self._event_time_watermark_ms + + def getCurrentProcessingTimeMs(self) -> int: + return self._batch_processing_time_ms + + def __str__(self) -> str: + if self.exists: + return "GroupState(%s)" % (self.get, ) + else: + return "GroupState()" + + def json(self) -> str: + return json.dumps( + { + # Constructor + "optionalValue": None, # Note that optionalValue will be manually serialized. + "batchProcessingTimeMs": self._batch_processing_time_ms, + "eventTimeWatermarkMs": self._event_time_watermark_ms, + "timeoutConf": self._timeout_conf, + "hasTimedOut": self._has_timed_out, + "watermarkPresent": self._watermark_present, + # JVM internal state. + "defined": self._defined, + "updated": self._updated, + "removed": self._removed, + "timeoutTimestamp": self._timeout_timestamp, + } + ) diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py new file mode 100644 index 000000000000..9271853ab625 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import cast + +from pyspark.sql.streaming.state import GroupStateTimeout, GroupStateImpl +from pyspark.sql.types import ( + LongType, + StringType, + StructType, + StructField, + Row, +) +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pandas: + import pandas as pd + +if have_pyarrow: + import pyarrow as pa # noqa: F401 + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): + def test_apply_in_pandas_with_state_basic(self): + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_type = StructType( + [StructField("key", StringType()), StructField("countAsString", StringType())] + ) + state_type = StructType([StructField("c", LongType())]) + + def func(key, pdf_iter, state): + assert isinstance(state, GroupStateImpl) + + total_len = 0 + for pdf in pdf_iter: + total_len += len(pdf) + + state.update((total_len,)) + assert state.get[0] == 1 + yield pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]}) + + def check_results(batch_df, _): + self.assertEqual( + set(batch_df.collect()), + {Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")}, + ) + + q = ( + df.groupBy(df["value"]) + .applyInPandasWithState( + func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout + ) + .writeStream.queryName("this_query") + .foreachBatch(check_results) + .start() + ) + + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 6a01e399d040..417896ab738c 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -144,20 +144,23 @@ def returnType(self) -> DataType: "Invalid return type with scalar Pandas UDFs: %s is " "not supported" % str(self._returnType_placeholder) ) - elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif ( + self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE + ): if isinstance(self._returnType_placeholder, StructType): try: to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid return type with grouped map Pandas UDFs or " - "at groupby.applyInPandas: %s is not supported" + "at groupby.applyInPandas(withState): %s is not supported" % str(self._returnType_placeholder) ) else: raise TypeError( "Invalid return type for grouped map Pandas " - "UDFs or at groupby.applyInPandas: return type must be a " + "UDFs or at groupby.applyInPandas(withState): return type must be a " "StructType." ) elif ( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c486b7bed1d8..3ba317822196 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,7 @@ import time from inspect import currentframe, getframeinfo, getfullargspec import importlib +import json # 'resource' is a Unix specific module. has_resource_module = True @@ -57,6 +58,7 @@ ArrowStreamPandasUDFSerializer, CogroupUDFSerializer, ArrowStreamUDFSerializer, + ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType @@ -67,10 +69,11 @@ utf8_deserializer = UTF8Deserializer() -def report_times(outfile, boot, init, finish): +def report_times(outfile, boot, init, func_init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(int(1000 * boot), outfile) write_long(int(1000 * init), outfile) + write_long(int(1000 * func_init), outfile) write_long(int(1000 * finish), outfile) @@ -207,6 +210,60 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_map_pandas_udf_with_state(f, return_type): + def wrapped(key_series, value_series_gen, state): + import pandas as pd + + key = tuple(s[0] for s in key_series) + + if state.hasTimedOut: + # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. + values = [pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), ] + else: + values = (pd.concat(x, axis=1) for x in value_series_gen) + + result_iter = f(key, values, state) + + def verify_element(result): + if not isinstance(result, pd.DataFrame): + raise TypeError( + "The type of element in return iterator of the user-defined function " + "should be pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty + ): + raise RuntimeError( + "Number of columns of the element (pandas.DataFrame) in return iterator " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + + return result + + if isinstance(result_iter, pd.DataFrame): + raise TypeError( + "Return type of the user-defined function should be " + "iterable of pandas.DataFrame, but is {}".format(type(result_iter)) + ) + + try: + iter(result_iter) + except TypeError: + raise TypeError( + "Return type of the user-defined function should be " + "iterable, but is {}".format(type(result_iter)) + ) + + result_iter_with_validation = (verify_element(x) for x in result_iter) + + return (result_iter_with_validation, state, ) + + return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] + + def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) @@ -311,6 +368,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) @@ -336,6 +395,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ): # Load conf used for pandas_udf evaluation @@ -345,6 +405,12 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v + state_object_schema = None + softLimitBytesPerBatchInApplyInPandasWithState = None + minDataCountForSampleInApplyInPandasWithState = None + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -359,10 +425,33 @@ def read_udfs(pickleSer, infile, eval_type): == "true" ) + softLimitBytesPerBatchInApplyInPandasWithState = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch", (64 * 1024 * 1024) + ) + softLimitBytesPerBatchInApplyInPandasWithState = \ + int(softLimitBytesPerBatchInApplyInPandasWithState) + + minDataCountForSampleInApplyInPandasWithState = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.minDataCountForSample", 100 + ) + minDataCountForSampleInApplyInPandasWithState = \ + int(minDataCountForSampleInApplyInPandasWithState) + softTimeoutMillisPurgeBatchInApplyInPandasWithState = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch", 100 + ) + softTimeoutMillisPurgeBatchInApplyInPandasWithState = \ + int(softTimeoutMillisPurgeBatchInApplyInPandasWithState) + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + ser = ApplyInPandasWithStateSerializer(timezone, safecheck, assign_cols_by_name, + state_object_schema, + softLimitBytesPerBatchInApplyInPandasWithState, + minDataCountForSampleInApplyInPandasWithState, + softTimeoutMillisPurgeBatchInApplyInPandasWithState) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. @@ -474,9 +563,11 @@ def extract_key_value_indexes(grouped_arg_offsets): # support combining multiple UDFs. assert num_udfs == 1 - # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0 + ) parsed_offsets = extract_key_value_indexes(arg_offsets) # Create function like this: @@ -486,6 +577,37 @@ def mapper(a): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0 + ) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + from itertools import tee + + state = a[1] + data_gen = (x[0] for x in a[0]) + + # We know there should be at least one item in the iterator/generator. + # We want to peek the first element to construct the key, hence applying + # tee to construct the key while we retain another iterator/generator + # for values. + keys_gen, values_gen = tee(data_gen) + keys_elem = next(keys_gen) + keys = [keys_elem[o] for o in parsed_offsets[0][0]] + + # This must be generator comprehension - do not materialize. + vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen) + + return f(keys, vals, state) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. @@ -663,13 +785,16 @@ def main(infile, outfile): broadcast_sock_file.close() _accumulatorRegistry.clear() + + init_time = time.time() + eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) - init_time = time.time() + func_init_time = time.time() def process(): iterator = deserializer.load_stream(infile) @@ -716,7 +841,7 @@ def process(): faulthandler_log_file.close() os.remove(faulthandler_log_path) finish_time = time.time() - report_times(outfile, boot_time, init_time, finish_time) + report_times(outfile, boot_time, init_time, func_init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index a814525f870c..479a097713f5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -32,6 +32,9 @@ @Experimental @Evolving public class GroupStateTimeout { + // NOTE: if you're adding new type of timeout, you should also fix the places below: + // - Scala: org.apache.spark.sql.api.python.PythonSQLUtils.getGroupStateTimeoutFromString + // - Python: pyspark.sql.streaming.state.GroupStateTimeout /** * Timeout based on processing time. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c11ce7d3b90f..99ba3802097b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -64,6 +64,7 @@ object UnsupportedOperationChecker extends Logging { case s: Aggregate if s.isStreaming => true case _ @ Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => true case f: FlatMapGroupsWithState if f.isStreaming => true + case f: FlatMapGroupsInPandasWithState if f.isStreaming => true case d: Deduplicate if d.isStreaming => true case _ => false } @@ -142,6 +143,17 @@ object UnsupportedOperationChecker extends Logging { " or the output mode is not append on a streaming DataFrames/Datasets")(plan) } + val applyInPandasWithStates = plan.collect { + case f: FlatMapGroupsInPandasWithState if f.isStreaming => f + } + + // Disallow multiple `applyInPandasWithState`s. + if (applyInPandasWithStates.size >= 2) { + throwError( + "Multiple applyInPandasWithStates are not supported on a streaming " + + "DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging { } } + // applyInPandasWithState + case m: FlatMapGroupsInPandasWithState if m.isStreaming => + // Check compatibility with output modes and aggregations in query + val aggsInQuery = collectStreamingAggregates(plan) + + if (aggsInQuery.isEmpty) { + // applyInPandasWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "applyInPandasWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "applyInPandasWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => + } + } else { + // applyInPandasWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "applyInPandasWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "applyInPandasWithState in append mode is not supported after " + + "aggregation on a streaming DataFrame/Dataset") + } + } + + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "applyInPandasWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index c2f74b350834..e97ff7808f17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType /** * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. @@ -98,6 +100,38 @@ case class FlatMapCoGroupsInPandas( copy(left = newLeft, right = newRight) } +/** + * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group + * in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * `functionExpr` is invoked with an pandas DataFrame representation and the + * grouping key (tuple). + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outputAttrs used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param outputMode the output mode of `func` + * @param timeout used to timeout groups that have not received data in a while + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithState( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + stateType: StructType, + outputMode: OutputMode, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index de25c19a26eb..c8acb8ac09cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2705,6 +2705,44 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH = + buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch") + .internal() + .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " + + "records that can be written to a single ArrowRecordBatch in memory. This is used to " + + "restrict the amount of memory being used to materialize the data in both executor and " + + "Python worker. The accumulated size of records are calculated via sampling a set of " + + "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " + + "is quite huge, the size of constructed ArrowRecordBatch will be around the " + + "configured value.") + .version("3.4.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("64MB") + + val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE = + buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample") + .internal() + .doc("When using applyInPandasWithState, specify the minimum number of records to sample " + + "the size of record. The size being retrieved from sampling will be used to estimate " + + "the accumulated size of records. Note that limiting by size does not work if the " + + "number of records are less than the configured value. For such case, ArrowRecordBatch " + + "will only be split for soft timeout.") + .version("3.4.0") + .intConf + .createWithDefault(100) + + val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH = + buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch") + .internal() + .doc("When using applyInPandasWithState, specify the soft timeout for purging the " + + "ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " + + "the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " + + "of data (both executor and Python worker) to not wait indefinitely for sender to " + + "complete the ArrowRecordBatch, which may hurt both throughput and latency.") + .version("3.4.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("100ms") + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -4529,6 +4567,15 @@ class SQLConf extends Serializable with Logging { def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION) + def softLimitBytesPerBatchInApplyInPandasWithState: Long = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH) + + def minDataCountForSampleInApplyInPandasWithState: Int = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE) + + def softTimeoutMillisPurgeBatchInApplyInPandasWithState: Long = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 989ee3252187..69eb8101abf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,9 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} /** @@ -620,6 +622,35 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + private[sql] def applyInPandasWithState( + func: PythonUDF, + outputStructType: StructType, + stateStructType: StructType, + outputModeStr: String, + timeoutConfStr: String): DataFrame = { + val timeoutConf = org.apache.spark.sql.execution.streaming + .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr) + val outputMode = InternalOutputModes(outputModeStr) + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = FlatMapGroupsInPandasWithState( + func, + groupingAttrs, + outputAttrs, + stateStructType, + outputMode, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 2b74bcc38501..258d8a87f8b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -22,14 +22,15 @@ import java.net.Socket import java.nio.channels.Channels import java.util.Locale -import net.razorvine.pickle.Pickler +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -37,12 +38,29 @@ import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} private[sql] object PythonSQLUtils extends Logging { - private lazy val internalRowPickler = { + private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = { EvaluatePython.registerPicklers() - new Pickler(true, false) + val pickler = new Pickler(true, false) + val ret = try { + f(pickler) + } finally { + pickler.close() + } + ret + } + + private def withInternalRowUnpickler(f: Unpickler => Any): Any = { + EvaluatePython.registerPicklers() + val unpickler = new Unpickler + val ret = try { + f(unpickler) + } finally { + unpickler.close() + } + ret } def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText) @@ -94,8 +112,18 @@ private[sql] object PythonSQLUtils extends Logging { def toPyRow(row: Row): Array[Byte] = { assert(row.isInstanceOf[GenericRowWithSchema]) - internalRowPickler.dumps(EvaluatePython.toJava( - CatalystTypeConverters.convertToCatalyst(row), row.schema)) + withInternalRowPickler(_.dumps(EvaluatePython.toJava( + CatalystTypeConverters.convertToCatalyst(row), row.schema))) + } + + def toJVMRow( + arr: Array[Byte], + returnType: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]): Row = { + val fromJava = EvaluatePython.makeFromJava(returnType) + val internalRow = + fromJava(withInternalRowUnpickler(_.loads(arr))).asInstanceOf[InternalRow] + deserializer(internalRow) } def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6104104c7bea..8feb68909b2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -684,6 +684,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object FlatMapGroupsInPandasWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case FlatMapGroupsInPandasWithState( + func, groupAttr, outputAttr, stateType, outputMode, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val execPlan = python.FlatMapGroupsInPandasWithStateExec( + func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, + batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + ) + execPlan :: Nil + case _ => + Nil + } + } + /** * Strategy to convert EvalPython logical operator to physical operator. */ @@ -793,6 +812,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + case _: FlatMapGroupsInPandasWithState => + // TODO(SPARK-XXXXX): Implement batch support for applyInPandasWithState + throw new UnsupportedOperationException( + "applyInPandasWithState is unsupported in batch query. Use applyInPandas instead.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 7abca5f0e332..bd27ad59bf03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -44,7 +44,7 @@ object ArrowWriter { new ArrowWriter(root, children.toArray) } - private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() (ArrowUtils.fromArrowField(field), vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) @@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { count += 1 } + def sizeInBytes(): Int = { + var i = 0 + var bytes = 0 + while (i < fields.size) { + bytes += fields(i).getSizeInBytes() + i += 1 + } + bytes + } + def finish(): Unit = { root.setRowCount(count) fields.foreach(_.finish()) @@ -136,6 +146,10 @@ private[arrow] abstract class ArrowFieldWriter { valueVector.setValueCount(count) } + def getSizeInBytes(): Int = { + valueVector.getBufferSizeFor(count) + } + def reset(): Unit = { valueVector.reset() count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala new file mode 100644 index 000000000000..213c9f4e712b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} +import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + + +/** + * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. + */ +class ApplyInPandasWithStatePythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + inputSchema: StructType, + override protected val timeZoneId: String, + initialWorkerConf: Map[String, String], + stateEncoder: ExpressionEncoder[Row], + keySchema: StructType, + valueSchema: StructType, + stateValueSchema: StructType, + softLimitBytesPerBatch: Long, + minDataCountForSample: Int, + softTimeoutMillsPurgeBatch: Long) + extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) + with PythonArrowInput[InType] + with PythonArrowOutput[OutType] { + + override protected val schema: StructType = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + + override protected val workerConf: Map[String, String] = initialWorkerConf + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> + softLimitBytesPerBatch.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> + minDataCountForSample.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> + softTimeoutMillsPurgeBatch.toString) + + private val stateRowDeserializer = stateEncoder.createDeserializer() + + override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + // Also write the schema for state value + PythonRDD.writeUTF(stateValueSchema.json, stream) + } + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[InType]): Unit = { + val w = new ApplyInPandasWithStateWriter(root, writer, softLimitBytesPerBatch, + minDataCountForSample, softTimeoutMillsPurgeBatch) + + while (inputIterator.hasNext) { + val (keyRow, groupState, dataIter) = inputIterator.next() + assert(dataIter.hasNext, "should have at least one data row!") + w.startNewGroup(keyRow, groupState) + + while (dataIter.hasNext) { + val dataRow = dataIter.next() + w.writeRow(dataRow) + } + + w.finalizeGroup() + } + + w.finalizeData() + } + + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = { + // This should at least have one row for state. Also, we ensure that all columns across + // data and state metadata have same number of rows, which is required by Arrow record + // batch. + assert(batch.numRows() > 0) + assert(schema.length == 2) + + def getColumnarBatchForStructTypeColumn( + batch: ColumnarBatch, + ordinal: Int, + expectedType: StructType): ColumnarBatch = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector] + val dataType = schema(ordinal).dataType.asInstanceOf[StructType] + assert(dataType.sameType(expectedType)) + + val outputVectors = dataType.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + flattenedBatch + } + + def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { + val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, valueSchema) + dataBatch.rowIterator.asScala.flatMap { row => + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for state metadata. + None + } else { + Some(row) + } + } + } + + def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] = { + val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1, + STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER) + + stateMetadataBatch.rowIterator().asScala.flatMap { row => + implicit val formats = org.json4s.DefaultFormats + + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for data. + None + } else { + // NOTE: See StateReaderIterator.STATE_METADATA_SCHEMA for the schema. + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { + None + } else { + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) + } + val oldTimeoutTimestamp = row.getLong(3) + + Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), + oldTimeoutTimestamp)) + } + } + } + + (constructIterForState(batch), constructIterForData(batch)) + } +} + +object ApplyInPandasWithStatePythonRunner { + type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) + type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long) + type OutType = (Iterator[OutTypeForState], Iterator[InternalRow]) + + val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("oldTimeoutTimestamp", LongType) + ) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala new file mode 100644 index 000000000000..781335f821eb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + + +class ApplyInPandasWithStateWriter( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + softLimitBytesPerBatch: Long, + minDataCountForSample: Int, + softTimeoutMillsPurgeBatch: Long) { + + import ApplyInPandasWithStateWriter._ + + // We logically group the columns by family and initialize writer separately, since it's + // lot more easier and probably performant to write the row directly rather than + // projecting the row to match up with the overall schema. + // + // The number of data rows and state metadata rows can be different which seems to matter + // for Arrow RecordBatch, so we append empty rows to cover it. + // + // We always produce at least one data row per grouping key whereas we only produce one + // state metadata row per grouping key, so we only need to fill up the empty rows in + // state metadata side. + private val arrowWriterForData = createArrowWriter(root.getFieldVectors.asScala.dropRight(1)) + private val arrowWriterForState = createArrowWriter(root.getFieldVectors.asScala.takeRight(1)) + + // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to + // gain the performance. In many cases, the amount of data per grouping key is quite + // small, which does not seem to maximize the benefits of using Arrow. + // + // We have to split the record batch down to each group in Python worker to convert the + // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split + // the range of data and give a view, say, "zero-copy". To help splitting the range for + // data, we provide the "start offset" and the "number of data" in the state metadata. + // + // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft + // limit on the size - it's not a hard limit since we allow current group to write all + // data even it's going to exceed the limit. + // + // We perform some basic sampling for data to guess the size of the data very roughly, + // and simply multiply by the number of data to estimate the size. We extract the size of + // data from the record batch rather than UnsafeRow, as we don't hold the memory for + // UnsafeRow once we write to the record batch. If there is a memory bound here, it + // should come from record batch. + // + // In the meanwhile, we don't also want to let the current record batch collect the data + // indefinitely, since we are pipelining the process between executor and python worker. + // Python worker won't process any data if executor is not yet finalized a record + // batch, which defeats the purpose of pipelining. To address this, we also introduce + // timeout for constructing a record batch. This is a soft limit indeed as same as limit + // on the size - we allow current group to write all data even it's timed-out. + + private var numRowsForCurGroup = 0 + private var startOffsetForCurGroup = 0 + private var totalNumRowsForBatch = 0 + private var totalNumStatesForBatch = 0 + + private var sampledDataSizePerRow = 0 + private var lastBatchPurgedMillis = System.currentTimeMillis() + + private var currentGroupKeyRow: UnsafeRow = _ + private var currentGroupState: GroupStateImpl[Row] = _ + + def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = { + currentGroupKeyRow = keyRow + currentGroupState = groupState + } + + def writeRow(dataRow: InternalRow): Unit = { + // Currently, this only works when the number of rows are greater than the minimum + // data count for sampling. And we technically have no way to pick some rows from + // record batch and measure the size of data, hence we leverage all data in current + // record batch. We only sample once as it could be costly. + if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { + sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch + } + + // If it exceeds the condition of batch (only size, not about timeout) and + // there is more data for the same group, flush and construct a new batch. + + // The soft-limit on size effectively works after the sampling has completed, since we + // multiply the number of rows by 0 if the sampling is still in progress. + + // if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch) { + // FIXME: debug + if (totalNumRowsForBatch % 10 == 1) { + // Provide state metadata row as intermediate + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + finalizeCurrentArrowBatch() + } + + arrowWriterForData.write(dataRow) + numRowsForCurGroup += 1 + totalNumRowsForBatch += 1 + } + + def finalizeGroup(): Unit = { + // Provide state metadata row + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = true) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + // The start offset for next group would be same as the total number of rows for batch, + // unless the next group starts with new batch. + startOffsetForCurGroup = totalNumRowsForBatch + + // The soft-limit on timeout applies on finalization of each group. + if (System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { + finalizeCurrentArrowBatch() + } + } + + def finalizeData(): Unit = { + if (numRowsForCurGroup > 0) { + // We still have some rows in the current record batch. Need to flush them as well. + finalizeCurrentArrowBatch() + } + } + + private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = { + val children = fieldVectors.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + + new ArrowWriter(root, children.toArray) + } + + private def buildStateInfoRow( + keyRow: UnsafeRow, + groupState: GroupStateImpl[Row], + startOffset: Int, + numRows: Int, + isLastChunk: Boolean): InternalRow = { + // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA + val stateUnderlyingRow = new GenericInternalRow( + Array[Any]( + UTF8String.fromString(groupState.json()), + keyRow.getBytes, + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, + startOffset, + numRows, + isLastChunk + ) + ) + new GenericInternalRow(Array[Any](stateUnderlyingRow)) + } + + private def finalizeCurrentArrowBatch(): Unit = { + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) + } + + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurGroup = 0 + numRowsForCurGroup = 0 + totalNumRowsForBatch = 0 + totalNumStatesForBatch = 0 + lastBatchPurgedMillis = System.currentTimeMillis() + } +} + +object ApplyInPandasWithStateWriter { + val STATE_METADATA_SCHEMA: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("startOffset", IntegerType), + StructField("numRows", IntegerType), + StructField("isLastChunk", BooleanType) + ) + ) + + // To avoid initializing a new row for empty state metadata row. + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( + Array[Any](null, null, null, null, null, null)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index e830ea6b5466..b39787b12a48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) - val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 3a3a6022f998..f0e815e966e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala new file mode 100644 index 000000000000..b833809561f7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing + * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]] + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outAttributes used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator. + * @param stateFormatVersion the version of state format. + * @param outputMode the output mode of `functionExpr` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithStateExec( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + + // TODO(SPARK-XXXXX): Add the support of initial state. + override protected val initialStateDeserializer: Expression = null + override protected val initialStateGroupAttrs: Seq[Attribute] = null + override protected val initialStateDataAttrs: Seq[Attribute] = null + override protected val initialState: SparkPlan = null + override protected val hasInitialState: Boolean = false + + override protected val stateEncoder: ExpressionEncoder[Any] = + RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + + override def output: Seq[Attribute] = outAttributes + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( + groupingAttributes ++ child.output, groupingAttributes) + + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def shortName: String = "applyInPandasWithState" + + override protected def withNewChildInternal( + newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { + + override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + val processIter = groupedIter.map { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + val stateData = stateManager.getState(store, keyUnsafeRow) + (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj)) + } + + process(processIter, hasTimedOut = false) + } + + override def processNewDataWithInitialState( + childDataIter: Iterator[InternalRow], + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + + override def processTimedOutState(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + } + + val processIter = timingOutPairs.map { stateData => + val joinedKeyRow = unsafeProj( + new JoinedRow( + stateData.keyRow, + new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + + (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) + } + + process(processIter, hasTimedOut = true) + } else Iterator.empty + } + + private def process( + iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val runner = new ApplyInPandasWithStatePythonRunner( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), + sessionLocalTimeZone, + pythonRunnerConf, + stateEncoder.asInstanceOf[ExpressionEncoder[Row]], + groupingAttributes.toStructType, + child.output.toStructType, + stateType, + conf.softLimitBytesPerBatchInApplyInPandasWithState, + conf.minDataCountForSampleInApplyInPandasWithState, + conf.softTimeoutMillisPurgeBatchInApplyInPandasWithState) + + val context = TaskContext.get() + + val processIter = iter.map { case (keyRow, stateData, valueIter) => + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + (keyRow, groupedState, valueIter) + } + runner.compute(processIter, context.partitionId(), context).flatMap { + case (stateIter, outputIter) => + // When the iterator is consumed, then write changes to state. + // state does not affect each others, hence when to update does not affect to the result. + def onIteratorCompletion: Unit = { + stateIter.foreach { case (keyRow, newGroupState, oldTimeoutTimestamp) => + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs + .orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != oldTimeoutTimestamp + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + hasTimeoutChanged + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, keyRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIter, onIteratorCompletion).map { row => + numOutputRows += 1 + row + } + } + } + + override protected def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 2da0000dad4e..078876664062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.BasePythonRunner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.GroupedIterator import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -88,9 +88,10 @@ private[python] object PandasGroupUtils { * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes */ def resolveArgOffsets( - child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + attributes: Seq[Attribute], + groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { - val dataAttributes = child.output.drop(groupingAttributes.length) + val dataAttributes = attributes.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 6168d0f867ad..bf66791183ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -76,7 +76,6 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { - val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 790a652f2112..1071e522f6e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -56,7 +56,7 @@ trait FlatMapGroupsWithStateExecBase protected val batchTimestampMs: Option[Long] val eventTimeWatermark: Option[Long] - protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout + protected val isTimeoutEnabled = timeoutConf != NoTimeout protected val watermarkPresent: Boolean = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false @@ -271,8 +271,7 @@ trait FlatMapGroupsWithStateExecBase */ def processNewDataWithInitialState( childDataIter: Iterator[InternalRow], - initStateIter: Iterator[InternalRow] - ): Iterator[InternalRow] = { + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty @@ -284,7 +283,7 @@ trait FlatMapGroupsWithStateExecBase // Create a CoGroupedIterator that will group the two iterators together for every key group. new CoGroupedIterator( - groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { + groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { case (keyRow, valueRowIter, initialStateRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] var foundInitialStateForKey = false @@ -299,8 +298,8 @@ trait FlatMapGroupsWithStateExecBase // We apply the values for the key after applying the initial state. callFunctionAndUpdateState( stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false + valueRowIter, + hasTimedOut = false ) } } @@ -334,9 +333,9 @@ trait FlatMapGroupsWithStateExecBase * @param hasTimedOut Whether this function is being called for a key timeout */ protected def callFunctionAndUpdateState( - stateData: StateData, - valueRowIter: Iterator[InternalRow], - hasTimedOut: Boolean): Iterator[InternalRow] + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index b4f37125f4fa..bcd3cfc4508d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.sql.Date import java.util.concurrent.TimeUnit +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.api.java.Optional import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.util.IntervalUtils @@ -27,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.GroupStateImpl._ import org.apache.spark.sql.streaming.{GroupStateTimeout, TestGroupState} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * Internal implementation of the [[TestGroupState]] interface. Methods are not thread-safe. @@ -39,7 +43,10 @@ import org.apache.spark.unsafe.types.UTF8String * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class GroupStateImpl[S] private( +private[sql] class GroupStateImpl[S] private[sql]( + // NOTE:if you're adding new properties here, fix: + // - `json` and `fromJson` methods of this class in Scala + // - pyspark.sql.streaming.state.GroupStateImpl in Python optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, @@ -173,6 +180,22 @@ private[sql] class GroupStateImpl[S] private( throw QueryExecutionErrors.cannotSetTimeoutTimestampError() } } + + private[sql] def json(): String = compact(render(new JObject( + // Constructor + "optionalValue" -> JNull :: // Note that optionalValue will be manually serialized. + "batchProcessingTimeMs" -> JLong(batchProcessingTimeMs) :: + "eventTimeWatermarkMs" -> JLong(eventTimeWatermarkMs) :: + "timeoutConf" -> JString(Utils.stripDollars(Utils.getSimpleName(timeoutConf.getClass))) :: + "hasTimedOut" -> JBool(hasTimedOut) :: + "watermarkPresent" -> JBool(watermarkPresent) :: + + // Internal state + "defined" -> JBool(defined) :: + "updated" -> JBool(updated) :: + "removed" -> JBool(removed) :: + "timeoutTimestamp" -> JLong(timeoutTimestamp) :: Nil + ))) } @@ -214,4 +237,35 @@ private[sql] object GroupStateImpl { hasTimedOut = false, watermarkPresent) } + + def groupStateTimeoutFromString(clazz: String): GroupStateTimeout = clazz match { + case "ProcessingTimeTimeout" => GroupStateTimeout.ProcessingTimeTimeout + case "EventTimeTimeout" => GroupStateTimeout.EventTimeTimeout + case "NoTimeout" => GroupStateTimeout.NoTimeout + case _ => throw new IllegalStateException("Invalid string for GroupStateTimeout: " + clazz) + } + + def fromJson[S](value: Option[S], json: JValue): GroupStateImpl[S] = { + implicit val formats = org.json4s.DefaultFormats + + val hmap = json.extract[Map[String, Any]] + + // Constructor + val newGroupState = new GroupStateImpl[S]( + value, + hmap("batchProcessingTimeMs").asInstanceOf[Number].longValue(), + hmap("eventTimeWatermarkMs").asInstanceOf[Number].longValue(), + groupStateTimeoutFromString(hmap("timeoutConf").asInstanceOf[String]), + hmap("hasTimedOut").asInstanceOf[Boolean], + hmap("watermarkPresent").asInstanceOf[Boolean]) + + // Internal state + newGroupState.defined = hmap("defined").asInstanceOf[Boolean] + newGroupState.updated = hmap("updated").asInstanceOf[Boolean] + newGroupState.removed = hmap("removed").asInstanceOf[Boolean] + newGroupState.timeoutTimestamp = + hmap("timeoutTimestamp").asInstanceOf[Number].longValue() + + newGroupState + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3f369ac5e973..f386282a0b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -62,6 +63,7 @@ class IncrementalExecution( StreamingJoinStrategy :: StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: + FlatMapGroupsInPandasWithStateStrategy :: StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: Nil @@ -210,6 +212,13 @@ class IncrementalExecution( hasInitialState = hasInitialState ) + case m: FlatMapGroupsInPandasWithStateExec => + m.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + ) + case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 01ff72bac7bc..022fd1239ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -49,7 +49,7 @@ package object state { } /** Map each partition of an RDD along with data in a [[StateStore]]. */ - private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( + def mapPartitionsWithStateStore[U: ClassTag]( stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 827cfcf32fea..3c41f6b47b5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType} /** * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF, @@ -190,7 +191,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasFunc: Array[Byte] = if (shouldTestScalarPandasUDFs) { + private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -213,7 +214,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) { + private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -235,6 +236,34 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } + private def createPandasGroupedMapFuncWithState(pythonScript: String): Array[Byte] = { + if (shouldTestPandasUDFs) { + var binaryPandasFunc: Array[Byte] = null + withTempPath { codePath => + Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8)) + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + "from pyspark.serializers import CloudPickleSerializer; " + + s"f = open('$path', 'wb');" + + s"exec(open('$codePath', 'r').read());" + + "f.write(CloudPickleSerializer().dumps((" + + "func, tpe)))"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } else { + throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") + } + } + + // Make sure this map stays mutable - this map gets updated later in Python runners. private val workerEnv = new java.util.HashMap[String, String]() workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") @@ -251,11 +280,9 @@ object IntegratedUDFTestUtils extends SQLHelper { lazy val shouldTestPythonUDFs: Boolean = isPythonAvailable && isPySparkAvailable - lazy val shouldTestScalarPandasUDFs: Boolean = + lazy val shouldTestPandasUDFs: Boolean = isPythonAvailable && isPandasAvailable && isPyArrowAvailable - lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs - /** * A base trait for various UDFs defined in this object. */ @@ -420,6 +447,41 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String = "Grouped Aggregate Pandas UDF" } + /** + * Arbitrary stateful processing in Python is used for + * `DataFrame.groupBy.applyInPandasWithState`. It requires `pythonScript` to + * define `func` (Python function) and `tpe` (`StructType` for state key). + * + * Virtually equivalent to: + * + * {{{ + * # exec defines 'func' and 'tpe' (struct type for state key) + * exec(pythonScript) + * + * # ... are filled when this UDF is invoked, see also 'PythonFlatMapGroupsWithStateSuite'. + * df.groupBy(...).applyInPandasWithState(func, ..., tpe, ..., ...) + * }}} + */ + case class TestGroupedMapPandasUDFWithState(name: String, pythonScript: String) extends TestUDF { + private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( + name = name, + func = SimplePythonFunction( + command = createPandasGroupedMapFuncWithState(pythonScript), + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = NullType, // This is not respected. + pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + udfDeterministic = true) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Grouped Map Pandas UDF with State" + } + /** * A Scala UDF that takes one column, casts into string, executes the * Scala native function, and casts back to the type of input column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index cca9bb6741f6..a662caea74a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -244,7 +244,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper /* Do nothing */ } case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestPandasUDFs => ignore(s"${testCase.name} is skipped because pyspark," + s"pandas and/or pyarrow were not available in [$pythonExec].") { /* Do nothing */ @@ -433,7 +433,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper if udfTestCase.udf.isInstanceOf[TestPythonUDF] && shouldTestPythonUDFs => s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}" case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestPandasUDFs => s"${testCase.name}${System.lineSeparator()}" + s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}" case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 00c774e2d1be..92aadb6779e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -128,7 +128,7 @@ class QueryCompilationErrorsSuite test("INVALID_PANDAS_UDF_PLACEMENT: Using aggregate function with grouped aggregate pandas UDF") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), @@ -180,7 +180,7 @@ class QueryCompilationErrorsSuite test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 4ad7f9010537..42e4b1accde7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -73,7 +73,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-39962: Global aggregation of Pandas UDF should respect the column order") { - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPythonUDFs) val df = Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("a", "b") val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala new file mode 100644 index 000000000000..9c6573fd782a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.IntegratedUDFTestUtils.{shouldTestPandasUDFs, TestGroupedMapPandasUDFWithState} +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.util.{StatefulOpClusteredDistributionTestHelper, StreamManualClock} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} + +class FlatMapGroupsInPandasWithStateDistributionSuite extends StreamTest + with StatefulOpClusteredDistributionTestHelper { + + import testImplicits._ + + test("applyInPandasWithState should require StatefulOpClusteredDistribution " + + "from children - without initial state") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType, IntegerType + | + |tpe = StructType([ + | StructField("key1", StringType()), + | StructField("key2", StringType()), + | StructField("count", IntegerType())]) + | + |def func(key, pdf_iter, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | if count >= 3: + | state.remove() + | yield pd.DataFrame() + | else: + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'count': [count]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[(String, String, Long)] + val outputStructType = StructType( + Seq( + StructField("key1", StringType), + StructField("key2", StringType), + StructField("count", IntegerType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + val result = + inputDataDS + .withWatermark("timestamp", "10 second") + .repartition($"key1") + .groupBy($"key1", $"key2") + .applyInPandasWithState( + pythonFunc(inputDataDS("key1"), inputDataDS("key2"), inputDataDS("timestamp")) + .expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + .select("key1", "key2", "count") + + val clock = new StreamManualClock + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a is processed here for the first time. + CheckNewAnswer(("a", "a", 1)), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsInPandasWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsInPandasWithStateExec => f + } + + assert(flatMapGroupsInPandasWithStateExecs.length === 1) + assert(requireStatefulOpClusteredDistribution( + flatMapGroupsInPandasWithStateExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsInPandasWithStateExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala new file mode 100644 index 000000000000..aa2d7169ce88 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -0,0 +1,599 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.IntegratedUDFTestUtils._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update} +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ + +class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { + + import testImplicits._ + + test("applyInPandasWithState - streaming") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | if count >= 3: + | state.remove() + | yield pd.DataFrame() + | else: + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("applyInPandasWithState - streaming, multiple groups in partition, " + + "multiple outputs per grouping key") { + assume(shouldTestPandasUDFs) + + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import IntegerType, StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("value", IntegerType()), + | StructField("valueAsString", StringType()), + | StructField("prevCountAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | prev_count = state.getOption + | if prev_count is None: + | prev_count = 0 + | else: + | prev_count = prev_count[0] + | + | count = prev_count + | for pdf in pdf_iter: + | count += len(pdf) + | yield pdf.assign(valueAsString=lambda x: x.value.apply(str), + | prevCountAsString=str(prev_count)) + | + | state.update((count,)) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputData = MemoryStream[(String, Int)] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("value", IntegerType), + StructField("valueAsString", StringType), + StructField("prevCountAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS().selectExpr("_1 AS key", "_2 AS value") + val result = + inputDataDS + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDS("key"), inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + .select("key", "value", "valueAsString", "prevCountAsString") + + testStream(result, Update)( + AddData(inputData, ("a", 1)), + CheckNewAnswer(("a", 1, "1", "0")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, ("a", 2), ("a", 3), ("b", 1)), + CheckNewAnswer(("a", 2, "2", "1"), ("a", 3, "3", "1"), ("b", 1, "1", "0")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, ("b", 2), ("c", 1), ("d", 1), ("e", 1)), + CheckNewAnswer(("b", 2, "2", "1"), ("c", 1, "1", "0"), ("d", 1, "1", "0"), + ("e", 1, "1", "0")), + assertNumStateRows(total = 5, updated = 4), + AddData(inputData, ("a", 4)), + CheckNewAnswer(("a", 4, "4", "3")) + ) + } + } + + test("applyInPandasWithState - streaming + aggregation") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count,)) + | + | ret = pd.DataFrame() + | if count >= 3: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Append", + "NoTimeout") + .groupBy("key") + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckNewAnswer(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckNewAnswer(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckNewAnswer(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("applyInPandasWithState - streaming with processing time timeout") { + assume(shouldTestPandasUDFs) + + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | ret = None + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count,)) + | state.setTimeoutDuration(10000) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + StopStream, + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(11 * 1000), + CheckNewAnswer(("b", "-1"), ("c", "1")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows( + total = Seq(0), updated = Seq(0), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } + + test("applyInPandasWithState - streaming w/ event time timeout + watermark") { + assume(shouldTestPandasUDFs) + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + val pythonScript = + """ + |import calendar + |import os + |import datetime + |import pandas as pd + |from pyspark.sql.types import StructType, StringType, StructField, IntegerType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("maxEventTimeSec", IntegerType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | assert state.getCurrentWatermarkMs() >= -1 + | + | timeout_delay_sec = 5 + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) + | else: + | m = state.getOption + | if m is None: + | max_event_time_sec = 0 + | else: + | max_event_time_sec = m[0] + | + | for pdf in pdf_iter: + | pser = pdf.eventTime.apply( + | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) + | max_event_time_sec = int(max(pser.max(), max_event_time_sec)) + | + | state.update((max_event_time_sec,)) + | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec + | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + | yield pd.DataFrame({'key': [key[0]], + | 'maxEventTimeSec': [max_event_time_sec]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[(String, Int)] + val inputDataDF = + inputData.toDF.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("maxEventTimeSec", IntegerType))) + val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) + val result = + inputDataDF + .withWatermark("eventTime", "10 seconds") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "EventTimeTimeout") + + testStream(result, Update)( + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + } + + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + assume(shouldTestPandasUDFs) + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // String, (String, Long), RunningCount(Long) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count,)) + | state.setTimeoutDuration(10000) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val inputDataDF = inputData + .toDF.toDF("key", "time") + .selectExpr("key", "timestamp_seconds(time) as timestamp") + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDF + .withWatermark("timestamp", "10 second") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("timestamp")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")) + ) + } + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) + + test("applyInPandasWithState - uses state format version 2 by default") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | if count >= 3: + | state.remove() + | yield pd.DataFrame() + | else: + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsInPandasWithStateExec => f + } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 26c201d5921e..fc6b51dce790 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -279,7 +279,7 @@ class ContinuousSuite extends ContinuousSuiteBase { Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).foreach { udf => test(s"continuous mode with various UDFs - ${udf.prettyName}") { assume( - shouldTestScalarPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || + shouldTestPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || shouldTestPythonUDFs && udf.isInstanceOf[TestPythonUDF] || udf.isInstanceOf[TestScalaUDF]) diff --git a/test-applyinpandaswithstate.py b/test-applyinpandaswithstate.py new file mode 100644 index 000000000000..03233b9e2cc6 --- /dev/null +++ b/test-applyinpandaswithstate.py @@ -0,0 +1,94 @@ +import calendar +import os +import datetime +import pandas as pd +from pyspark.sql import SparkSession +from pyspark.sql import Row + +def user_func(key, pdf, state): + timeout_delay_sec = 10 + + print('=' * 80) + print(key) + print(pdf) + print(state.getOption) + print(state.hasTimedOut) + print('=' * 80) + + if state.hasTimedOut: + state.remove() + return pd.DataFrame({'key1': [], 'key2': [], 'maxTimestampSeenMs': [], 'average': []}) + else: + prev_state = state.getOption + if prev_state is None: + prev_sum = 0 + prev_count = 0 + prev_max_timestamp_seen_sec = 0 # should be -Inf or something along with + else: + # FIXME: Is it better UX to access the state object as tuple instead of Row or dict at least? + prev_sum = prev_state[0] + prev_count = prev_state[1] + prev_max_timestamp_seen_sec = prev_state[2] + + new_sum = prev_sum + int(pdf.value.sum()) + new_count = prev_count + len(pdf) + + # TODO: now it's taking second precision - lower down to millisecond + # print(key) + # print(pdf) + pser = pdf.timestamp.apply(lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) + #print(pser) + new_max_event_time_sec = int(max(pser.max(), prev_max_timestamp_seen_sec)) + timeout_timestamp_sec = new_max_event_time_sec + timeout_delay_sec + + # FIXME: Is it better UX to access the state object as tuple instead of Row or dict at least? + state.update((new_sum, new_count, new_max_event_time_sec,)) + state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + return pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'maxTimestampSeenMs': [new_max_event_time_sec * 1000], 'average': [new_sum * 1.0 / new_count]}) + + +spark = SparkSession \ + .builder \ + .appName("Python ApplyInPandasWithState example") \ + .config("spark.sql.shuffle.partitions", 1) \ + .getOrCreate() + +rate_stream = ( + spark.readStream + .format('rate') + .option('numPartitions', 1) + .option('rowsPerSecond', 500000) + .load() +) + +output_struct = 'key1 string, key2 long, maxTimestampSeenMs long, average double' +state_struct = 'sum long, count long, maxTimestampSeenSec long' + +# desired_group_keys = 100 +desired_group_keys = 100000 +key1_expr = "(case when value % 5 = 0 then 'a' when value % 5 = 1 then 'b' when value % 5 = 2 then 'c' when value % 5 = 3 then 'd' else 'e' end) AS key1" +key2_expr = f"ceil(value / 5) % {desired_group_keys / 5} AS key2" + +# schema from rate source: 'timestamp' - TimestampType, 'value' - LongType +custom_session_window_stream = ( + rate_stream + # TODO: how many groups we want to track? + .selectExpr("timestamp", key1_expr, key2_expr, "value") + .withWatermark('timestamp', '0 seconds') + .groupby('key1', 'key2') + .applyInPandasWithState(user_func, outputStructType=output_struct, + stateStructType=state_struct, outputMode='update', timeoutConf='EventTimeTimeout') + .selectExpr('maxTimestampSeenMs * 1000', 'key1', 'key2', 'average') +) + +query = ( + custom_session_window_stream + .writeStream + .trigger(processingTime='0 seconds') + .outputMode('Update') + .format('console') + .start() +) + +query.awaitTermination() +