Skip to content

Commit f5df97f

Browse files
committed
refactor, address comments
1 parent 9d9af55 commit f5df97f

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -705,25 +705,25 @@ private[spark] object PythonRDD extends Logging {
705705
* Convert an RDD of serialized Python tuple to Array (no recursive conversions).
706706
* It is only used by pyspark.sql.
707707
*/
708-
def pythonToJava(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[_]] = {
708+
def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
709+
710+
def toArray(obj: Any): Array[_] = {
711+
obj match {
712+
case objs: JArrayList[_] =>
713+
objs.toArray
714+
case obj if obj.getClass.isArray =>
715+
obj.asInstanceOf[Array[_]].toArray
716+
}
717+
}
718+
709719
pyRDD.rdd.mapPartitions { iter =>
710720
val unpickle = new Unpickler
711721
iter.flatMap { row =>
712-
unpickle.loads(row) match {
713-
// in case of objects are pickled in batch mode
714-
case objs: JArrayList[_] => Try(objs.map(obj => obj match {
715-
case list: JArrayList[_] => list.toArray // list
716-
case obj if obj.getClass.isArray => // tuple
717-
obj.asInstanceOf[Array[_]].toArray
718-
})) match {
719-
// objs is list of list or tuple
720-
case Success(v) => v
721-
// objs is a row, list of different objects
722-
case Failure(e) => Seq(objs.toArray)
723-
}
724-
// not in batch mode
725-
case obj if obj.getClass.isArray => // tuple
726-
Seq(obj.asInstanceOf[Array[_]].toArray)
722+
val obj = unpickle.loads(row)
723+
if (batched) {
724+
obj.asInstanceOf[JArrayList[_]].map(toArray)
725+
} else {
726+
Seq(toArray(obj))
727727
}
728728
}
729729
}.toJavaRDD()

python/pyspark/sql.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def __init__(self, sparkContext, sqlContext=None):
640640
self._sc = sparkContext
641641
self._jsc = self._sc._jsc
642642
self._jvm = self._sc._jvm
643-
self._pythonToJava = self._jvm.PythonRDD.pythonToJava
643+
self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
644644

645645
if sqlContext:
646646
self._scala_SQLContext = sqlContext
@@ -686,10 +686,7 @@ def inferSchema(self, rdd):
686686

687687
schema = _inferSchema(first)
688688
rdd = rdd.mapPartitions(lambda rows: _dropSchema(rows, schema))
689-
690-
jrdd = self._pythonToJava(rdd._jrdd)
691-
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
692-
return SchemaRDD(srdd, self)
689+
return self.applySchema(rdd, schema)
693690

694691
def applySchema(self, rdd, schema):
695692
"""Applies the given schema to the given RDD of L{dict}s.
@@ -719,7 +716,8 @@ def applySchema(self, rdd, schema):
719716
>>> srdd.collect()[0]
720717
(127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
721718
"""
722-
jrdd = self._pythonToJava(rdd._jrdd)
719+
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
720+
jrdd = self._pythonToJava(rdd._jrdd, batched)
723721
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
724722
return SchemaRDD(srdd, self)
725723

0 commit comments

Comments
 (0)