Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
512958b
Basework
HeartSaVioR Jul 12, 2022
d36373b
Add Python implementation
HyukjinKwon Jul 14, 2022
f754fd9
Reorder key attributes from deduplicated data attributes
HyukjinKwon Jul 27, 2022
5194e0c
Apply suggestions from code review
HyukjinKwon Jul 27, 2022
1301ee5
Refactoring a bit to respect the column order
HyukjinKwon Aug 11, 2022
135a826
WIP Changes to execute in pipelined manner
HeartSaVioR Aug 15, 2022
9282e5c
WIP further optimization
HeartSaVioR Aug 18, 2022
a792c98
WIP comments for more tunes
HeartSaVioR Aug 18, 2022
27e7af9
WIP further tune...
HeartSaVioR Aug 18, 2022
04a6b98
WIP done more tune! didn't do any of pandas/arrow side tunes
HeartSaVioR Aug 18, 2022
765f4d3
WIP avoid adding additional empty row for state, empty row will be ad…
HeartSaVioR Aug 19, 2022
9e11225
WIP remove debug log
HeartSaVioR Aug 19, 2022
f33d978
WIP hack around to see the possibility of perf gain on binpacking
HeartSaVioR Aug 27, 2022
8604fdf
WIP proper work to apply binpacking on python worker -> executor
HeartSaVioR Aug 27, 2022
0d024e0
WIP fix silly bug
HeartSaVioR Aug 27, 2022
43c623b
WIP another silly bugfix on migration
HeartSaVioR Aug 27, 2022
af1725a
WIP apply binpacking for executor -> python worker as well
HeartSaVioR Aug 27, 2022
31e9687
WIP fix silly bug
HeartSaVioR Aug 27, 2022
cad77a2
WIP fix another silly bug
HeartSaVioR Aug 27, 2022
c3da996
WIP batching per specified size, with sampling
HeartSaVioR Aug 29, 2022
cfb2780
WIP introduce DBR-only change
HeartSaVioR Aug 29, 2022
228b140
WIP debugging now...
HeartSaVioR Aug 29, 2022
ee4ed57
WIP still debugging... weirdness happened
HeartSaVioR Aug 30, 2022
4045ab3
WIP small fix
HeartSaVioR Aug 30, 2022
2d115ab
WIP fix a serious bug... make sure all columns in Arrow RecordBatch h…
HeartSaVioR Aug 30, 2022
3e7d785
WIP strengthen test
HeartSaVioR Aug 30, 2022
029dae7
WIP documenting the changes for pipelining and bin-packing... not yet…
HeartSaVioR Sep 2, 2022
d7ecaf9
WIP sync
HeartSaVioR Sep 2, 2022
6a6dd20
WIP start with is_last_chunk since it's easier to implement... severa…
HeartSaVioR Sep 2, 2022
5cfd59c
WIP adjust the test code to make test pass with multiple calls
HeartSaVioR Sep 2, 2022
63f8f87
WIP refactor a bit... just extract the abstract classes to explicit ones
HeartSaVioR Sep 5, 2022
6e772cd
WIP iterator of DatFrame done! updated tests and they all passed
HeartSaVioR Sep 5, 2022
00836b5
WIP FIX pyspark side test failure
HeartSaVioR Sep 6, 2022
5fdde94
WIP sort out codebase a bit
HeartSaVioR Sep 14, 2022
e7ad043
WIP no batch query support in applyInPandasWithState
HeartSaVioR Sep 6, 2022
5070b81
WIP address some missed things
HeartSaVioR Sep 6, 2022
1b919b8
WIP remove comments which are obsolete or won't be addressed
HeartSaVioR Sep 7, 2022
198fc17
WIP change the return type of user function to Iterator[DataFrame]
HeartSaVioR Sep 7, 2022
f2a75f1
WIP remove unnecessary interface/implementation changes on GroupState…
HeartSaVioR Sep 13, 2022
3e5f5d4
WIP refine out some code
HeartSaVioR Sep 13, 2022
4e34d29
WIP fix scalastyle
HeartSaVioR Sep 13, 2022
50e743e
WIP remove obsolete class
HeartSaVioR Sep 13, 2022
d22d7db
WIP remove the temp fix
HeartSaVioR Sep 13, 2022
e60408f
remove unused code
HeartSaVioR Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private[spark] object PythonEvalType {
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -65,6 +66,7 @@ private[spark] object PythonEvalType {
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
}
}

