Skip to content

Commit c246b95

Browse files
Davies LiuJoshRosen
authored andcommitted
[SPARK-4841] fix zip with textFile()
UTF8Deserializer can not be used in BatchedSerializer, so always use PickleSerializer() when change batchSize in zip(). Also, if two RDD have the same batch size already, they did not need re-serialize any more. Author: Davies Liu <[email protected]> Closes apache#3706 from davies/fix_4841 and squashes the following commits: 20ce3a3 [Davies Liu] fix bug in _reserialize() e3ebf7c [Davies Liu] add comment 379d2c8 [Davies Liu] fix zip with textFile()
1 parent c762877 commit c246b95

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

python/pyspark/rdd.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,7 @@ def intersection(self, other):
469469
def _reserialize(self, serializer=None):
470470
serializer = serializer or self.ctx.serializer
471471
if self._jrdd_deserializer != serializer:
472-
if not isinstance(self, PipelinedRDD):
473-
self = self.map(lambda x: x, preservesPartitioning=True)
472+
self = self.map(lambda x: x, preservesPartitioning=True)
474473
self._jrdd_deserializer = serializer
475474
return self
476475

@@ -1798,23 +1797,21 @@ def zip(self, other):
17981797
def get_batch_size(ser):
17991798
if isinstance(ser, BatchedSerializer):
18001799
return ser.batchSize
1801-
return 1
1800+
return 1 # not batched
18021801

18031802
def batch_as(rdd, batchSize):
1804-
ser = rdd._jrdd_deserializer
1805-
if isinstance(ser, BatchedSerializer):
1806-
ser = ser.serializer
1807-
return rdd._reserialize(BatchedSerializer(ser, batchSize))
1803+
return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize))
18081804

18091805
my_batch = get_batch_size(self._jrdd_deserializer)
18101806
other_batch = get_batch_size(other._jrdd_deserializer)
1811-
# use the smallest batchSize for both of them
1812-
batchSize = min(my_batch, other_batch)
1813-
if batchSize <= 0:
1814-
# auto batched or unlimited
1815-
batchSize = 100
1816-
other = batch_as(other, batchSize)
1817-
self = batch_as(self, batchSize)
1807+
if my_batch != other_batch:
1808+
# use the smallest batchSize for both of them
1809+
batchSize = min(my_batch, other_batch)
1810+
if batchSize <= 0:
1811+
# auto batched or unlimited
1812+
batchSize = 100
1813+
other = batch_as(other, batchSize)
1814+
self = batch_as(self, batchSize)
18181815

18191816
if self.getNumPartitions() != other.getNumPartitions():
18201817
raise ValueError("Can only zip with RDD which has the same number of partitions")

python/pyspark/serializers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,9 @@ def dumps(self, obj):
463463
def loads(self, obj):
464464
return self.serializer.loads(zlib.decompress(obj))
465465

466+
def __eq__(self, other):
467+
return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
468+
466469

467470
class UTF8Deserializer(Serializer):
468471

@@ -489,6 +492,9 @@ def load_stream(self, stream):
489492
except EOFError:
490493
return
491494

495+
def __eq__(self, other):
496+
return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
497+
492498

493499
def read_long(stream):
494500
length = stream.read(8)

python/pyspark/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,15 @@ def test_zip_with_different_serializers(self):
533533
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
534534
b = b._reserialize(MarshalSerializer())
535535
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
536+
# regression test for SPARK-4841
537+
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
538+
t = self.sc.textFile(path)
539+
cnt = t.count()
540+
self.assertEqual(cnt, t.zip(t).count())
541+
rdd = t.map(str)
542+
self.assertEqual(cnt, t.zip(rdd).count())
543+
# regression test for bug in _reserializer()
544+
self.assertEqual(cnt, t.zip(rdd).count())
536545

537546
def test_zip_with_different_number_of_items(self):
538547
a = self.sc.parallelize(range(5), 2)

0 commit comments

Comments
 (0)