diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 65e58e025b17..93bc11ed2b23 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -501,6 +501,11 @@ def normalize_value(v: Any) -> Any: # Convert NumPy types to Python primitive types. if isinstance(v, np.generic): return v.tolist() + # Named tuples (collections.namedtuple or typing.NamedTuple) and Row both + # require positional arguments and cannot be instantiated + # with a generator expression. + if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")): + return type(v)(*[normalize_value(e) for e in v]) # List / tuple: recursively normalize each element if isinstance(v, (list, tuple)): return type(v)(normalize_value(e) for e in v) diff --git a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py index 09ef3a447f9c..e76405da447c 100644 --- a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py @@ -17,7 +17,10 @@ from abc import abstractmethod import sys -from typing import Iterator +from typing import ( + Iterator, + NamedTuple, +) import unittest from pyspark.errors import PySparkRuntimeError from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle @@ -1663,10 +1666,15 @@ def close(self) -> None: # A stateful processor that contains composite python type inside Value, List and Map state variable class PandasStatefulProcessorCompositeType(StatefulProcessor): + class Address(NamedTuple): + road_id: int + city: str + TAGS = [["dummy1", "dummy2"], ["dummy3"]] METADATA = [{"key": "env", "value": "prod"}, {"key": "region", "value": "us-west"}] ATTRIBUTES_MAP = {"key1": [1], "key2": [10]} CONFS_MAP = {"e1": {"e2": 5, "e3": 10}} + ADDRESS = [Address(1, "Seattle"), Address(3, "SF")] def init(self, handle: StatefulProcessorHandle) -> None: obj_schema = StructType( @@ -1681,6 +1689,17 @@ def init(self, handle: StatefulProcessorHandle) -> None: ) ), ), + StructField( + "address", + ArrayType( + StructType( + [ + StructField("road_id", IntegerType()), + StructField("city", StringType()), + ] + ) + ), + ), ] ) @@ -1700,25 +1719,28 @@ def init(self, handle: StatefulProcessorHandle) -> None: def _update_obj_state(self, total_temperature): if self.obj_state.exists(): - ids, tags, metadata = self.obj_state.get() + ids, tags, metadata, address = self.obj_state.get() assert tags == self.TAGS, f"Tag mismatch: {tags}" assert metadata == [Row(**m) for m in self.METADATA], f"Metadata mismatch: {metadata}" + assert address == [ + Row(**e._asdict()) for e in self.ADDRESS + ], f"Address mismatch: {address}" ids = [int(x + total_temperature) for x in ids] else: ids = [0] - self.obj_state.update((ids, self.TAGS, self.METADATA)) + self.obj_state.update((ids, self.TAGS, self.METADATA, self.ADDRESS)) return ids def _update_list_state(self, total_temperature, initial_obj): existing_list = self.list_state.get() updated_list = [] - for ids, tags, metadata in existing_list: + for ids, tags, metadata, address in existing_list: ids.append(total_temperature) - updated_list.append((ids, tags, [row.asDict() for row in metadata])) + updated_list.append((ids, tags, [row.asDict() for row in metadata], address)) if not updated_list: updated_list.append(initial_obj) self.list_state.put(updated_list) - return [id_val for ids, _, _ in updated_list for id_val in ids] + return [id_val for ids, _, _, _ in updated_list for id_val in ids] def _update_map_state(self, key, total_temperature): if not self.map_state.containsKey(key): @@ -1736,7 +1758,7 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: updated_ids = self._update_obj_state(total_temperature) flattened_ids = self._update_list_state( - total_temperature, (updated_ids, self.TAGS, self.METADATA) + total_temperature, (updated_ids, self.TAGS, self.METADATA, self.ADDRESS) ) attributes_map, confs_map = self._update_map_state(key, total_temperature) @@ -1767,10 +1789,15 @@ def close(self) -> None: class RowStatefulProcessorCompositeType(StatefulProcessor): + class Address(NamedTuple): + road_id: int + city: str + TAGS = [["dummy1", "dummy2"], ["dummy3"]] METADATA = [{"key": "env", "value": "prod"}, {"key": "region", "value": "us-west"}] ATTRIBUTES_MAP = {"key1": [1], "key2": [10]} CONFS_MAP = {"e1": {"e2": 5, "e3": 10}} + ADDRESS = [Address(1, "Seattle"), Address(3, "SF")] def init(self, handle: StatefulProcessorHandle) -> None: obj_schema = StructType( @@ -1785,6 +1812,17 @@ def init(self, handle: StatefulProcessorHandle) -> None: ) ), ), + StructField( + "address", + ArrayType( + StructType( + [ + StructField("road_id", IntegerType()), + StructField("city", StringType()), + ] + ) + ), + ), ] ) @@ -1804,25 +1842,28 @@ def init(self, handle: StatefulProcessorHandle) -> None: def _update_obj_state(self, total_temperature): if self.obj_state.exists(): - ids, tags, metadata = self.obj_state.get() + ids, tags, metadata, address = self.obj_state.get() assert tags == self.TAGS, f"Tag mismatch: {tags}" assert metadata == [Row(**m) for m in self.METADATA], f"Metadata mismatch: {metadata}" + assert address == [ + Row(**e._asdict()) for e in self.ADDRESS + ], f"Address mismatch: {address}" ids = [int(x + total_temperature) for x in ids] else: ids = [0] - self.obj_state.update((ids, self.TAGS, self.METADATA)) + self.obj_state.update((ids, self.TAGS, self.METADATA, self.ADDRESS)) return ids def _update_list_state(self, total_temperature, initial_obj): existing_list = self.list_state.get() updated_list = [] - for ids, tags, metadata in existing_list: + for ids, tags, metadata, address in existing_list: ids.append(total_temperature) - updated_list.append((ids, tags, [row.asDict() for row in metadata])) + updated_list.append((ids, tags, [row.asDict() for row in metadata], address)) if not updated_list: updated_list.append(initial_obj) self.list_state.put(updated_list) - return [id_val for ids, _, _ in updated_list for id_val in ids] + return [id_val for ids, _, _, _ in updated_list for id_val in ids] def _update_map_state(self, key, total_temperature): if not self.map_state.containsKey(key): @@ -1840,7 +1881,7 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]: updated_ids = self._update_obj_state(total_temperature) flattened_ids = self._update_list_state( - total_temperature, (updated_ids, self.TAGS, self.METADATA) + total_temperature, (updated_ids, self.TAGS, self.METADATA, self.ADDRESS) ) attributes_map, confs_map = self._update_map_state(key, total_temperature)