Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private[spark] object PythonEvalType {
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -65,6 +66,7 @@ private[spark] object PythonEvalType {
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3051,7 +3051,7 @@ private[spark] object Utils extends Logging {
* and return the trailing part after the last dollar sign in the middle
*/
@scala.annotation.tailrec
private def stripDollars(s: String): String = {
def stripDollars(s: String): String = {
val lastDollarIndex = s.lastIndexOf('$')
if (lastDollarIndex < s.length - 1) {
// The last char is not a dollar sign
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def __hash__(self):
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.test_pandas_cogrouped_map",
"pyspark.sql.tests.test_pandas_grouped_map",
"pyspark.sql.tests.test_pandas_grouped_map_with_state",
"pyspark.sql.tests.test_pandas_map",
"pyspark.sql.tests.test_arrow_map",
"pyspark.sql.tests.test_pandas_udf",
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
PandasMapIterUDFType,
PandasCogroupedMapUDFType,
ArrowMapIterUDFType,
PandasGroupedMapUDFWithStateType,
)
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import AtomicType, StructType
Expand Down Expand Up @@ -147,6 +148,7 @@ class PythonEvalType:
SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205
SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208


def portable_hash(x: Hashable) -> int:
Expand Down
11 changes: 7 additions & 4 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ from typing import (
Iterable,
NewType,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import Protocol, Literal
from types import FunctionType

from pyspark.sql._typing import LiteralType
import pyarrow
from pandas.core.frame import DataFrame as PandasDataFrame
from pandas.core.series import Series as PandasSeries
from numpy import ndarray as NDArray

import pyarrow
from pyspark.sql._typing import LiteralType
from pyspark.sql.streaming.state import GroupStateImpl

ArrayLike = NDArray
DataFrameLike = PandasDataFrame
Expand All @@ -51,6 +51,7 @@ PandasScalarIterUDFType = Literal[204]
PandasMapIterUDFType = Literal[205]
PandasCogroupedMapUDFType = Literal[206]
ArrowMapIterUDFType = Literal[207]
PandasGroupedMapUDFWithStateType = Literal[208]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
Expand Down Expand Up @@ -253,9 +254,11 @@ PandasScalarIterFunction = Union[

PandasGroupedMapFunction = Union[
Callable[[DataFrameLike], DataFrameLike],
Callable[[Any, DataFrameLike], DataFrameLike],
Callable[[Tuple, DataFrameLike], DataFrameLike],
]

PandasGroupedMapFunctionWithState = Callable[[Tuple, DataFrameLike, GroupStateImpl], DataFrameLike]

class PandasVariadicGroupedAggFunction(Protocol):
def __call__(self, *_: SeriesLike) -> LiteralType: ...

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
None,
]: # None means it should infer the type from type hints.

Expand Down Expand Up @@ -399,6 +400,7 @@ def _create_pandas_udf(f, returnType, evalType):
)
elif evalType in [
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
Expand Down
45 changes: 43 additions & 2 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
# limitations under the License.
#
import sys
from typing import List, Union, TYPE_CHECKING
from typing import List, Union, TYPE_CHECKING, cast
import warnings

from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.streaming.state import GroupStateTimeout
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
GroupedMapPandasUserDefinedFunction,
PandasGroupedMapFunction,
PandasGroupedMapFunctionWithState,
PandasCogroupedMapFunction,
)
from pyspark.sql.group import GroupedData
Expand Down Expand Up @@ -216,6 +218,45 @@ def applyInPandas(
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.session)

def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> DataFrame:
from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf

assert isinstance(self, GroupedData)
assert timeoutConf in [
GroupStateTimeout.NoTimeout,
GroupStateTimeout.ProcessingTimeTimeout,
GroupStateTimeout.EventTimeTimeout,
]

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
if isinstance(stateStructType, str):
stateStructType = cast(StructType, _parse_datatype_string(stateStructType))

udf = pandas_udf(
func, # type: ignore[call-overload]
returnType=outputStructType,
functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.applyInPandasWithState(
udf_column._jc.expr(),
self.session._jsparkSession.parseDataType(outputStructType.json()),
self.session._jsparkSession.parseDataType(stateStructType.json()),
outputMode,
timeoutConf,
)
return DataFrame(jdf, self.session)

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
"""
Cogroups this group with another group so that we can run cogrouped operations.
Expand Down
182 changes: 182 additions & 0 deletions python/pyspark/sql/streaming/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
from typing import Tuple, Optional

from pyspark.sql.types import DateType, Row, StructType

__all__ = ["GroupStateImpl", "GroupStateTimeout"]


class GroupStateTimeout:
NoTimeout: str = "NoTimeout"
ProcessingTimeTimeout: str = "ProcessingTimeTimeout"
EventTimeTimeout: str = "EventTimeTimeout"


class GroupStateImpl:
NO_TIMESTAMP: int = -1

def __init__(
self,
# JVM Constructor
optionalValue: Row,
batchProcessingTimeMs: int,
eventTimeWatermarkMs: int,
timeoutConf: str,
hasTimedOut: bool,
watermarkPresent: bool,
# JVM internal state.
defined: bool,
updated: bool,
removed: bool,
timeoutTimestamp: int,
# Python internal state.
keySchema: StructType,
) -> None:
self._value = optionalValue
self._batch_processing_time_ms = batchProcessingTimeMs
self._event_time_watermark_ms = eventTimeWatermarkMs

assert timeoutConf in [
GroupStateTimeout.NoTimeout,
GroupStateTimeout.ProcessingTimeTimeout,
GroupStateTimeout.EventTimeTimeout,
]
self._timeout_conf = timeoutConf

self._has_timed_out = hasTimedOut
self._watermark_present = watermarkPresent

self._defined = defined
self._updated = updated
self._removed = removed
self._timeout_timestamp = timeoutTimestamp

self._key_schema = keySchema

@property
def exists(self) -> bool:
return self._defined

@property
def get(self) -> Tuple:
if self.exists:
return tuple(self._value)
else:
raise ValueError("State is either not defined or has already been removed")

@property
def getOption(self) -> Optional[Tuple]:
if self.exists:
return tuple(self._value)
else:
return None

@property
def hasTimedOut(self) -> bool:
return self._has_timed_out

def update(self, newValue: Tuple) -> None:
if newValue is None:
raise ValueError("'None' is not a valid state value")

self._value = Row(*newValue)
self._defined = True
self._updated = True
self._removed = False

def remove(self) -> None:
self._defined = False
self._updated = False
self._removed = True

def setTimeoutDuration(self, durationMs: int) -> None:
if isinstance(durationMs, str):
# TODO(SPARK-XXXXX): Support string representation of durationMs.
raise ValueError("durationMs should be int but get :%s" % type(durationMs))

if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout:
raise RuntimeError(
"Cannot set timeout duration without enabling processing time timeout in "
"applyInPandasWithState"
)

if durationMs <= 0:
raise ValueError("Timeout duration must be positive")
self._timeout_timestamp = durationMs + self._batch_processing_time_ms

# TODO(SPARK-XXXXX): Implement additionalDuration parameter.
def setTimeoutTimestamp(self, timestampMs: int) -> None:
if self._timeout_conf != GroupStateTimeout.EventTimeTimeout:
raise RuntimeError(
"Cannot set timeout duration without enabling processing time timeout in "
"applyInPandasWithState"
)

if isinstance(timestampMs, datetime.datetime):
timestampMs = DateType().toInternal(timestampMs)

if timestampMs <= 0:
raise ValueError("Timeout timestamp must be positive")

if (
self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP
and timestampMs < self._event_time_watermark_ms
):
raise ValueError(
"Timeout timestamp (%s) cannot be earlier than the "
"current watermark (%s)" % (timestampMs, self._event_time_watermark_ms)
)

self._timeout_timestamp = timestampMs

def getCurrentWatermarkMs(self) -> int:
if not self._watermark_present:
raise RuntimeError(
"Cannot get event time watermark timestamp without setting watermark before "
"applyInPandasWithState"
)
return self._event_time_watermark_ms

def getCurrentProcessingTimeMs(self) -> int:
return self._batch_processing_time_ms

def __str__(self) -> str:
if self.exists:
return "GroupState(%s)" % self.get
else:
return "GroupState(<undefined>)"

def json(self) -> str:
return json.dumps(
{
# Constructor
"optionalValue": None, # Note that optionalValue will be manually serialized.
"batchProcessingTimeMs": self._batch_processing_time_ms,
"eventTimeWatermarkMs": self._event_time_watermark_ms,
"timeoutConf": self._timeout_conf,
"hasTimedOut": self._has_timed_out,
"watermarkPresent": self._watermark_present,
# JVM internal state.
"defined": self._defined,
"updated": self._updated,
"removed": self._removed,
"timeoutTimestamp": self._timeout_timestamp,
}
)
Loading