1717
1818package org .apache .spark .shuffle .sort
1919
20+ import java .io .File
2021import java .io .FileOutputStream
2122import java .nio .ByteBuffer
2223import java .util .Comparator
@@ -38,7 +39,7 @@ import org.apache.spark.util.collection.{MergeUtil, TieredDiskMerger}
3839 * map output block.
3940 *
4041 * As blocks are fetched, we store them in memory until we fail to acquire space from the
41- * ShuffleMemoryManager. When this occurs, we merge the in-memory blocks to disk and go back to
42+ * ShuffleMemoryManager. When this occurs, we merge some in-memory blocks to disk and go back to
4243 * fetching.
4344 *
4445 * TieredDiskMerger is responsible for managing the merged on-disk blocks and for supplying an
@@ -70,10 +71,10 @@ private[spark] class SortShuffleReader[K, C](
7071 private val inMemoryBlocks = new Queue [MemoryShuffleBlock ]()
7172
7273 /**
73- * Maintain the relation between shuffle block and its size . The reason we should maintain this
74- * is that the request shuffle block size is not equal to the result size because of
75- * compression of size. So here we should maintain this make sure the correctness of our
76- * algorithm .
74+ * Maintain block manager and reported size of each shuffle block . The block manager is used for
75+ * error reporting. The reported size, which, because of size compression, may be slightly
76+ * different than the size of the actual fetched block, is used for calculating how many blocks
77+ * to spill .
7778 */
7879 private val shuffleBlockMap = new HashMap [ShuffleBlockId , (BlockManagerId , Long )]()
7980
@@ -97,7 +98,7 @@ private[spark] class SortShuffleReader[K, C](
9798 private var _memoryBytesSpilled : Long = 0L
9899 private var _diskBytesSpilled : Long = 0L
99100
100- /** number of bytes left to fetch */
101+ /** Number of bytes left to fetch */
101102 private var unfetchedBytes : Long = 0L
102103
103104 def memoryBytesSpilled : Long = _memoryBytesSpilled
@@ -131,7 +132,7 @@ private[spark] class SortShuffleReader[K, C](
131132 val granted = shuffleMemoryManager.tryToAcquire(blockSize)
132133 if (granted >= blockSize) {
133134 if (blockData.isDirect) {
134- // If the memory shuffle block is allocated on direct buffer, copy it on heap,
135+ // If the shuffle block is allocated on a direct buffer, copy it to an on- heap buffer ,
135136 // otherwise off heap memory will be increased out of control.
136137 val onHeapBuffer = ByteBuffer .allocate(blockSize.toInt)
137138 onHeapBuffer.put(blockData.nioByteBuffer)
@@ -142,7 +143,7 @@ private[spark] class SortShuffleReader[K, C](
142143 inMemoryBlocks += MemoryShuffleBlock (blockId, blockData)
143144 }
144145 } else {
145- logDebug(s " Granted $granted memory is not enough to store shuffle block id $blockId, " +
146+ logDebug(s " Granted $granted memory is not enough to store shuffle block $blockId, " +
146147 s " block size $blockSize, spilling in-memory blocks to release the memory " )
147148
148149 shuffleMemoryManager.release(granted)
@@ -162,7 +163,7 @@ private[spark] class SortShuffleReader[K, C](
162163 val mergedItr =
163164 MergeUtil .mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
164165
165- // Update the spilled info and do cleanup work when task is finished.
166+ // Update the spill metrics and do cleanup work when task is finished.
166167 context.taskMetrics().memoryBytesSpilled += memoryBytesSpilled
167168 context.taskMetrics().diskBytesSpilled += diskBytesSpilled
168169
@@ -182,95 +183,105 @@ private[spark] class SortShuffleReader[K, C](
182183 new InterruptibleIterator (context, completionItr.map(p => (p._1, p._2)))
183184 }
184185
186+ /**
187+ * Called when we've failed to acquire memory for a block we've just fetched. Figure out how many
188+ * blocks to spill and then spill them.
189+ */
185190 private def spillInMemoryBlocks (tippingBlock : MemoryShuffleBlock ): Unit = {
186- // Write merged blocks to disk
187191 val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock()
188192
189- def releaseTempShuffleMemory (blocks : ArrayBuffer [MemoryShuffleBlock ]): Unit = {
190- for (block <- blocks) {
191- block.blockData.release()
192- if (block != tippingBlock) {
193- shuffleMemoryManager.release(block.blockData.size)
194- }
195- }
196- }
197-
198193 // If the remaining unfetched data would fit inside our current allocation, we don't want to
199194 // waste time spilling blocks beyond the space needed for it.
200- // We use the request size to calculate the remaining spilled size to make sure the
201- // correctness, since the request size is slightly different from result block size because
202- // of size compression.
195+ // Note that the number of unfetchedBytes is not exact, because of the compression used on the
196+ // sizes of map output blocks.
203197 var bytesToSpill = unfetchedBytes
204198 val blocksToSpill = new ArrayBuffer [MemoryShuffleBlock ]()
205199 blocksToSpill += tippingBlock
206- bytesToSpill -= shuffleBlockMap( tippingBlock.blockId. asInstanceOf [ ShuffleBlockId ])._2
200+ bytesToSpill -= tippingBlock.blockData.size
207201 while (bytesToSpill > 0 && ! inMemoryBlocks.isEmpty) {
208202 val block = inMemoryBlocks.dequeue()
209203 blocksToSpill += block
210- bytesToSpill -= shuffleBlockMap( block.blockId. asInstanceOf [ ShuffleBlockId ])._2
204+ bytesToSpill -= block.blockData.size
211205 }
212206
213207 _memoryBytesSpilled += blocksToSpill.map(_.blockData.size()).sum
214208
215209 if (blocksToSpill.size > 1 ) {
216- val itrGroup = inMemoryBlocksToIterators(blocksToSpill)
217- val partialMergedItr =
218- MergeUtil .mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
219- val curWriteMetrics = new ShuffleWriteMetrics ()
220- var writer =
221- blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics)
222- var success = false
223-
224- try {
225- partialMergedItr.foreach(writer.write)
226- success = true
227- } finally {
228- if (! success) {
229- if (writer != null ) {
230- writer.revertPartialWritesAndClose()
231- writer = null
232- }
233- if (file.exists()) {
234- file.delete()
235- }
236- } else {
237- writer.commitAndClose()
238- writer = null
210+ spillMultipleBlocks(file, tmpBlockId, blocksToSpill, tippingBlock)
211+ } else {
212+ spillSingleBlock(file, blocksToSpill.head)
213+ }
214+
215+ tieredMerger.registerOnDiskBlock(tmpBlockId, file)
216+
217+ logInfo(s " Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}" )
218+ }
219+
220+ private def spillSingleBlock (file : File , block : MemoryShuffleBlock ): Unit = {
221+ val fos = new FileOutputStream (file)
222+ val buffer = block.blockData.nioByteBuffer()
223+ var channel = fos.getChannel
224+ var success = false
225+
226+ try {
227+ while (buffer.hasRemaining) {
228+ channel.write(buffer)
229+ }
230+ success = true
231+ } finally {
232+ if (channel != null ) {
233+ channel.close()
234+ channel = null
235+ }
236+ if (! success) {
237+ if (file.exists()) {
238+ file.delete()
239239 }
240- releaseTempShuffleMemory(blocksToSpill)
240+ } else {
241+ _diskBytesSpilled += file.length()
241242 }
242- _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
243+ // When we spill a single block, it's the single tipping block that we never acquired memory
244+ // from the shuffle memory manager for, so we don't need to release any memory from there.
245+ block.blockData.release()
246+ }
247+ }
243248
244- } else {
245- val fos = new FileOutputStream (file)
246- val buffer = blocksToSpill.map(_.blockData.nioByteBuffer()).head
247- var channel = fos.getChannel
248- var success = false
249-
250- try {
251- while (buffer.hasRemaining) {
252- channel.write(buffer)
249+ /**
250+ * Merge multiple in-memory blocks to a single on-disk file.
251+ */
252+ private def spillMultipleBlocks (file : File , tmpBlockId : BlockId ,
253+ blocksToSpill : Seq [MemoryShuffleBlock ], tippingBlock : MemoryShuffleBlock ): Unit = {
254+ val itrGroup = inMemoryBlocksToIterators(blocksToSpill)
255+ val partialMergedItr =
256+ MergeUtil .mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
257+ val curWriteMetrics = new ShuffleWriteMetrics ()
258+ var writer = blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics)
259+ var success = false
260+
261+ try {
262+ partialMergedItr.foreach(writer.write)
263+ success = true
264+ } finally {
265+ if (! success) {
266+ if (writer != null ) {
267+ writer.revertPartialWritesAndClose()
268+ writer = null
253269 }
254- success = true
255- } finally {
256- if (channel != null ) {
257- channel.close()
258- channel = null
270+ if (file.exists()) {
271+ file.delete()
259272 }
260- if (! success) {
261- if (file.exists()) {
262- file.delete()
263- }
264- } else {
265- _diskBytesSpilled = file.length()
273+ } else {
274+ writer.commitAndClose()
275+ writer = null
276+ }
277+ for (block <- blocksToSpill) {
278+ block.blockData.release()
279+ if (block != tippingBlock) {
280+ shuffleMemoryManager.release(block.blockData.size)
266281 }
267- releaseTempShuffleMemory(blocksToSpill)
268282 }
269283 }
270-
271- tieredMerger.registerOnDiskBlock(tmpBlockId, file)
272-
273- logInfo(s " Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}" )
284+ _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
274285 }
275286
276287 private def inMemoryBlocksToIterators (blocks : Seq [MemoryShuffleBlock ])
0 commit comments