Skip to content

Commit 9c06c72

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd
Currently, SchemaRDD._jschema_rdd is SchemaRDD, the Scala API (coalesce(), repartition()) can not been called in Python easily, there is no way to specify the implicit parameter `ord`. The _jrdd is an JavaRDD, so _jschema_rdd should also be JavaSchemaRDD. In this patch, change _schema_rdd to JavaSchemaRDD, also added an assert for it. If some methods are missing from JavaSchemaRDD, then it's called by _schema_rdd.baseSchemaRDD().xxx(). BTW, Do we need JavaSQLContext? Author: Davies Liu <[email protected]> Closes #2369 from davies/fix_schemardd and squashes the following commits: abee159 [Davies Liu] use JavaSchemaRDD as SchemaRDD._jschema_rdd (cherry picked from commit 885d162) Signed-off-by: Josh Rosen <[email protected]> Conflicts: python/pyspark/tests.py
1 parent 6cbf83c commit 9c06c72

File tree

2 files changed

+55
-20
lines changed

2 files changed

+55
-20
lines changed

python/pyspark/sql.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,7 @@ def applySchema(self, rdd, schema):
11211121
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
11221122
jrdd = self._pythonToJava(rdd._jrdd, batched)
11231123
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
1124-
return SchemaRDD(srdd, self)
1124+
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
11251125

11261126
def registerRDDAsTable(self, rdd, tableName):
11271127
"""Registers the given RDD as a temporary table in the catalog.
@@ -1133,8 +1133,8 @@ def registerRDDAsTable(self, rdd, tableName):
11331133
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
11341134
"""
11351135
if (rdd.__class__ is SchemaRDD):
1136-
jschema_rdd = rdd._jschema_rdd
1137-
self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
1136+
srdd = rdd._jschema_rdd.baseSchemaRDD()
1137+
self._ssql_ctx.registerRDDAsTable(srdd, tableName)
11381138
else:
11391139
raise ValueError("Can only register SchemaRDD as table")
11401140

@@ -1150,7 +1150,7 @@ def parquetFile(self, path):
11501150
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
11511151
True
11521152
"""
1153-
jschema_rdd = self._ssql_ctx.parquetFile(path)
1153+
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
11541154
return SchemaRDD(jschema_rdd, self)
11551155

11561156
def jsonFile(self, path, schema=None):
@@ -1206,11 +1206,11 @@ def jsonFile(self, path, schema=None):
12061206
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
12071207
"""
12081208
if schema is None:
1209-
jschema_rdd = self._ssql_ctx.jsonFile(path)
1209+
srdd = self._ssql_ctx.jsonFile(path)
12101210
else:
12111211
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
1212-
jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
1213-
return SchemaRDD(jschema_rdd, self)
1212+
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
1213+
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
12141214

12151215
def jsonRDD(self, rdd, schema=None):
12161216
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
@@ -1274,11 +1274,11 @@ def func(iterator):
12741274
keyed._bypass_serializer = True
12751275
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
12761276
if schema is None:
1277-
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
1277+
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
12781278
else:
12791279
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
1280-
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
1281-
return SchemaRDD(jschema_rdd, self)
1280+
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
1281+
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
12821282

12831283
def sql(self, sqlQuery):
12841284
"""Return a L{SchemaRDD} representing the result of the given query.
@@ -1289,7 +1289,7 @@ def sql(self, sqlQuery):
12891289
>>> srdd2.collect()
12901290
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
12911291
"""
1292-
return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
1292+
return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
12931293

12941294
def table(self, tableName):
12951295
"""Returns the specified table as a L{SchemaRDD}.
@@ -1300,7 +1300,7 @@ def table(self, tableName):
13001300
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
13011301
True
13021302
"""
1303-
return SchemaRDD(self._ssql_ctx.table(tableName), self)
1303+
return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
13041304

13051305
def cacheTable(self, tableName):
13061306
"""Caches the specified table in-memory."""
@@ -1352,7 +1352,7 @@ def hiveql(self, hqlQuery):
13521352
warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
13531353
"default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
13541354
DeprecationWarning)
1355-
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
1355+
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
13561356

