Skip to content

Commit 8c70dd9

Browse files
committed
Fix serialization
1 parent 9c16fe6 commit 8c70dd9

File tree

7 files changed

+69
-19
lines changed

7 files changed

+69
-19
lines changed

core/src/main/scala/org/apache/spark/serializer/Serializer.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,12 @@ abstract class SerializerInstance {
101101
*/
102102
@DeveloperApi
103103
abstract class SerializationStream {
104+
/** The most general-purpose method to write an object. */
104105
def writeObject[T: ClassTag](t: T): SerializationStream
106+
/** Writes the object representing the key of a key-value pair. */
107+
def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key)
108+
/** Writes the object representing the value of a key-value pair. */
109+
def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
105110
def flush(): Unit
106111
def close(): Unit
107112

@@ -120,7 +125,12 @@ abstract class SerializationStream {
120125
*/
121126
@DeveloperApi
122127
abstract class DeserializationStream {
128+
/** The most general-purpose method to read an object. */
123129
def readObject[T: ClassTag](): T
130+
/** Reads the object representing the key of a key-value pair. */
131+
def readKey[T: ClassTag](): T = readObject[T]()
132+
/** Reads the object representing the value of a key-value pair. */
133+
def readValue[T: ClassTag](): T = readObject[T]()
124134
def close(): Unit
125135

126136
/**
@@ -141,4 +151,25 @@ abstract class DeserializationStream {
141151
DeserializationStream.this.close()
142152
}
143153
}
154+
155+
/**
156+
* Read the elements of this stream through an iterator over key-value pairs. This can only be
157+
* called once, as reading each element will consume data from the input source.
158+
*/
159+
def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] {
160+
override protected def getNext() = {
161+
try {
162+
(readKey[Any](), readValue[Any]())
163+
} catch {
164+
case eof: EOFException => {
165+
finished = true
166+
null
167+
}
168+
}
169+
}
170+
171+
override protected def close() {
172+
DeserializationStream.this.close()
173+
}
174+
}
144175
}

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ private[spark] class DiskBlockObjectWriter(
213213
open()
214214
}
215215

216-
objOut.writeObject(key)
217-
objOut.writeObject(value)
216+
objOut.writeKey(key)
217+
objOut.writeValue(value)
218218
numRecordsWritten += 1
219219
writeMetrics.incShuffleRecordsWritten(1)
220220

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
2727
import org.apache.spark.network.buffer.ManagedBuffer
2828
import org.apache.spark.serializer.{SerializerInstance, Serializer}
2929
import org.apache.spark.util.{CompletionIterator, Utils}
30-
import org.apache.spark.util.collection.PairIterator
3130

3231
/**
3332
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -300,7 +299,7 @@ final class ShuffleBlockFetcherIterator(
300299
// the scheduler gets a FetchFailedException.
301300
Try(buf.createInputStream()).map { is0 =>
302301
val is = blockManager.wrapForCompression(blockId, is0)
303-
val iter = new PairIterator(serializerInstance.deserializeStream(is).asIterator)
302+
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
304303
CompletionIterator[Any, Iterator[Any]](iter, {
305304
// Once the iterator is exhausted, release the buffer and set currentResult to null
306305
// so we don't release it again in cleanup.

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,8 @@ class ExternalAppendOnlyMap[K, V, C](
433433
*/
434434
private def readNextItem(): (K, C) = {
435435
try {
436-
val k = deserializeStream.readObject().asInstanceOf[K]
437-
val c = deserializeStream.readObject().asInstanceOf[C]
436+
val k = deserializeStream.readKey().asInstanceOf[K]
437+
val c = deserializeStream.readValue().asInstanceOf[C]
438438
val item = (k, c)
439439
objectsRead += 1
440440
if (objectsRead == serializerBatchSize) {

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,8 @@ private[spark] class ExternalSorter[K, V, C](
599599
if (finished || deserializeStream == null) {
600600
return null
601601
}
602-
val k = deserializeStream.readObject().asInstanceOf[K]
603-
val c = deserializeStream.readObject().asInstanceOf[C]
602+
val k = deserializeStream.readKey().asInstanceOf[K]
603+
val c = deserializeStream.readValue().asInstanceOf[C]
604604
lastPartitionId = partitionId
605605
// Start reading the next batch if we're done with this one
606606
indexInBatch += 1

core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
105105
var metaBufferPos = 0
106106
def hasNext: Boolean = metaBufferPos < metaBuffer.position
107107
def next(): ((Int, K), V) = {
108-
val key = deserStream.readObject[Any]().asInstanceOf[K]
109-
val value = deserStream.readObject[Any]().asInstanceOf[V]
108+
val key = deserStream.readKey[Any]().asInstanceOf[K]
109+
val value = deserStream.readValue[Any]().asInstanceOf[V]
110110
val partition = metaBuffer.get(metaBufferPos + PARTITION)
111111
metaBufferPos += RECORD_SIZE
112112
((partition, key), value)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,27 @@ private[sql] class Serializer2SerializationStream(
5050
extends SerializationStream with Logging {
5151

5252
val rowOut = new DataOutputStream(out)
53-
val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
54-
val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
53+
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
54+
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
5555

56-
def writeObject[T: ClassTag](t: T): SerializationStream = {
56+
override def writeObject[T: ClassTag](t: T): SerializationStream = {
5757
val kv = t.asInstanceOf[Product2[Row, Row]]
5858
writeKey(kv._1)
5959
writeValue(kv._2)
6060

6161
this
6262
}
6363

64+
override def writeKey[T: ClassTag](t: T): SerializationStream = {
65+
writeKeyFunc(t.asInstanceOf[Row])
66+
this
67+
}
68+
69+
override def writeValue[T: ClassTag](t: T): SerializationStream = {
70+
writeValueFunc(t.asInstanceOf[Row])
71+
this
72+
}
73+
6474
def flush(): Unit = {
6575
rowOut.flush()
6676
}
@@ -83,17 +93,27 @@ private[sql] class Serializer2DeserializationStream(
8393

8494
val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
8595
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
86-
val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
87-
val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
96+
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
97+
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
8898

89-
def readObject[T: ClassTag](): T = {
90-
readKey()
91-
readValue()
99+
override def readObject[T: ClassTag](): T = {
100+
readKeyFunc()
101+
readValueFunc()
92102

93103
(key, value).asInstanceOf[T]
94104
}
95105

96-
def close(): Unit = {
106+
override def readKey[T: ClassTag](): T = {
107+
readKeyFunc()
108+
key.asInstanceOf[T]
109+
}
110+
111+
override def readValue[T: ClassTag](): T = {
112+
readValueFunc()
113+
value.asInstanceOf[T]
114+
}
115+
116+
override def close(): Unit = {
97117
rowIn.close()
98118
}
99119
}

0 commit comments

Comments
 (0)