@@ -1121,7 +1121,7 @@ def applySchema(self, rdd, schema):
11211121 batched = isinstance (rdd ._jrdd_deserializer , BatchedSerializer )
11221122 jrdd = self ._pythonToJava (rdd ._jrdd , batched )
11231123 srdd = self ._ssql_ctx .applySchemaToPythonRDD (jrdd .rdd (), str (schema ))
1124- return SchemaRDD (srdd , self )
1124+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
11251125
11261126 def registerRDDAsTable (self , rdd , tableName ):
11271127 """Registers the given RDD as a temporary table in the catalog.
@@ -1133,8 +1133,8 @@ def registerRDDAsTable(self, rdd, tableName):
11331133 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
11341134 """
11351135 if (rdd .__class__ is SchemaRDD ):
1136- jschema_rdd = rdd ._jschema_rdd
1137- self ._ssql_ctx .registerRDDAsTable (jschema_rdd , tableName )
1136+ srdd = rdd ._jschema_rdd . baseSchemaRDD ()
1137+ self ._ssql_ctx .registerRDDAsTable (srdd , tableName )
11381138 else :
11391139 raise ValueError ("Can only register SchemaRDD as table" )
11401140
@@ -1150,7 +1150,7 @@ def parquetFile(self, path):
11501150 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
11511151 True
11521152 """
1153- jschema_rdd = self ._ssql_ctx .parquetFile (path )
1153+ jschema_rdd = self ._ssql_ctx .parquetFile (path ). toJavaSchemaRDD ()
11541154 return SchemaRDD (jschema_rdd , self )
11551155
11561156 def jsonFile (self , path , schema = None ):
@@ -1206,11 +1206,11 @@ def jsonFile(self, path, schema=None):
12061206 [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
12071207 """
12081208 if schema is None :
1209- jschema_rdd = self ._ssql_ctx .jsonFile (path )
1209+ srdd = self ._ssql_ctx .jsonFile (path )
12101210 else :
12111211 scala_datatype = self ._ssql_ctx .parseDataType (str (schema ))
1212- jschema_rdd = self ._ssql_ctx .jsonFile (path , scala_datatype )
1213- return SchemaRDD (jschema_rdd , self )
1212+ srdd = self ._ssql_ctx .jsonFile (path , scala_datatype )
1213+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
12141214
12151215 def jsonRDD (self , rdd , schema = None ):
12161216 """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
@@ -1274,11 +1274,11 @@ def func(iterator):
12741274 keyed ._bypass_serializer = True
12751275 jrdd = keyed ._jrdd .map (self ._jvm .BytesToString ())
12761276 if schema is None :
1277- jschema_rdd = self ._ssql_ctx .jsonRDD (jrdd .rdd ())
1277+ srdd = self ._ssql_ctx .jsonRDD (jrdd .rdd ())
12781278 else :
12791279 scala_datatype = self ._ssql_ctx .parseDataType (str (schema ))
1280- jschema_rdd = self ._ssql_ctx .jsonRDD (jrdd .rdd (), scala_datatype )
1281- return SchemaRDD (jschema_rdd , self )
1280+ srdd = self ._ssql_ctx .jsonRDD (jrdd .rdd (), scala_datatype )
1281+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
12821282
12831283 def sql (self , sqlQuery ):
12841284 """Return a L{SchemaRDD} representing the result of the given query.
@@ -1289,7 +1289,7 @@ def sql(self, sqlQuery):
12891289 >>> srdd2.collect()
12901290 [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
12911291 """
1292- return SchemaRDD (self ._ssql_ctx .sql (sqlQuery ), self )
1292+ return SchemaRDD (self ._ssql_ctx .sql (sqlQuery ). toJavaSchemaRDD () , self )
12931293
12941294 def table (self , tableName ):
12951295 """Returns the specified table as a L{SchemaRDD}.
@@ -1300,7 +1300,7 @@ def table(self, tableName):
13001300 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
13011301 True
13021302 """
1303- return SchemaRDD (self ._ssql_ctx .table (tableName ), self )
1303+ return SchemaRDD (self ._ssql_ctx .table (tableName ). toJavaSchemaRDD () , self )
13041304
13051305 def cacheTable (self , tableName ):
13061306 """Caches the specified table in-memory."""
@@ -1352,7 +1352,7 @@ def hiveql(self, hqlQuery):
13521352 warnings .warn ("hiveql() is deprecated as the sql function now parses using HiveQL by" +
13531353 "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'" ,
13541354 DeprecationWarning )
1355- return SchemaRDD (self ._ssql_ctx .hiveql (hqlQuery ), self )
1355+ return SchemaRDD (self ._ssql_ctx .hiveql (hqlQuery ). toJavaSchemaRDD () , self )
13561356
13571357 def hql (self , hqlQuery ):
13581358 """
@@ -1508,6 +1508,8 @@ class SchemaRDD(RDD):
15081508 def __init__ (self , jschema_rdd , sql_ctx ):
15091509 self .sql_ctx = sql_ctx
15101510 self ._sc = sql_ctx ._sc
1511+ clsName = jschema_rdd .getClass ().getName ()
1512+ assert clsName .endswith ("JavaSchemaRDD" ), "jschema_rdd must be JavaSchemaRDD"
15111513 self ._jschema_rdd = jschema_rdd
15121514
15131515 self .is_cached = False
@@ -1524,7 +1526,7 @@ def _jrdd(self):
15241526 L{pyspark.rdd.RDD} super class (map, filter, etc.).
15251527 """
15261528 if not hasattr (self , '_lazy_jrdd' ):
1527- self ._lazy_jrdd = self ._jschema_rdd .javaToPython ()
1529+ self ._lazy_jrdd = self ._jschema_rdd .baseSchemaRDD (). javaToPython ()
15281530 return self ._lazy_jrdd
15291531
15301532 @property
@@ -1580,7 +1582,7 @@ def saveAsTable(self, tableName):
15801582 def schema (self ):
15811583 """Returns the schema of this SchemaRDD (represented by
15821584 a L{StructType})."""
1583- return _parse_datatype_string (self ._jschema_rdd .schema ().toString ())
1585+ return _parse_datatype_string (self ._jschema_rdd .baseSchemaRDD (). schema ().toString ())
15841586
15851587 def schemaString (self ):
15861588 """Returns the output schema in the tree format."""
@@ -1631,8 +1633,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
16311633 rdd = RDD (self ._jrdd , self ._sc , self ._jrdd_deserializer )
16321634
16331635 schema = self .schema ()
1634- import pickle
1635- pickle .loads (pickle .dumps (schema ))
16361636
16371637 def applySchema (_ , it ):
16381638 cls = _create_cls (schema )
@@ -1669,10 +1669,8 @@ def isCheckpointed(self):
16691669
16701670 def getCheckpointFile (self ):
16711671 checkpointFile = self ._jschema_rdd .getCheckpointFile ()
1672- if checkpointFile .isDefined ():
1672+ if checkpointFile .isPresent ():
16731673 return checkpointFile .get ()
1674- else :
1675- return None
16761674
16771675 def coalesce (self , numPartitions , shuffle = False ):
16781676 rdd = self ._jschema_rdd .coalesce (numPartitions , shuffle )
0 commit comments