1717
1818package org .apache .spark .shuffle .unsafe
1919
20- import java .io .{FileOutputStream , OutputStream }
2120import java .nio .ByteBuffer
2221import java .util
2322
@@ -29,7 +28,7 @@ import org.apache.spark.scheduler.MapStatus
2928import org .apache .spark .serializer .Serializer
3029import org .apache .spark .shuffle ._
3130import org .apache .spark .shuffle .sort .SortShuffleManager
32- import org .apache .spark .storage .ShuffleBlockId
31+ import org .apache .spark .storage .{ BlockObjectWriter , ShuffleBlockId }
3332import org .apache .spark .unsafe .PlatformDependent
3433import org .apache .spark .unsafe .memory .{MemoryBlock , TaskMemoryManager }
3534import org .apache .spark .unsafe .sort .UnsafeSorter
@@ -104,15 +103,21 @@ private[spark] class UnsafeShuffleWriter[K, V](
104103
105104 private [this ] val blockManager = SparkEnv .get.blockManager
106105
107- private def sortRecords (records : Iterator [_ <: Product2 [K , V ]]): java.util.Iterator [KeyPointerAndPrefix ] = {
106+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
107+ private [this ] val fileBufferSize =
108+ SparkEnv .get.conf.getSizeAsKb(" spark.shuffle.file.buffer" , " 32k" ).toInt * 1024
109+
110+ private [this ] val serializer = Serializer .getSerializer(dep.serializer).newInstance()
111+
112+ private def sortRecords (
113+ records : Iterator [_ <: Product2 [K , V ]]): java.util.Iterator [KeyPointerAndPrefix ] = {
108114 val sorter = new UnsafeSorter (
109115 context.taskMemoryManager(),
110116 DummyRecordComparator ,
111117 PartitionerPrefixComputer ,
112118 PartitionerPrefixComparator ,
113119 4096 // initial size
114120 )
115- val serializer = Serializer .getSerializer(dep.serializer).newInstance()
116121 val PAGE_SIZE = 1024 * 1024 * 1
117122
118123 var currentPage : MemoryBlock = null
@@ -178,32 +183,31 @@ private[spark] class UnsafeShuffleWriter[K, V](
178183 sorter.getSortedIterator
179184 }
180185
181- private def writeSortedRecordsToFile (sortedRecords : java.util.Iterator [KeyPointerAndPrefix ]): Array [Long ] = {
186+ private def writeSortedRecordsToFile (
187+ sortedRecords : java.util.Iterator [KeyPointerAndPrefix ]): Array [Long ] = {
182188 val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
183189 val blockId = ShuffleBlockId (dep.shuffleId, mapId, IndexShuffleBlockManager .NOOP_REDUCE_ID )
184190 val partitionLengths = new Array [Long ](partitioner.numPartitions)
185191
186192 var currentPartition = - 1
187- var prevPartitionLength : Long = 0
188- var out : OutputStream = null
193+ var writer : BlockObjectWriter = null
189194
190195 // TODO: don't close and re-open file handles so often; this could be inefficient
191196
192197 def closePartition (): Unit = {
193- out.flush()
194- out.close()
195- partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength
198+ writer.commitAndClose()
199+ partitionLengths(currentPartition) = writer.fileSegment().length
196200 }
197201
198202 def switchToPartition (newPartition : Int ): Unit = {
199- assert (newPartition > currentPartition, s " new partition $newPartition should be >= $currentPartition" )
203+ assert (newPartition > currentPartition,
204+ s " new partition $newPartition should be >= $currentPartition" )
200205 if (currentPartition != - 1 ) {
201206 closePartition()
202- prevPartitionLength = partitionLengths(currentPartition)
203207 }
204- println(s " Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq)
205208 currentPartition = newPartition
206- out = blockManager.wrapForCompression(blockId, new FileOutputStream (outputFile, true ))
209+ writer =
210+ blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics)
207211 }
208212
209213 while (sortedRecords.hasNext) {
@@ -214,18 +218,24 @@ private[spark] class UnsafeShuffleWriter[K, V](
214218 }
215219 val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
216220 val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
217- val recordLength = PlatformDependent .UNSAFE .getLong(baseObject, baseOffset + 8 )
221+ val recordLength : Int = PlatformDependent .UNSAFE .getLong(baseObject, baseOffset + 8 ).toInt
218222 println(" Base offset is " + baseOffset)
219223 println(" Record length is " + recordLength)
220224 // TODO: need to have a way to figure out whether a serializer supports relocation of
221225 // serialized objects or not. Sandy also ran into this in his patch (see
222226 // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
223227 // as well just bypass this optimized code path in favor of the old one.
224- var i : Int = 0
225- while (i < recordLength) {
226- out.write(PlatformDependent .UNSAFE .getByte(baseObject, baseOffset + 16 + i))
227- i += 1
228- }
228+ // TODO: re-use a buffer or avoid double-buffering entirely
229+ val arr : Array [Byte ] = new Array [Byte ](recordLength)
230+ PlatformDependent .copyMemory(
231+ baseObject,
232+ baseOffset + 16 ,
233+ arr,
234+ PlatformDependent .BYTE_ARRAY_OFFSET ,
235+ recordLength)
236+ writer.write(arr)
237+ // TODO: add a test that detects whether we leave this call out:
238+ writer.recordWritten()
229239 }
230240 closePartition()
231241
0 commit comments