Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def normalize_value(v: Any) -> Any:
# Convert NumPy types to Python primitive types.
if isinstance(v, np.generic):
return v.tolist()
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
# require positional arguments and cannot be instantiated
# with a generator expression.
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
return type(v)(*[normalize_value(e) for e in v])
# List / tuple: recursively normalize each element
if isinstance(v, (list, tuple)):
return type(v)(normalize_value(e) for e in v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1663,10 +1663,17 @@ def close(self) -> None:

# A stateful processor that contains composite python type inside Value, List and Map state variable
class PandasStatefulProcessorCompositeType(StatefulProcessor):
from typing import NamedTuple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to the top of the test module? Is there a specific reason that you want it in the class definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. Just moved it the top!


class Address(NamedTuple):
road_id: int
city: str

TAGS = [["dummy1", "dummy2"], ["dummy3"]]
METADATA = [{"key": "env", "value": "prod"}, {"key": "region", "value": "us-west"}]
ATTRIBUTES_MAP = {"key1": [1], "key2": [10]}
CONFS_MAP = {"e1": {"e2": 5, "e3": 10}}
ADDRESS = [Address(1, "Seattle"), Address(3, "SF")]

def init(self, handle: StatefulProcessorHandle) -> None:
obj_schema = StructType(
Expand All @@ -1681,6 +1688,17 @@ def init(self, handle: StatefulProcessorHandle) -> None:
)
),
),
StructField(
"address",
ArrayType(
StructType(
[
StructField("road_id", IntegerType()),
StructField("city", StringType()),
]
)
),
),
]
)

Expand All @@ -1700,25 +1718,28 @@ def init(self, handle: StatefulProcessorHandle) -> None:

def _update_obj_state(self, total_temperature):
if self.obj_state.exists():
ids, tags, metadata = self.obj_state.get()
ids, tags, metadata, address = self.obj_state.get()
assert tags == self.TAGS, f"Tag mismatch: {tags}"
assert metadata == [Row(**m) for m in self.METADATA], f"Metadata mismatch: {metadata}"
assert address == [
Row(**e._asdict()) for e in self.ADDRESS
], f"Address mismatch: {address}"
ids = [int(x + total_temperature) for x in ids]
else:
ids = [0]
self.obj_state.update((ids, self.TAGS, self.METADATA))
self.obj_state.update((ids, self.TAGS, self.METADATA, self.ADDRESS))
return ids

def _update_list_state(self, total_temperature, initial_obj):
existing_list = self.list_state.get()
updated_list = []
for ids, tags, metadata in existing_list:
for ids, tags, metadata, address in existing_list:
ids.append(total_temperature)
updated_list.append((ids, tags, [row.asDict() for row in metadata]))
updated_list.append((ids, tags, [row.asDict() for row in metadata], address))
if not updated_list:
updated_list.append(initial_obj)
self.list_state.put(updated_list)
return [id_val for ids, _, _ in updated_list for id_val in ids]
return [id_val for ids, _, _, _ in updated_list for id_val in ids]

def _update_map_state(self, key, total_temperature):
if not self.map_state.containsKey(key):
Expand All @@ -1736,7 +1757,7 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:

updated_ids = self._update_obj_state(total_temperature)
flattened_ids = self._update_list_state(
total_temperature, (updated_ids, self.TAGS, self.METADATA)
total_temperature, (updated_ids, self.TAGS, self.METADATA, self.ADDRESS)
)
attributes_map, confs_map = self._update_map_state(key, total_temperature)

Expand Down Expand Up @@ -1767,10 +1788,17 @@ def close(self) -> None:


class RowStatefulProcessorCompositeType(StatefulProcessor):
from typing import NamedTuple

class Address(NamedTuple):
road_id: int
city: str

TAGS = [["dummy1", "dummy2"], ["dummy3"]]
METADATA = [{"key": "env", "value": "prod"}, {"key": "region", "value": "us-west"}]
ATTRIBUTES_MAP = {"key1": [1], "key2": [10]}
CONFS_MAP = {"e1": {"e2": 5, "e3": 10}}
ADDRESS = [Address(1, "Seattle"), Address(3, "SF")]

def init(self, handle: StatefulProcessorHandle) -> None:
obj_schema = StructType(
Expand All @@ -1785,6 +1813,17 @@ def init(self, handle: StatefulProcessorHandle) -> None:
)
),
),
StructField(
"address",
ArrayType(
StructType(
[
StructField("road_id", IntegerType()),
StructField("city", StringType()),
]
)
),
),
]
)

Expand All @@ -1804,25 +1843,28 @@ def init(self, handle: StatefulProcessorHandle) -> None:

def _update_obj_state(self, total_temperature):
if self.obj_state.exists():
ids, tags, metadata = self.obj_state.get()
ids, tags, metadata, address = self.obj_state.get()
assert tags == self.TAGS, f"Tag mismatch: {tags}"
assert metadata == [Row(**m) for m in self.METADATA], f"Metadata mismatch: {metadata}"
assert address == [
Row(**e._asdict()) for e in self.ADDRESS
], f"Address mismatch: {address}"
ids = [int(x + total_temperature) for x in ids]
else:
ids = [0]
self.obj_state.update((ids, self.TAGS, self.METADATA))
self.obj_state.update((ids, self.TAGS, self.METADATA, self.ADDRESS))
return ids

def _update_list_state(self, total_temperature, initial_obj):
existing_list = self.list_state.get()
updated_list = []
for ids, tags, metadata in existing_list:
for ids, tags, metadata, address in existing_list:
ids.append(total_temperature)
updated_list.append((ids, tags, [row.asDict() for row in metadata]))
updated_list.append((ids, tags, [row.asDict() for row in metadata], address))
if not updated_list:
updated_list.append(initial_obj)
self.list_state.put(updated_list)
return [id_val for ids, _, _ in updated_list for id_val in ids]
return [id_val for ids, _, _, _ in updated_list for id_val in ids]

def _update_map_state(self, key, total_temperature):
if not self.map_state.containsKey(key):
Expand All @@ -1840,7 +1882,7 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:

updated_ids = self._update_obj_state(total_temperature)
flattened_ids = self._update_list_state(
total_temperature, (updated_ids, self.TAGS, self.METADATA)
total_temperature, (updated_ids, self.TAGS, self.METADATA, self.ADDRESS)
)
attributes_map, confs_map = self._update_map_state(key, total_temperature)

Expand Down