Skip to content

Commit e0e64ba

Browse files
Davies LiuJoshRosen
authored andcommitted
[SPARK-6055] [PySpark] fix incorrect __eq__ of DataType
The _eq_ of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released. Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython). This PR also improve the performance of inferSchema (avoid the unnecessary converter of object). cc pwendell JoshRosen Author: Davies Liu <[email protected]> Closes #4808 from davies/leak and squashes the following commits: 6a322a4 [Davies Liu] tests refactor 3da44fc [Davies Liu] fix __eq__ of Singleton 534ac90 [Davies Liu] add more checks 46999dc [Davies Liu] fix tests d9ae973 [Davies Liu] fix memory leak in sql
1 parent 8c468a6 commit e0e64ba

File tree

4 files changed

+86
-137
lines changed

4 files changed

+86
-137
lines changed

python/pyspark/sql/context.py

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717

1818
import warnings
1919
import json
20-
from array import array
2120
from itertools import imap
2221

2322
from py4j.protocol import Py4JError
2423
from py4j.java_collections import MapConverter
2524

2625
from pyspark.rdd import RDD, _prepare_for_python_RDD
2726
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
28-
from pyspark.sql.types import StringType, StructType, _verify_type, \
27+
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
2928
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
3029
from pyspark.sql.dataframe import DataFrame
3130

@@ -620,93 +619,6 @@ def _get_hive_ctx(self):
620619
return self._jvm.HiveContext(self._jsc.sc())
621620

622621

623-
def _create_row(fields, values):
624-
row = Row(*values)
625-
row.__FIELDS__ = fields
626-
return row
627-
628-
629-
class Row(tuple):
630-
631-
"""
632-
A row in L{DataFrame}. The fields in it can be accessed like attributes.
633-
634-
Row can be used to create a row object by using named arguments,
635-
the fields will be sorted by names.
636-
637-
>>> row = Row(name="Alice", age=11)
638-
>>> row
639-
Row(age=11, name='Alice')
640-
>>> row.name, row.age
641-
('Alice', 11)
642-
643-
Row also can be used to create another Row like class, then it
644-
could be used to create Row objects, such as
645-
646-
>>> Person = Row("name", "age")
647-
>>> Person
648-
<Row(name, age)>
649-
>>> Person("Alice", 11)
650-
Row(name='Alice', age=11)
651-
"""
652-
653-
def __new__(self, *args, **kwargs):
654-
if args and kwargs:
655-
raise ValueError("Can not use both args "
656-
"and kwargs to create Row")
657-
if args:
658-
# create row class or objects
659-
return tuple.__new__(self, args)
660-
661-
elif kwargs:
662-
# create row objects
663-
names = sorted(kwargs.keys())
664-
values = tuple(kwargs[n] for n in names)
665-
row = tuple.__new__(self, values)
666-
row.__FIELDS__ = names
667-
return row
668-
669-
else:
670-
raise ValueError("No args or kwargs")
671-
672-
def asDict(self):
673-
"""
674-
Return as an dict
675-
"""
676-
if not hasattr(self, "__FIELDS__"):
677-
raise TypeError("Cannot convert a Row class into dict")
678-
return dict(zip(self.__FIELDS__, self))
679-
680-
# let obect acs like class
681-
def __call__(self, *args):
682-
"""create new Row object"""
683-
return _create_row(self, args)
684-
685-
def __getattr__(self, item):
686-
if item.startswith("__"):
687-
raise AttributeError(item)
688-
try:
689-
# it will be slow when it has many fields,
690-
# but this will not be used in normal cases
691-
idx = self.__FIELDS__.index(item)
692-
return self[idx]
693-
except IndexError:
694-
raise AttributeError(item)
695-
696-
def __reduce__(self):
697-
if hasattr(self, "__FIELDS__"):
698-
return (_create_row, (self.__FIELDS__, tuple(self)))
699-
else:
700-
return tuple.__reduce__(self)
701-
702-
def __repr__(self):
703-
if hasattr(self, "__FIELDS__"):
704-
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
705-
for k, v in zip(self.__FIELDS__, self))
706-
else:
707-
return "<Row(%s)>" % ", ".join(self)
708-
709-
710622
def _test():
711623
import doctest
712624
from pyspark.context import SparkContext

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,12 @@ def cast(self, dataType):
10251025
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
10261026
jdt = ssql_ctx.parseDataType(dataType.json())
10271027
jc = self._jc.cast(jdt)
1028+
else:
1029+
raise TypeError("unexpected type: %s" % type(dataType))
10281030
return Column(jc)
10291031

10301032
def __repr__(self):
1031-
return 'Column<%s>' % self._jdf.toString().encode('utf8')
1033+
return 'Column<%s>' % self._jc.toString().encode('utf8')
10321034

10331035

10341036
def _test():

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pydoc
2525
import shutil
2626
import tempfile
27+
import pickle
2728

2829
import py4j
2930

@@ -88,6 +89,14 @@ def __eq__(self, other):
8889
other.x == self.x and other.y == self.y
8990

9091

92+
class DataTypeTests(unittest.TestCase):
93+
# regression test for SPARK-6055
94+
def test_data_type_eq(self):
95+
lt = LongType()
96+
lt2 = pickle.loads(pickle.dumps(LongType()))
97+
self.assertEquals(lt, lt2)
98+
99+
91100
class SQLTests(ReusedPySparkTestCase):
92101

93102
@classmethod

0 commit comments

Comments
 (0)