@@ -335,28 +335,32 @@ def _parse_datatype_string(datatype_string):
335335 >>> check_datatype(complex_maptype)
336336 True
337337 """
338- left_bracket_index = datatype_string .find ("(" )
339- if left_bracket_index == - 1 :
338+ index = datatype_string .find ("(" )
339+ if index == - 1 :
340340 # It is a primitive type.
341- left_bracket_index = len (datatype_string )
342- type_or_field = datatype_string [:left_bracket_index ]
343- rest_part = datatype_string [left_bracket_index + 1 :len (datatype_string )- 1 ].strip ()
341+ index = len (datatype_string )
342+ type_or_field = datatype_string [:index ]
343+ rest_part = datatype_string [index + 1 :len (datatype_string )- 1 ].strip ()
344+
344345 if type_or_field in _all_primitive_types :
345346 return _all_primitive_types [type_or_field ]()
347+
346348 elif type_or_field == "ArrayType" :
347349 last_comma_index = rest_part .rfind ("," )
348350 containsNull = True
349351 if rest_part [last_comma_index + 1 :].strip ().lower () == "false" :
350352 containsNull = False
351353 elementType = _parse_datatype_string (rest_part [:last_comma_index ].strip ())
352354 return ArrayType (elementType , containsNull )
355+
353356 elif type_or_field == "MapType" :
354357 last_comma_index = rest_part .rfind ("," )
355358 valueContainsNull = True
356359 if rest_part [last_comma_index + 1 :].strip ().lower () == "false" :
357360 valueContainsNull = False
358361 keyType , valueType = _parse_datatype_list (rest_part [:last_comma_index ].strip ())
359362 return MapType (keyType , valueType , valueContainsNull )
363+
360364 elif type_or_field == "StructField" :
361365 first_comma_index = rest_part .find ("," )
362366 name = rest_part [:first_comma_index ].strip ()
@@ -367,6 +371,7 @@ def _parse_datatype_string(datatype_string):
367371 dataType = _parse_datatype_string (
368372 rest_part [first_comma_index + 1 :last_comma_index ].strip ())
369373 return StructField (name , dataType , nullable )
374+
370375 elif type_or_field == "StructType" :
371376 # rest_part should be in the format like
372377 # List(StructField(field1,IntegerType,false)).
@@ -378,13 +383,13 @@ def _parse_datatype_string(datatype_string):
378383
379384_cached_namedtuples = {}
380385
381- def _restore_object (fields , obj ):
386+ def _restore_object (name , fields , obj ):
382387 """ Restore namedtuple object during unpickling. """
383388 cls = _cached_namedtuples .get (fields )
384389 if cls is None :
385- cls = namedtuple ("Row" , fields )
390+ cls = namedtuple (name , fields )
386391 def __reduce__ (self ):
387- return (_restore_object , (fields , tuple (self )))
392+ return (_restore_object , (name , fields , tuple (self )))
388393 cls .__reduce__ = __reduce__
389394 _cached_namedtuples [fields ] = cls
390395 return cls (* obj )
@@ -395,13 +400,13 @@ def _create_object(cls, v):
395400
396401def _create_getter (dt , i ):
397402 """ Create a getter for item `i` with schema """
398- # TODO: cache created class
399403 cls = _create_cls (dt )
400404 def getter (self ):
401405 return _create_object (cls , self [i ])
402406 return getter
403407
404408def _has_struct (dt ):
409+ """Return whether `dt` is or has StructType in it"""
405410 if isinstance (dt , StructType ):
406411 return True
407412 elif isinstance (dt , ArrayType ):
@@ -416,34 +421,35 @@ def _create_cls(dataType):
416421
417422 The created class is similar to namedtuple, but can have nested schema.
418423 """
419- # this can not be in global
420- from pyspark .sql import _has_struct , _create_getter
421424 from operator import itemgetter
422425
423-
424- # TODO: update to new DataType
425426 if isinstance (dataType , ArrayType ):
426427 cls = _create_cls (dataType .elementType )
427428 class List (list ):
428429 def __getitem__ (self , i ):
430+ # create object with datetype
429431 return _create_object (cls , list .__getitem__ (self , i ))
430432 def __repr__ (self ):
433+ # call collect __repr__ for nested objects
431434 return "[%s]" % (", " .join (repr (self [i ])
432435 for i in range (len (self ))))
433436 def __reduce__ (self ):
434- # the nested struct can be reduced by itself
437+ # pickle as dict, the nested struct can be reduced by itself
435438 return (list , (list (self ),))
436439 return List
437440
438441 elif isinstance (dataType , MapType ):
439442 vcls = _create_cls (dataType .valueType )
440443 class Dict (dict ):
441444 def __getitem__ (self , k ):
445+ # create object with datetype
442446 return _create_object (vcls , dict .__getitem__ (self , k ))
443447 def __repr__ (self ):
448+ # call collect __repr__ for nested objects
444449 return "{%s}" % (", " .join ("%r: %r" % (k , self [k ])
445450 for k in self ))
446451 def __reduce__ (self ):
452+ # pickle as dict, the nested struct can be reduced by itself
447453 return (dict , (dict (self ),))
448454 return Dict
449455
@@ -454,24 +460,24 @@ class Row(tuple):
454460 """ Row in SchemaRDD """
455461 _fields = tuple (f .name for f in dataType .fields )
456462
463+ # create property for fast access
457464 # use local vars begins with "_"
458465 for _i ,_f in enumerate (dataType .fields ):
459466 if _has_struct (_f .dataType ):
460- _getter = property (_create_getter (_f .dataType , _i ))
467+ # delay creating object until accessing it
468+ _getter = _create_getter (_f .dataType , _i )
461469 else :
462- _getter = property ( itemgetter (_i ) )
463- locals ()[_f .name ] = _getter
470+ _getter = itemgetter (_i )
471+ locals ()[_f .name ] = property ( _getter )
464472 del _i , _f , _getter
465473
466474 def __repr__ (self ):
475+ # call collect __repr__ for nested objects
467476 return ("Row(%s)" % ", " .join ("%s=%r" % (n , getattr (self , n ))
468477 for n in self ._fields ))
469-
470- def __str__ (self ):
471- return repr (self )
472-
473478 def __reduce__ (self ):
474- return (_restore_object , (self ._fields , tuple (self )))
479+ # pickle as namedtuple
480+ return (_restore_object , ("Row" , self ._fields , tuple (self )))
475481
476482 return Row
477483
0 commit comments