Expand Down Expand Up @@ -537,13 +539,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val funcInitTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val funcInit = funcInitTime - initTime
val finish = finishTime - funcInitTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
logInfo("Times: total = %s, boot = %s, init = %s, func_init = %s, finish = %s"
.format(total, boot, init, funcInit, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3051,7 +3051,7 @@ private[spark] object Utils extends Logging {
* and return the trailing part after the last dollar sign in the middle
*/
@scala.annotation.tailrec
private def stripDollars(s: String): String = {
def stripDollars(s: String): String = {
val lastDollarIndex = s.lastIndexOf('$')
if (lastDollarIndex < s.length - 1) {
// The last char is not a dollar sign
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def __hash__(self):
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.test_pandas_cogrouped_map",
"pyspark.sql.tests.test_pandas_grouped_map",
"pyspark.sql.tests.test_pandas_grouped_map_with_state",
"pyspark.sql.tests.test_pandas_map",
"pyspark.sql.tests.test_arrow_map",
"pyspark.sql.tests.test_pandas_udf",
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
PandasMapIterUDFType,
PandasCogroupedMapUDFType,
ArrowMapIterUDFType,
PandasGroupedMapUDFWithStateType,
)
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import AtomicType, StructType
Expand Down Expand Up @@ -147,6 +148,7 @@ class PythonEvalType:
SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205
SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208


def portable_hash(x: Hashable) -> int:
Expand Down
11 changes: 7 additions & 4 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ from typing import (
Iterable,
NewType,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import Protocol, Literal
from types import FunctionType

from pyspark.sql._typing import LiteralType
import pyarrow
from pandas.core.frame import DataFrame as PandasDataFrame
from pandas.core.series import Series as PandasSeries
from numpy import ndarray as NDArray

import pyarrow
from pyspark.sql._typing import LiteralType
from pyspark.sql.streaming.state import GroupStateImpl

ArrayLike = NDArray
DataFrameLike = PandasDataFrame
Expand All @@ -51,6 +51,7 @@ PandasScalarIterUDFType = Literal[204]
PandasMapIterUDFType = Literal[205]
PandasCogroupedMapUDFType = Literal[206]
ArrowMapIterUDFType = Literal[207]
PandasGroupedMapUDFWithStateType = Literal[208]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
Expand Down Expand Up @@ -253,9 +254,11 @@ PandasScalarIterFunction = Union[

PandasGroupedMapFunction = Union[
Callable[[DataFrameLike], DataFrameLike],
Callable[[Any, DataFrameLike], DataFrameLike],
Callable[[Tuple, DataFrameLike], DataFrameLike],
]

PandasGroupedMapFunctionWithState = Callable[[Tuple, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]]

class PandasVariadicGroupedAggFunction(Protocol):
def __call__(self, *_: SeriesLike) -> LiteralType: ...

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
None,
]: # None means it should infer the type from type hints.

Expand Down Expand Up @@ -399,6 +400,7 @@ def _create_pandas_udf(f, returnType, evalType):
)
elif evalType in [
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
Expand Down
45 changes: 43 additions & 2 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
# limitations under the License.
#
import sys
from typing import List, Union, TYPE_CHECKING
from typing import List, Union, TYPE_CHECKING, cast
import warnings

from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.streaming.state import GroupStateTimeout
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
GroupedMapPandasUserDefinedFunction,
PandasGroupedMapFunction,
PandasGroupedMapFunctionWithState,
PandasCogroupedMapFunction,
)
from pyspark.sql.group import GroupedData
Expand Down Expand Up @@ -216,6 +218,45 @@ def applyInPandas(
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.session)

def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> DataFrame:
from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf

assert isinstance(self, GroupedData)
assert timeoutConf in [
GroupStateTimeout.NoTimeout,
GroupStateTimeout.ProcessingTimeTimeout,
GroupStateTimeout.EventTimeTimeout,
]

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
if isinstance(stateStructType, str):
stateStructType = cast(StructType, _parse_datatype_string(stateStructType))

udf = pandas_udf(
func, # type: ignore[call-overload]
returnType=outputStructType,
functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.applyInPandasWithState(
udf_column._jc.expr(),
self.session._jsparkSession.parseDataType(outputStructType.json()),
self.session._jsparkSession.parseDataType(stateStructType.json()),
outputMode,
timeoutConf,
)
return DataFrame(jdf, self.session)

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
"""
Cogroups this group with another group so that we can run cogrouped operations.
Expand Down
Loading