Skip to content

Commit 9c6cf58

Browse files
committed
Refactor to use DiskBlockObjectWriter.
1 parent 253f13e commit 9c6cf58

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.shuffle.unsafe
1919

20-
import java.io.{FileOutputStream, OutputStream}
2120
import java.nio.ByteBuffer
2221
import java.util
2322

@@ -29,7 +28,7 @@ import org.apache.spark.scheduler.MapStatus
2928
import org.apache.spark.serializer.Serializer
3029
import org.apache.spark.shuffle._
3130
import org.apache.spark.shuffle.sort.SortShuffleManager
32-
import org.apache.spark.storage.ShuffleBlockId
31+
import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId}
3332
import org.apache.spark.unsafe.PlatformDependent
3433
import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
3534
import org.apache.spark.unsafe.sort.UnsafeSorter
@@ -104,15 +103,21 @@ private[spark] class UnsafeShuffleWriter[K, V](
104103

105104
private[this] val blockManager = SparkEnv.get.blockManager
106105

107-
private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = {
106+
// Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
107+
private[this] val fileBufferSize =
108+
SparkEnv.get.conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
109+
110+
private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance()
111+
112+
private def sortRecords(
113+
records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = {
108114
val sorter = new UnsafeSorter(
109115
context.taskMemoryManager(),
110116
DummyRecordComparator,
111117
PartitionerPrefixComputer,
112118
PartitionerPrefixComparator,
113119
4096 // initial size
114120
)
115-
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
116121
val PAGE_SIZE = 1024 * 1024 * 1
117122

118123
var currentPage: MemoryBlock = null
@@ -178,32 +183,31 @@ private[spark] class UnsafeShuffleWriter[K, V](
178183
sorter.getSortedIterator
179184
}
180185

181-
private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = {
186+
private def writeSortedRecordsToFile(
187+
sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = {
182188
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
183189
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID)
184190
val partitionLengths = new Array[Long](partitioner.numPartitions)
185191

186192
var currentPartition = -1
187-
var prevPartitionLength: Long = 0
188-
var out: OutputStream = null
193+
var writer: BlockObjectWriter = null
189194

190195
// TODO: don't close and re-open file handles so often; this could be inefficient
191196

192197
def closePartition(): Unit = {
193-
out.flush()
194-
out.close()
195-
partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength
198+
writer.commitAndClose()
199+
partitionLengths(currentPartition) = writer.fileSegment().length
196200
}
197201

198202
def switchToPartition(newPartition: Int): Unit = {
199-
assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition")
203+
assert (newPartition > currentPartition,
204+
s"new partition $newPartition should be >= $currentPartition")
200205
if (currentPartition != -1) {
201206
closePartition()
202-
prevPartitionLength = partitionLengths(currentPartition)
203207
}
204-
println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq)
205208
currentPartition = newPartition
206-
out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
209+
writer =
210+
blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics)
207211
}
208212

209213
while (sortedRecords.hasNext) {
@@ -214,18 +218,24 @@ private[spark] class UnsafeShuffleWriter[K, V](
214218
}
215219
val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
216220
val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
217-
val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8)
221+
val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt
218222
println("Base offset is " + baseOffset)
219223
println("Record length is " + recordLength)
220224
// TODO: need to have a way to figure out whether a serializer supports relocation of
221225
// serialized objects or not. Sandy also ran into this in his patch (see
222226
// https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
223227
// as well just bypass this optimized code path in favor of the old one.
224-
var i: Int = 0
225-
while (i < recordLength) {
226-
out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i))
227-
i += 1
228-
}
228+
// TODO: re-use a buffer or avoid double-buffering entirely
229+
val arr: Array[Byte] = new Array[Byte](recordLength)
230+
PlatformDependent.copyMemory(
231+
baseObject,
232+
baseOffset + 16,
233+
arr,
234+
PlatformDependent.BYTE_ARRAY_OFFSET,
235+
recordLength)
236+
writer.write(arr)
237+
// TODO: add a test that detects whether we leave this call out:
238+
writer.recordWritten()
229239
}
230240
closePartition()
231241

0 commit comments

Comments
 (0)