Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ message Relation {
GroupMap group_map = 31;
CoGroupMap co_group_map = 32;
WithWatermark with_watermark = 33;
ApplyInPandasWithState apply_in_pandas_with_state = 34;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -840,6 +841,29 @@ message CoGroupMap {
CommonInlineUserDefinedFunction func = 5;
}

message ApplyInPandasWithState {
// (Required) Input relation for applyInPandasWithState.
Relation input = 1;

// (Required) Expressions for grouping keys.
repeated Expression grouping_expressions = 2;

// (Required) Input user-defined function.
CommonInlineUserDefinedFunction func = 3;

// (Required) Schema for the output DataFrame.
string output_schema = 4;

// (Required) Schema for the state.
string state_schema = 5;

// (Required) The output mode of the function.
string output_mode = 6;

// (Required) Timeout configuration for groups that do not receive data for a while.
string timeout_conf = 7;
}

// Collect arbitrary (named) metrics from a dataset.
message CollectMetrics {
// (Required) The input relation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformGroupMap(rel.getGroupMap)
case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
transformApplyInPandasWithState(rel.getApplyInPandasWithState)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
Expand Down Expand Up @@ -583,6 +585,27 @@ class SparkConnectPlanner(val session: SparkSession) {
input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
}

private def transformApplyInPandasWithState(rel: proto.ApplyInPandasWithState): LogicalPlan = {
val pythonUdf = transformPythonUDF(rel.getFunc)
val cols =
rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))

val outputSchema = parseSchema(rel.getOutputSchema)

val stateSchema = parseSchema(rel.getStateSchema)

Dataset
.ofRows(session, transformRelation(rel.getInput))
.groupBy(cols: _*)
.applyInPandasWithState(
pythonUdf,
outputSchema,
stateSchema,
rel.getOutputMode,
rel.getTimeoutConf)
.logicalPlan
}

private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
"pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map",
"pyspark.sql.tests.connect.streaming.test_parity_streaming",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
# ml doctests
"pyspark.ml.connect.functions",
# ml unittests
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from pyspark.sql.connect.column import Column
from pyspark.sql.connect.types import DataType
from pyspark.sql.streaming.state import GroupState


ColumnOrName = Union[Column, str]
Expand Down Expand Up @@ -63,6 +64,10 @@

PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike], DataFrameLike]

PandasGroupedMapFunctionWithState = Callable[
[Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike]
]


class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
Expand Down
46 changes: 44 additions & 2 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
from pyspark.sql.types import NumericType
from pyspark.sql.types import StructType

