Skip to content

Commit 75ca5bd

Browse files
committed
[SPARK-2024] Better type checking for batch serialized RDD
1 parent 0bdec55 commit 75ca5bd

File tree

4 files changed

+44
-18
lines changed

4 files changed

+44
-18
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,11 @@ private[spark] object PythonRDD extends Logging {
607607
*/
608608
def saveAsSequenceFile[K, V, C <: CompressionCodec](
609609
pyRDD: JavaRDD[Array[Byte]],
610+
batchSerialized: Boolean,
610611
path: String,
611612
compressionCodecClass: String) = {
612-
saveAsHadoopFile(pyRDD, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat",
613+
saveAsHadoopFile(
614+
pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat",
613615
null, null, null, null, new java.util.HashMap(), compressionCodecClass, false)
614616
}
615617

@@ -625,6 +627,7 @@ private[spark] object PythonRDD extends Logging {
625627
def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], G <: NewOutputFormat[_, _],
626628
C <: CompressionCodec](
627629
pyRDD: JavaRDD[Array[Byte]],
630+
batchSerialized: Boolean,
628631
path: String,
629632
outputFormatClass: String,
630633
keyClass: String,
@@ -634,7 +637,7 @@ private[spark] object PythonRDD extends Logging {
634637
confAsMap: java.util.HashMap[String, String],
635638
compressionCodecClass: String,
636639
useNewAPI: Boolean) = {
637-
val rdd = SerDeUtil.pythonToPairRDD(pyRDD)
640+
val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized)
638641
val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse(
639642
inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass))
640643
val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration)
@@ -660,13 +663,14 @@ private[spark] object PythonRDD extends Logging {
660663
*/
661664
def saveAsHadoopDataset[K, V](
662665
pyRDD: JavaRDD[Array[Byte]],
666+
batchSerialzied: Boolean,
663667
confAsMap: java.util.HashMap[String, String],
664668
useNewAPI: Boolean,
665669
keyConverterClass: String,
666670
valueConverterClass: String) = {
667671
val conf = PythonHadoopUtil.mapToConf(confAsMap)
668-
val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD), keyConverterClass,
669-
valueConverterClass, new JavaToWritableConverter)
672+
val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialzied),
673+
keyConverterClass, valueConverterClass, new JavaToWritableConverter)
670674
if (useNewAPI) {
671675
converted.saveAsNewAPIHadoopDataset(conf)
672676
} else {

core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,25 @@ private[python] object SerDeUtil extends Logging {
8686
/**
8787
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
8888
*/
89-
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]]): RDD[(K, V)] = {
89+
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
9090
def isPair(obj: Any): Boolean = {
9191
Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
9292
obj.asInstanceOf[Array[_]].length == 2
9393
}
9494
pyRDD.mapPartitions { iter =>
9595
val unpickle = new Unpickler
96-
iter.flatMap { row =>
97-
unpickle.loads(row) match {
98-
// batch serialized Python RDDs
99-
case objs: java.util.List[_] => objs
100-
// unbatched case
101-
case obj => Seq(obj)
96+
val unpickled = if (batchSerialized) {
97+
iter.flatMap { batch =>
98+
unpickle.loads(batch) match {
99+
case objs: java.util.List[_] => collectionAsScalaIterable(objs)
100+
case other => throw new SparkException(
101+
s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
102+
}
102103
}
103-
}.map {
104+
} else {
105+
iter.map(unpickle.loads(_))
106+
}
107+
unpickled.map {
104108
// we only accept pickled (K, V)
105109
case obj if isPair(obj) =>
106110
val arr = obj.asInstanceOf[Array[_]]

python/pyspark/rdd.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer):
232232
self._id = jrdd.id()
233233

234234
def _toPickleSerialization(self):
235-
if (self._jrdd_deserializer == PickleSerializer or
235+
if (self._jrdd_deserializer == PickleSerializer() or
236236
self._jrdd_deserializer == BatchedSerializer(PickleSerializer())):
237237
return self
238238
else:
@@ -1049,7 +1049,9 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None
10491049
@param valueConverter: (None by default)
10501050
"""
10511051
jconf = self.ctx._dictToJavaMap(conf)
1052-
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(self._toPickleSerialization()._jrdd, jconf,
1052+
pickled = self._toPickleSerialization()
1053+
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
1054+
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickled._jrdd, batched, jconf,
10531055
True, keyConverter, valueConverter)
10541056

10551057
def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
@@ -1074,7 +1076,9 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl
10741076
@param conf: Hadoop job configuration, passed in as a dict (None by default)
10751077
"""
10761078
jconf = self.ctx._dictToJavaMap(conf)
1077-
self.ctx._jvm.PythonRDD.saveAsHadoopFile(self._toPickleSerialization()._jrdd, path,
1079+
pickled = self._toPickleSerialization()
1080+
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
1081+
self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickled._jrdd, batched, path,
10781082
outputFormatClass, keyClass, valueClass, keyConverter, valueConverter,
10791083
jconf, None, True)
10801084

@@ -1090,7 +1094,9 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
10901094
@param valueConverter: (None by default)
10911095
"""
10921096
jconf = self.ctx._dictToJavaMap(conf)
1093-
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(self._toPickleSerialization()._jrdd, jconf,
1097+
pickled = self._toPickleSerialization()
1098+
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
1099+
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickled._jrdd, batched, jconf,
10941100
False, keyConverter, valueConverter)
10951101

10961102
def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
@@ -1116,7 +1122,9 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No
11161122
@param compressionCodecClass: (None by default)
11171123
"""
11181124
jconf = self.ctx._dictToJavaMap(conf)
1119-
self.ctx._jvm.PythonRDD.saveAsHadoopFile(self._toPickleSerialization()._jrdd,
1125+
pickled = self._toPickleSerialization()
1126+
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
1127+
self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickled._jrdd, batched,
11201128
path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter,
11211129
jconf, compressionCodecClass, False)
11221130

@@ -1131,7 +1139,9 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None):
11311139
@param path: path to sequence file
11321140
@param compressionCodecClass: (None by default)
11331141
"""
1134-
self.ctx._jvm.PythonRDD.saveAsSequenceFile(self._toPickleSerialization()._jrdd,
1142+
pickled = self._toPickleSerialization()
1143+
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
1144+
self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickled._jrdd, batched,
11351145
path, compressionCodecClass)
11361146

11371147
def saveAsPickleFile(self, path, batchSize=10):

python/pyspark/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,14 @@ def test_unbatched_save_and_read(self):
702702
batchSize=1).collect())
703703
self.assertEqual(unbatched_newAPIHadoopRDD, ei)
704704

705+
def test_malformed_RDD(self):
706+
basepath = self.tempdir.name
707+
# non-batch-serialized RDD of type RDD[[(K, V)]] should be rejected
708+
data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]]
709+
rdd = self.sc.parallelize(data, numSlices=len(data))
710+
self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile(
711+
basepath + "/malformed/sequence"))
712+
705713
class TestDaemon(unittest.TestCase):
706714
def connect(self, port):
707715
from socket import socket, AF_INET, SOCK_STREAM

0 commit comments

Comments
 (0)