Skip to content

Commit e2d96ca

Browse files
committed
Expand serializer API and use new function to help control when new UnsafeShuffle path is used.
1 parent e267cee commit e2d96ca

File tree

4 files changed

+55
-23
lines changed

4 files changed

+55
-23
lines changed

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ class KryoSerializer(conf: SparkConf)
125125
override def newInstance(): SerializerInstance = {
126126
new KryoSerializerInstance(this)
127127
}
128+
129+
override def supportsRelocationOfSerializedObjects: Boolean = {
130+
// TODO: we should have a citation / explanatory comment here clarifying _why_ this is the case
131+
newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()
132+
}
128133
}
129134

130135
private[spark]

core/src/main/scala/org/apache/spark/serializer/Serializer.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
2323
import scala.reflect.ClassTag
2424

2525
import org.apache.spark.{SparkConf, SparkEnv}
26-
import org.apache.spark.annotation.DeveloperApi
26+
import org.apache.spark.annotation.{Experimental, DeveloperApi}
2727
import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
2828

2929
/**
@@ -63,6 +63,30 @@ abstract class Serializer {
6363

6464
/** Creates a new [[SerializerInstance]]. */
6565
def newInstance(): SerializerInstance
66+
67+
/**
68+
* Returns true if this serializer supports relocation of its serialized objects and false
69+
* otherwise. This should return true if and only if reordering the bytes of serialized objects
70+
* in serialization stream output results in re-ordered input that can be read with the
71+
* deserializer. For instance, the following should work if the serializer supports relocation:
72+
*
73+
* serOut.open()
74+
* position = 0
75+
* serOut.write(obj1)
76+
* serOut.flush()
77+
* position = # of bytes writen to stream so far
78+
* obj1Bytes = [bytes 0 through position of stream]
79+
* serOut.write(obj2)
80+
* serOut.flush
81+
* position2 = # of bytes written to stream so far
82+
* obj2Bytes = bytes[position through position2 of stream]
83+
*
84+
* serIn.open([obj2bytes] concatenate [obj1bytes]) should return (obj2, obj1)
85+
*
86+
* See SPARK-7311 for more discussion.
87+
*/
88+
@Experimental
89+
def supportsRelocationOfSerializedObjects: Boolean = false
6690
}
6791

6892

core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util
2222

2323
import com.esotericsoftware.kryo.io.ByteBufferOutputStream
2424

25-
import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
25+
import org.apache.spark._
2626
import org.apache.spark.executor.ShuffleWriteMetrics
2727
import org.apache.spark.scheduler.MapStatus
2828
import org.apache.spark.serializer.Serializer
@@ -34,17 +34,31 @@ import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
3434
import org.apache.spark.unsafe.sort.UnsafeSorter
3535
import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator}
3636

37-
private[spark] class UnsafeShuffleHandle[K, V](
37+
private class UnsafeShuffleHandle[K, V](
3838
shuffleId: Int,
3939
override val numMaps: Int,
4040
override val dependency: ShuffleDependency[K, V, V])
4141
extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
42-
require(UnsafeShuffleManager.canUseUnsafeShuffle(dependency))
4342
}
4443

45-
private[spark] object UnsafeShuffleManager {
44+
private[spark] object UnsafeShuffleManager extends Logging {
4645
def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
47-
dependency.aggregator.isEmpty && dependency.keyOrdering.isEmpty
46+
val shufId = dependency.shuffleId
47+
val serializer = Serializer.getSerializer(dependency.serializer)
48+
if (!serializer.supportsRelocationOfSerializedObjects) {
49+
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
50+
s"${serializer.getClass.getName}, does not support object relocation")
51+
false
52+
} else if (dependency.aggregator.isDefined) {
53+
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
54+
false
55+
} else if (dependency.keyOrdering.isDefined) {
56+
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
57+
false
58+
} else {
59+
log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
60+
true
61+
}
4862
}
4963
}
5064

@@ -73,15 +87,13 @@ private object PartitionerPrefixComparator extends PrefixComparator {
7387
}
7488
}
7589

76-
private[spark] class UnsafeShuffleWriter[K, V](
90+
private class UnsafeShuffleWriter[K, V](
7791
shuffleBlockManager: IndexShuffleBlockManager,
7892
handle: UnsafeShuffleHandle[K, V],
7993
mapId: Int,
8094
context: TaskContext)
8195
extends ShuffleWriter[K, V] {
8296

83-
println("Construcing a new UnsafeShuffleWriter")
84-
8597
private[this] val memoryManager: TaskMemoryManager = context.taskMemoryManager()
8698

8799
private[this] val dep = handle.dependency
@@ -158,7 +170,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
158170
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
159171
PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
160172
currentPagePosition += 8
161-
println("The stored record length is " + serializedRecordSize)
162173
PlatformDependent.UNSAFE.putLong(
163174
currentPage.getBaseObject, currentPagePosition, serializedRecordSize)
164175
currentPagePosition += 8
@@ -169,7 +180,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
169180
currentPagePosition,
170181
serializedRecordSize)
171182
currentPagePosition += serializedRecordSize
172-
println("After writing record, current page position is " + currentPagePosition)
173183
sorter.insertRecord(newRecordAddress)
174184

175185
// Reset for writing the next record
@@ -195,8 +205,10 @@ private[spark] class UnsafeShuffleWriter[K, V](
195205
// TODO: don't close and re-open file handles so often; this could be inefficient
196206

197207
def closePartition(): Unit = {
198-
writer.commitAndClose()
199-
partitionLengths(currentPartition) = writer.fileSegment().length
208+
if (writer != null) {
209+
writer.commitAndClose()
210+
partitionLengths(currentPartition) = writer.fileSegment().length
211+
}
200212
}
201213

202214
def switchToPartition(newPartition: Int): Unit = {
@@ -219,8 +231,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
219231
val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
220232
val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
221233
val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt
222-
println("Base offset is " + baseOffset)
223-
println("Record length is " + recordLength)
224234
// TODO: need to have a way to figure out whether a serializer supports relocation of
225235
// serialized objects or not. Sandy also ran into this in his patch (see
226236
// https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
@@ -244,12 +254,8 @@ private[spark] class UnsafeShuffleWriter[K, V](
244254

245255
/** Write a sequence of records to this task's output */
246256
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
247-
println("Opened writer!")
248-
249257
val sortedIterator = sortRecords(records)
250258
val partitionLengths = writeSortedRecordsToFile(sortedIterator)
251-
252-
println("Partition lengths are " + partitionLengths.toSeq)
253259
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
254260
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
255261
}
@@ -264,7 +270,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
264270

265271
/** Close this writer, passing along whether the map completed */
266272
override def stop(success: Boolean): Option[MapStatus] = {
267-
println("Stopping unsafeshufflewriter")
268273
try {
269274
if (stopping) {
270275
None
@@ -300,7 +305,6 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
300305
numMaps: Int,
301306
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
302307
if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
303-
println("Opening unsafeShuffleWriter")
304308
new UnsafeShuffleHandle[K, V](
305309
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
306310
} else {

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ private[spark] class ExternalSorter[K, V, C](
131131
private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
132132
private val useSerializedPairBuffer =
133133
!ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
134-
ser.isInstanceOf[KryoSerializer] &&
135-
serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset
134+
ser.supportsRelocationOfSerializedObjects
136135

137136
// Data structures to store in-memory objects before we spill. Depending on whether we have an
138137
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we

0 commit comments

Comments
 (0)