Skip to content

Commit 0d8cdf0

Browse files
daviesmarmbrus
authored andcommitted
[SPARK-3681] [SQL] [PySpark] fix serialization of List and Map in SchemaRDD
Currently, the schema of object in ArrayType or MapType is attached lazily, it will have better performance but introduce issues while serialization or accessing nested objects. This patch will apply schema to the objects of ArrayType or MapType immediately when accessing them, will be a little bit slower, but much robust. Author: Davies Liu <[email protected]> Closes #2526 from davies/nested and squashes the following commits: 2399ae5 [Davies Liu] fix serialization of List and Map in SchemaRDD
1 parent f0c7e19 commit 0d8cdf0

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

python/pyspark/sql.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -838,43 +838,29 @@ def _create_cls(dataType):
838838
>>> obj = _create_cls(schema)(row)
839839
>>> pickle.loads(pickle.dumps(obj))
840840
Row(a=[1], b={'key': Row(c=1, d=2.0)})
841+
>>> pickle.loads(pickle.dumps(obj.a))
842+
[1]
843+
>>> pickle.loads(pickle.dumps(obj.b))
844+
{'key': Row(c=1, d=2.0)}
841845
"""
842846

843847
if isinstance(dataType, ArrayType):
844848
cls = _create_cls(dataType.elementType)
845849

846-
class List(list):
847-
848-
def __getitem__(self, i):
849-
# create object with datetype
850-
return _create_object(cls, list.__getitem__(self, i))
851-
852-
def __repr__(self):
853-
# call collect __repr__ for nested objects
854-
return "[%s]" % (", ".join(repr(self[i])
855-
for i in range(len(self))))
856-
857-
def __reduce__(self):
858-
return list.__reduce__(self)
850+
def List(l):
851+
if l is None:
852+
return
853+
return [_create_object(cls, v) for v in l]
859854

860855
return List
861856

862857
elif isinstance(dataType, MapType):
863-
vcls = _create_cls(dataType.valueType)
864-
865-
class Dict(dict):
866-
867-
def __getitem__(self, k):
868-
# create object with datetype
869-
return _create_object(vcls, dict.__getitem__(self, k))
870-
871-
def __repr__(self):
872-
# call collect __repr__ for nested objects
873-
return "{%s}" % (", ".join("%r: %r" % (k, self[k])
874-
for k in self))
858+
cls = _create_cls(dataType.valueType)
875859

876-
def __reduce__(self):
877-
return dict.__reduce__(self)
860+
def Dict(d):
861+
if d is None:
862+
return
863+
return dict((k, _create_object(cls, v)) for k, v in d.items())
878864

879865
return Dict
880866

python/pyspark/tests.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,27 @@ def test_apply_schema_to_row(self):
698698
srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
699699
self.assertEqual(10, srdd3.count())
700700

701+
def test_serialize_nested_array_and_map(self):
702+
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
703+
rdd = self.sc.parallelize(d)
704+
srdd = self.sqlCtx.inferSchema(rdd)
705+
row = srdd.first()
706+
self.assertEqual(1, len(row.l))
707+
self.assertEqual(1, row.l[0].a)
708+
self.assertEqual("2", row.d["key"].d)
709+
710+
l = srdd.map(lambda x: x.l).first()
711+
self.assertEqual(1, len(l))
712+
self.assertEqual('s', l[0].b)
713+
714+
d = srdd.map(lambda x: x.d).first()
715+
self.assertEqual(1, len(d))
716+
self.assertEqual(1.0, d["key"].c)
717+
718+
row = srdd.map(lambda x: x.d["key"]).first()
719+
self.assertEqual(1.0, row.c)
720+
self.assertEqual("2", row.d)
721+
701722

702723
class TestIO(PySparkTestCase):
703724

0 commit comments

Comments
 (0)