13571357
def hql(self, hqlQuery):
13581358
"""
@@ -1508,6 +1508,8 @@ class SchemaRDD(RDD):
15081508
def __init__(self, jschema_rdd, sql_ctx):
15091509
self.sql_ctx = sql_ctx
15101510
self._sc = sql_ctx._sc
1511+
clsName = jschema_rdd.getClass().getName()
1512+
assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
15111513
self._jschema_rdd = jschema_rdd
15121514

15131515
self.is_cached = False
@@ -1524,7 +1526,7 @@ def _jrdd(self):
15241526
L{pyspark.rdd.RDD} super class (map, filter, etc.).
15251527
"""
15261528
if not hasattr(self, '_lazy_jrdd'):
1527-
self._lazy_jrdd = self._jschema_rdd.javaToPython()
1529+
self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
15281530
return self._lazy_jrdd
15291531

15301532
@property
@@ -1580,7 +1582,7 @@ def saveAsTable(self, tableName):
15801582
def schema(self):
15811583
"""Returns the schema of this SchemaRDD (represented by
15821584
a L{StructType})."""
1583-
return _parse_datatype_string(self._jschema_rdd.schema().toString())
1585+
return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
15841586

15851587
def schemaString(self):
15861588
"""Returns the output schema in the tree format."""
@@ -1631,8 +1633,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
16311633
rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
16321634

16331635
schema = self.schema()
1634-
import pickle
1635-
pickle.loads(pickle.dumps(schema))
16361636

16371637
def applySchema(_, it):
16381638
cls = _create_cls(schema)
@@ -1669,10 +1669,8 @@ def isCheckpointed(self):
16691669

16701670
def getCheckpointFile(self):
16711671
checkpointFile = self._jschema_rdd.getCheckpointFile()
1672-
if checkpointFile.isDefined():
1672+
if checkpointFile.isPresent():
16731673
return checkpointFile.get()
1674-
else:
1675-
return None
16761674

16771675
def coalesce(self, numPartitions, shuffle=False):
16781676
rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)

python/pyspark/tests.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from pyspark.files import SparkFiles
4242
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
4343
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
44+
from pyspark.storagelevel import StorageLevel
45+
from pyspark.sql import SQLContext
4446

4547
_have_scipy = False
4648
_have_numpy = False
@@ -469,6 +471,41 @@ def test_histogram(self):
469471
self.assertRaises(TypeError, lambda: rdd.histogram(2))
470472

471473

474+
class TestSQL(PySparkTestCase):
475+
476+
def setUp(self):
477+
PySparkTestCase.setUp(self)
478+
self.sqlCtx = SQLContext(self.sc)
479+
480+
def test_basic_functions(self):
481+
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
482+
srdd = self.sqlCtx.jsonRDD(rdd)
483+
srdd.count()
484+
srdd.collect()
485+
srdd.schemaString()
486+
srdd.schema()
487+
488+
# cache and checkpoint
489+
self.assertFalse(srdd.is_cached)
490+
srdd.persist(StorageLevel.MEMORY_ONLY_SER)
491+
srdd.unpersist()
492+
srdd.cache()
493+
self.assertTrue(srdd.is_cached)
494+
self.assertFalse(srdd.isCheckpointed())
495+
self.assertEqual(None, srdd.getCheckpointFile())
496+
497+
srdd = srdd.coalesce(2, True)
498+
srdd = srdd.repartition(3)
499+
srdd = srdd.distinct()
500+
srdd.intersection(srdd)
501+
self.assertEqual(2, srdd.count())
502+
503+
srdd.registerTempTable("temp")
504+
srdd = self.sqlCtx.sql("select foo from temp")
505+
srdd.count()
506+
srdd.collect()
507+
508+
472509
class TestIO(PySparkTestCase):
473510

474511
def test_stdout_redirection(self):

0 commit comments

Comments
 (0)