Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def applySchema(self, rdd, schema):
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
return SchemaRDD(srdd, self)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
Expand All @@ -1134,8 +1134,8 @@ def registerRDDAsTable(self, rdd, tableName):
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
"""
if (rdd.__class__ is SchemaRDD):
jschema_rdd = rdd._jschema_rdd
self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
srdd = rdd._jschema_rdd.baseSchemaRDD()
self._ssql_ctx.registerRDDAsTable(srdd, tableName)
else:
raise ValueError("Can only register SchemaRDD as table")

Expand All @@ -1151,7 +1151,7 @@ def parquetFile(self, path):
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
jschema_rdd = self._ssql_ctx.parquetFile(path)
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
return SchemaRDD(jschema_rdd, self)

def jsonFile(self, path, schema=None):
Expand Down Expand Up @@ -1207,11 +1207,11 @@ def jsonFile(self, path, schema=None):
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
jschema_rdd = self._ssql_ctx.jsonFile(path)
srdd = self._ssql_ctx.jsonFile(path)
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(jschema_rdd, self)
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def jsonRDD(self, rdd, schema=None):
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
Expand Down Expand Up @@ -1275,11 +1275,11 @@ def func(iterator):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(jschema_rdd, self)
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def sql(self, sqlQuery):
"""Return a L{SchemaRDD} representing the result of the given query.
Expand All @@ -1290,7 +1290,7 @@ def sql(self, sqlQuery):
>>> srdd2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)

def table(self, tableName):
"""Returns the specified table as a L{SchemaRDD}.
Expand All @@ -1301,7 +1301,7 @@ def table(self, tableName):
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
return SchemaRDD(self._ssql_ctx.table(tableName), self)
return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)

def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
Expand Down Expand Up @@ -1353,7 +1353,7 @@ def hiveql(self, hqlQuery):
warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
"default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
DeprecationWarning)
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)

def hql(self, hqlQuery):
"""
Expand Down Expand Up @@ -1524,6 +1524,8 @@ class SchemaRDD(RDD):
def __init__(self, jschema_rdd, sql_ctx):
self.sql_ctx = sql_ctx
self._sc = sql_ctx._sc
clsName = jschema_rdd.getClass().getName()
assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
self._jschema_rdd = jschema_rdd
self._id = None
self.is_cached = False
Expand All @@ -1540,7 +1542,7 @@ def _jrdd(self):
L{pyspark.rdd.RDD} super class (map, filter, etc.).
"""
if not hasattr(self, '_lazy_jrdd'):
self._lazy_jrdd = self._jschema_rdd.javaToPython()
self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
return self._lazy_jrdd

def id(self):
Expand Down Expand Up @@ -1598,7 +1600,7 @@ def saveAsTable(self, tableName):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
return _parse_datatype_string(self._jschema_rdd.schema().toString())
return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())

def schemaString(self):
"""Returns the output schema in the tree format."""
Expand Down Expand Up @@ -1649,8 +1651,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)

schema = self.schema()
import pickle
pickle.loads(pickle.dumps(schema))

def applySchema(_, it):
cls = _create_cls(schema)
Expand Down Expand Up @@ -1687,10 +1687,8 @@ def isCheckpointed(self):

def getCheckpointFile(self):
checkpointFile = self._jschema_rdd.getCheckpointFile()
if checkpointFile.isDefined():
if checkpointFile.isPresent():
return checkpointFile.get()
else:
return None

def coalesce(self, numPartitions, shuffle=False):
rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,34 @@ def test_broadcast_in_udf(self):
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
srdd = self.sqlCtx.jsonRDD(rdd)
srdd.count()
srdd.collect()
srdd.schemaString()
srdd.schema()

# cache and checkpoint
self.assertFalse(srdd.is_cached)
srdd.persist()
srdd.unpersist()
srdd.cache()
self.assertTrue(srdd.is_cached)
self.assertFalse(srdd.isCheckpointed())
self.assertEqual(None, srdd.getCheckpointFile())

srdd = srdd.coalesce(2, True)
srdd = srdd.repartition(3)
srdd = srdd.distinct()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies Shouldn't we also test srdd.distinct(n) since that was the missing functionality documented in SPARK-3500?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, although if srdd.distinct() works then srdd.distinct(n) should also work due to how distinct() and this fix were implemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll take your word for it but just point out that one of the reported issues in SPARK-3500 was specifically that distinct() worked but distinct(n) didn't. Since that is a possible failure mode, it probably makes sense to have a test for each.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it looks like we don't support distinct(n) in PySpark (the original ticket dealt with distinct() and coalesce() simply not working). Let's open a separate JIRA for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a separate issue makes sense (I actually suggested that in the JIRA ticket). But to clarify, the ticket was originally about coalesce() not working, then repartition() and distinct(n) were added on.

distinct() with no parameters was always working. There was no question about that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distinct(n) is a missing API, we could fix it in another issue or delay it later.

srdd.intersection(srdd)
self.assertEqual(2, srdd.count())

srdd.registerTempTable("temp")
srdd = self.sqlCtx.sql("select foo from temp")
srdd.count()
srdd.collect()


class TestIO(PySparkTestCase):

Expand Down