Skip to content

Commit d69d397

Browse files
committed
refactor
1 parent 2cc2d45 commit d69d397

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

python/pyspark/sql.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -335,28 +335,32 @@ def _parse_datatype_string(datatype_string):
335335
>>> check_datatype(complex_maptype)
336336
True
337337
"""
338-
left_bracket_index = datatype_string.find("(")
339-
if left_bracket_index == -1:
338+
index = datatype_string.find("(")
339+
if index == -1:
340340
# It is a primitive type.
341-
left_bracket_index = len(datatype_string)
342-
type_or_field = datatype_string[:left_bracket_index]
343-
rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip()
341+
index = len(datatype_string)
342+
type_or_field = datatype_string[:index]
343+
rest_part = datatype_string[index+1:len(datatype_string)-1].strip()
344+
344345
if type_or_field in _all_primitive_types:
345346
return _all_primitive_types[type_or_field]()
347+
346348
elif type_or_field == "ArrayType":
347349
last_comma_index = rest_part.rfind(",")
348350
containsNull = True
349351
if rest_part[last_comma_index+1:].strip().lower() == "false":
350352
containsNull = False
351353
elementType = _parse_datatype_string(rest_part[:last_comma_index].strip())
352354
return ArrayType(elementType, containsNull)
355+
353356
elif type_or_field == "MapType":
354357
last_comma_index = rest_part.rfind(",")
355358
valueContainsNull = True
356359
if rest_part[last_comma_index+1:].strip().lower() == "false":
357360
valueContainsNull = False
358361
keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip())
359362
return MapType(keyType, valueType, valueContainsNull)
363+
360364
elif type_or_field == "StructField":
361365
first_comma_index = rest_part.find(",")
362366
name = rest_part[:first_comma_index].strip()
@@ -367,6 +371,7 @@ def _parse_datatype_string(datatype_string):
367371
dataType = _parse_datatype_string(
368372
rest_part[first_comma_index+1:last_comma_index].strip())
369373
return StructField(name, dataType, nullable)
374+
370375
elif type_or_field == "StructType":
371376
# rest_part should be in the format like
372377
# List(StructField(field1,IntegerType,false)).
@@ -378,13 +383,13 @@ def _parse_datatype_string(datatype_string):
378383

379384
_cached_namedtuples = {}
380385

381-
def _restore_object(fields, obj):
386+
def _restore_object(name, fields, obj):
382387
""" Restore namedtuple object during unpickling. """
383388
cls = _cached_namedtuples.get(fields)
384389
if cls is None:
385-
cls = namedtuple("Row", fields)
390+
cls = namedtuple(name, fields)
386391
def __reduce__(self):
387-
return (_restore_object, (fields, tuple(self)))
392+
return (_restore_object, (name, fields, tuple(self)))
388393
cls.__reduce__ = __reduce__
389394
_cached_namedtuples[fields] = cls
390395
return cls(*obj)
@@ -395,13 +400,13 @@ def _create_object(cls, v):
395400

396401
def _create_getter(dt, i):
397402
""" Create a getter for item `i` with schema """
398-
# TODO: cache created class
399403
cls = _create_cls(dt)
400404
def getter(self):
401405
return _create_object(cls, self[i])
402406
return getter
403407

404408
def _has_struct(dt):
409+
"""Return whether `dt` is or has StructType in it"""
405410
if isinstance(dt, StructType):
406411
return True
407412
elif isinstance(dt, ArrayType):
@@ -416,34 +421,35 @@ def _create_cls(dataType):
416421
417422
The created class is similar to namedtuple, but can have nested schema.
418423
"""
419-
# this can not be in global
420-
from pyspark.sql import _has_struct, _create_getter
421424
from operator import itemgetter
422425

423-
424-
# TODO: update to new DataType
425426
if isinstance(dataType, ArrayType):
426427
cls = _create_cls(dataType.elementType)
427428
class List(list):
428429
def __getitem__(self, i):
430+
# create object with datetype
429431
return _create_object(cls, list.__getitem__(self, i))
430432
def __repr__(self):
433+
# call collect __repr__ for nested objects
431434
return "[%s]" % (", ".join(repr(self[i])
432435
for i in range(len(self))))
433436
def __reduce__(self):
434-
# the nested struct can be reduced by itself
437+
# pickle as dict, the nested struct can be reduced by itself
435438
return (list, (list(self),))
436439
return List
437440

438441
elif isinstance(dataType, MapType):
439442
vcls = _create_cls(dataType.valueType)
440443
class Dict(dict):
441444
def __getitem__(self, k):
445+
# create object with datetype
442446
return _create_object(vcls, dict.__getitem__(self, k))
443447
def __repr__(self):
448+
# call collect __repr__ for nested objects
444449
return "{%s}" % (", ".join("%r: %r" % (k, self[k])
445450
for k in self))
446451
def __reduce__(self):
452+
# pickle as dict, the nested struct can be reduced by itself
447453
return (dict, (dict(self),))
448454
return Dict
449455

@@ -454,24 +460,24 @@ class Row(tuple):
454460
""" Row in SchemaRDD """
455461
_fields = tuple(f.name for f in dataType.fields)
456462

463+
# create property for fast access
457464
# use local vars begins with "_"
458465
for _i,_f in enumerate(dataType.fields):
459466
if _has_struct(_f.dataType):
460-
_getter = property(_create_getter(_f.dataType, _i))
467+
# delay creating object until accessing it
468+
_getter = _create_getter(_f.dataType, _i)
461469
else:
462-
_getter = property(itemgetter(_i))
463-
locals()[_f.name] = _getter
470+
_getter = itemgetter(_i)
471+
locals()[_f.name] = property(_getter)
464472
del _i, _f, _getter
465473

466474
def __repr__(self):
475+
# call collect __repr__ for nested objects
467476
return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
468477
for n in self._fields))
469-
470-
def __str__(self):
471-
return repr(self)
472-
473478
def __reduce__(self):
474-
return (_restore_object, (self._fields, tuple(self)))
479+
# pickle as namedtuple
480+
return (_restore_object, ("Row", self._fields, tuple(self)))
475481

476482
return Row
477483

0 commit comments

Comments
 (0)