import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.column import Column
Expand All @@ -47,6 +48,7 @@
PandasGroupedMapFunction,
GroupedMapPandasUserDefinedFunction,
PandasCogroupedMapFunction,
PandasGroupedMapFunctionWithState,
)
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.types import StructType
Expand Down Expand Up @@ -262,8 +264,48 @@ def applyInPandas(

applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__

def applyInPandasWithState(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("applyInPandasWithState() is not implemented.")
def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

udf_obj = UserDefinedFunction(
func,
returnType=outputStructType,
evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
)

output_schema: str = (
outputStructType.json()
if isinstance(outputStructType, StructType)
else outputStructType
)

state_schema: str = (
stateStructType.json() if isinstance(stateStructType, StructType) else stateStructType
)

return DataFrame.withPlan(
plan.ApplyInPandasWithState(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
output_schema=output_schema,
state_schema=state_schema,
output_mode=outputMode,
timeout_conf=timeoutConf,
cols=self._df.columns,
),
session=self._df._session,
)

applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
return PandasCogroupedOps(self, other)
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2045,6 +2045,45 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class ApplyInPandasWithState(LogicalPlan):
"""Logical plan object for a applyInPandasWithState."""

def __init__(
self,
child: Optional["LogicalPlan"],
grouping_cols: Sequence[Column],
function: "UserDefinedFunction",
output_schema: str,
state_schema: str,
output_mode: str,
timeout_conf: str,
cols: List[str],
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)

super().__init__(child)
self._grouping_cols = grouping_cols
self._func = function._build_common_inline_user_defined_function(*cols)
self._output_schema = output_schema
self._state_schema = state_schema
self._output_mode = output_mode
self._timeout_conf = timeout_conf

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.apply_in_pandas_with_state.input.CopyFrom(self._child.plan(session))
plan.apply_in_pandas_with_state.grouping_expressions.extend(
[c.to_plan(session) for c in self._grouping_cols]
)
plan.apply_in_pandas_with_state.func.CopyFrom(self._func.to_plan_udf(session))
plan.apply_in_pandas_with_state.output_schema = self._output_schema
plan.apply_in_pandas_with_state.state_schema = self._state_schema
plan.apply_in_pandas_with_state.output_mode = self._output_mode
plan.apply_in_pandas_with_state.timeout_conf = self._timeout_conf
return plan


class CachedRelation(LogicalPlan):
def __init__(self, plan: proto.Relation) -> None:
super(CachedRelation, self).__init__(None)
Expand Down
258 changes: 136 additions & 122 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

80 changes: 80 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Relation(google.protobuf.message.Message):
GROUP_MAP_FIELD_NUMBER: builtins.int
CO_GROUP_MAP_FIELD_NUMBER: builtins.int
WITH_WATERMARK_FIELD_NUMBER: builtins.int
APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -176,6 +177,8 @@ class Relation(google.protobuf.message.Message):
@property
def with_watermark(self) -> global___WithWatermark: ...
@property
def apply_in_pandas_with_state(self) -> global___ApplyInPandasWithState: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -245,6 +248,7 @@ class Relation(google.protobuf.message.Message):
group_map: global___GroupMap | None = ...,
co_group_map: global___CoGroupMap | None = ...,
with_watermark: global___WithWatermark | None = ...,
apply_in_pandas_with_state: global___ApplyInPandasWithState | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand All @@ -265,6 +269,8 @@ class Relation(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"aggregate",
b"aggregate",
"apply_in_pandas_with_state",
b"apply_in_pandas_with_state",
"approx_quantile",
b"approx_quantile",
"catalog",
Expand Down Expand Up @@ -366,6 +372,8 @@ class Relation(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"aggregate",
b"aggregate",
"apply_in_pandas_with_state",
b"apply_in_pandas_with_state",
"approx_quantile",
b"approx_quantile",
"catalog",
Expand Down Expand Up @@ -497,6 +505,7 @@ class Relation(google.protobuf.message.Message):
"group_map",
"co_group_map",
"with_watermark",
"apply_in_pandas_with_state",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -2980,6 +2989,77 @@ class CoGroupMap(google.protobuf.message.Message):

global___CoGroupMap = CoGroupMap

class ApplyInPandasWithState(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
FUNC_FIELD_NUMBER: builtins.int
OUTPUT_SCHEMA_FIELD_NUMBER: builtins.int
STATE_SCHEMA_FIELD_NUMBER: builtins.int
OUTPUT_MODE_FIELD_NUMBER: builtins.int
TIMEOUT_CONF_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for applyInPandasWithState."""
@property
def grouping_expressions(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Required) Expressions for grouping keys."""
@property
def func(self) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
"""(Required) Input user-defined function."""
output_schema: builtins.str
"""(Required) Schema for the output DataFrame."""
state_schema: builtins.str
"""(Required) Schema for the state."""
output_mode: builtins.str
"""(Required) The output mode of the function."""
timeout_conf: builtins.str
"""(Required) Timeout configuration for groups that do not receive data for a while."""
def __init__(
self,
*,
input: global___Relation | None = ...,
grouping_expressions: collections.abc.Iterable[
pyspark.sql.connect.proto.expressions_pb2.Expression
]
| None = ...,
func: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
| None = ...,
output_schema: builtins.str = ...,
state_schema: builtins.str = ...,
output_mode: builtins.str = ...,
timeout_conf: builtins.str = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"func",
b"func",
"grouping_expressions",
b"grouping_expressions",
"input",
b"input",
"output_mode",
b"output_mode",
"output_schema",
b"output_schema",
"state_schema",
b"state_schema",
"timeout_conf",
b"timeout_conf",
],
) -> None: ...

global___ApplyInPandasWithState = ApplyInPandasWithState

class CollectMetrics(google.protobuf.message.Message):
"""Collect arbitrary (named) metrics from a dataset."""

Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def applyInPandasWithState(

.. versionadded:: 3.4.0

.. versionchanged:: 3.5.0
Supports Spark Connect.

Parameters
----------
func : function
Expand Down
7 changes: 0 additions & 7 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2928,13 +2928,6 @@ def test_unsupported_functions(self):
with self.assertRaises(NotImplementedError):
getattr(df, f)()

def test_unsupported_group_functions(self):
# SPARK-41927: Disable unsupported functions.
cg = self.connect.read.table(self.tbl_name).groupBy("id")
for f in ("applyInPandasWithState",):
with self.assertRaises(NotImplementedError):
getattr(cg, f)()

def test_unsupported_session_functions(self):
# SPARK-41934: Disable unsupported functions.

Expand Down
Loading