Skip to content

Commit 1068005

Browse files
sryzajerryshao
authored andcommitted
Clean up comments, break up large methods, spill based on actual block size, and properly increment _diskBytesSpilled
1 parent a3da81c commit 1068005

File tree

3 files changed

+90
-76
lines changed

3 files changed

+90
-76
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/MixedShuffleReader.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ import org.apache.spark.{TaskContext, Logging}
2121
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
2222
import org.apache.spark.shuffle.hash.HashShuffleReader
2323

24+
/**
25+
* ShuffleReader that chooses SortShuffleReader or HashShuffleReader depending on whether there is
26+
* a key ordering.
27+
*/
2428
private[spark] class MixedShuffleReader[K, C](
2529
handle: BaseShuffleHandle[K, _, C],
2630
startPartition: Int,

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleReader.scala

Lines changed: 85 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.shuffle.sort
1919

20+
import java.io.File
2021
import java.io.FileOutputStream
2122
import java.nio.ByteBuffer
2223
import java.util.Comparator
@@ -38,7 +39,7 @@ import org.apache.spark.util.collection.{MergeUtil, TieredDiskMerger}
3839
* map output block.
3940
*
4041
* As blocks are fetched, we store them in memory until we fail to acquire space from the
41-
* ShuffleMemoryManager. When this occurs, we merge the in-memory blocks to disk and go back to
42+
* ShuffleMemoryManager. When this occurs, we merge some in-memory blocks to disk and go back to
4243
* fetching.
4344
*
4445
* TieredDiskMerger is responsible for managing the merged on-disk blocks and for supplying an
@@ -70,10 +71,10 @@ private[spark] class SortShuffleReader[K, C](
7071
private val inMemoryBlocks = new Queue[MemoryShuffleBlock]()
7172

7273
/**
73-
* Maintain the relation between shuffle block and its size. The reason we should maintain this
74-
* is that the request shuffle block size is not equal to the result size because of
75-
* compression of size. So here we should maintain this make sure the correctness of our
76-
* algorithm.
74+
* Maintain block manager and reported size of each shuffle block. The block manager is used for
75+
* error reporting. The reported size, which, because of size compression, may be slightly
76+
* different than the size of the actual fetched block, is used for calculating how many blocks
77+
* to spill.
7778
*/
7879
private val shuffleBlockMap = new HashMap[ShuffleBlockId, (BlockManagerId, Long)]()
7980

@@ -97,7 +98,7 @@ private[spark] class SortShuffleReader[K, C](
9798
private var _memoryBytesSpilled: Long = 0L
9899
private var _diskBytesSpilled: Long = 0L
99100

100-
/** number of bytes left to fetch */
101+
/** Number of bytes left to fetch */
101102
private var unfetchedBytes: Long = 0L
102103

103104
def memoryBytesSpilled: Long = _memoryBytesSpilled
@@ -131,7 +132,7 @@ private[spark] class SortShuffleReader[K, C](
131132
val granted = shuffleMemoryManager.tryToAcquire(blockSize)
132133
if (granted >= blockSize) {
133134
if (blockData.isDirect) {
134-
// If the memory shuffle block is allocated on direct buffer, copy it on heap,
135+
// If the shuffle block is allocated on a direct buffer, copy it to an on-heap buffer,
135136
// otherwise off heap memory will be increased out of control.
136137
val onHeapBuffer = ByteBuffer.allocate(blockSize.toInt)
137138
onHeapBuffer.put(blockData.nioByteBuffer)
@@ -142,7 +143,7 @@ private[spark] class SortShuffleReader[K, C](
142143
inMemoryBlocks += MemoryShuffleBlock(blockId, blockData)
143144
}
144145
} else {
145-
logDebug(s"Granted $granted memory is not enough to store shuffle block id $blockId, " +
146+
logDebug(s"Granted $granted memory is not enough to store shuffle block $blockId, " +
146147
s"block size $blockSize, spilling in-memory blocks to release the memory")
147148

148149
shuffleMemoryManager.release(granted)
@@ -162,7 +163,7 @@ private[spark] class SortShuffleReader[K, C](
162163
val mergedItr =
163164
MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
164165

165-
// Update the spilled info and do cleanup work when task is finished.
166+
// Update the spill metrics and do cleanup work when task is finished.
166167
context.taskMetrics().memoryBytesSpilled += memoryBytesSpilled
167168
context.taskMetrics().diskBytesSpilled += diskBytesSpilled
168169

@@ -182,95 +183,105 @@ private[spark] class SortShuffleReader[K, C](
182183
new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2)))
183184
}
184185

186+
/**
187+
* Called when we've failed to acquire memory for a block we've just fetched. Figure out how many
188+
* blocks to spill and then spill them.
189+
*/
185190
private def spillInMemoryBlocks(tippingBlock: MemoryShuffleBlock): Unit = {
186-
// Write merged blocks to disk
187191
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock()
188192

189-
def releaseTempShuffleMemory(blocks: ArrayBuffer[MemoryShuffleBlock]): Unit = {
190-
for (block <- blocks) {
191-
block.blockData.release()
192-
if (block != tippingBlock) {
193-
shuffleMemoryManager.release(block.blockData.size)
194-
}
195-
}
196-
}
197-
198193
// If the remaining unfetched data would fit inside our current allocation, we don't want to
199194
// waste time spilling blocks beyond the space needed for it.
200-
// We use the request size to calculate the remaining spilled size to make sure the
201-
// correctness, since the request size is slightly different from result block size because
202-
// of size compression.
195+
// Note that the number of unfetchedBytes is not exact, because of the compression used on the
196+
// sizes of map output blocks.
203197
var bytesToSpill = unfetchedBytes
204198
val blocksToSpill = new ArrayBuffer[MemoryShuffleBlock]()
205199
blocksToSpill += tippingBlock
206-
bytesToSpill -= shuffleBlockMap(tippingBlock.blockId.asInstanceOf[ShuffleBlockId])._2
200+
bytesToSpill -= tippingBlock.blockData.size
207201
while (bytesToSpill > 0 && !inMemoryBlocks.isEmpty) {
208202
val block = inMemoryBlocks.dequeue()
209203
blocksToSpill += block
210-
bytesToSpill -= shuffleBlockMap(block.blockId.asInstanceOf[ShuffleBlockId])._2
204+
bytesToSpill -= block.blockData.size
211205
}
212206

213207
_memoryBytesSpilled += blocksToSpill.map(_.blockData.size()).sum
214208

215209
if (blocksToSpill.size > 1) {
216-
val itrGroup = inMemoryBlocksToIterators(blocksToSpill)
217-
val partialMergedItr =
218-
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
219-
val curWriteMetrics = new ShuffleWriteMetrics()
220-
var writer =
221-
blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics)
222-
var success = false
223-
224-
try {
225-
partialMergedItr.foreach(writer.write)
226-
success = true
227-
} finally {
228-
if (!success) {
229-
if (writer != null) {
230-
writer.revertPartialWritesAndClose()
231-
writer = null
232-
}
233-
if (file.exists()) {
234-
file.delete()
235-
}
236-
} else {
237-
writer.commitAndClose()
238-
writer = null
210+
spillMultipleBlocks(file, tmpBlockId, blocksToSpill, tippingBlock)
211+
} else {
212+
spillSingleBlock(file, blocksToSpill.head)
213+
}
214+
215+
tieredMerger.registerOnDiskBlock(tmpBlockId, file)
216+
217+
logInfo(s"Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}")
218+
}
219+
220+
private def spillSingleBlock(file: File, block: MemoryShuffleBlock): Unit = {
221+
val fos = new FileOutputStream(file)
222+
val buffer = block.blockData.nioByteBuffer()
223+
var channel = fos.getChannel
224+
var success = false
225+
226+
try {
227+
while (buffer.hasRemaining) {
228+
channel.write(buffer)
229+
}
230+
success = true
231+
} finally {
232+
if (channel != null) {
233+
channel.close()
234+
channel = null
235+
}
236+
if (!success) {
237+
if (file.exists()) {
238+
file.delete()
239239
}
240-
releaseTempShuffleMemory(blocksToSpill)
240+
} else {
241+
_diskBytesSpilled += file.length()
241242
}
242-
_diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
243+
// When we spill a single block, it's the single tipping block that we never acquired memory
244+
// from the shuffle memory manager for, so we don't need to release any memory from there.
245+
block.blockData.release()
246+
}
247+
}
243248

244-
} else {
245-
val fos = new FileOutputStream(file)
246-
val buffer = blocksToSpill.map(_.blockData.nioByteBuffer()).head
247-
var channel = fos.getChannel
248-
var success = false
249-
250-
try {
251-
while (buffer.hasRemaining) {
252-
channel.write(buffer)
249+
/**
250+
* Merge multiple in-memory blocks to a single on-disk file.
251+
*/
252+
private def spillMultipleBlocks(file: File, tmpBlockId: BlockId,
253+
blocksToSpill: Seq[MemoryShuffleBlock], tippingBlock: MemoryShuffleBlock): Unit = {
254+
val itrGroup = inMemoryBlocksToIterators(blocksToSpill)
255+
val partialMergedItr =
256+
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
257+
val curWriteMetrics = new ShuffleWriteMetrics()
258+
var writer = blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics)
259+
var success = false
260+
261+
try {
262+
partialMergedItr.foreach(writer.write)
263+
success = true
264+
} finally {
265+
if (!success) {
266+
if (writer != null) {
267+
writer.revertPartialWritesAndClose()
268+
writer = null
253269
}
254-
success = true
255-
} finally {
256-
if (channel != null) {
257-
channel.close()
258-
channel = null
270+
if (file.exists()) {
271+
file.delete()
259272
}
260-
if (!success) {
261-
if (file.exists()) {
262-
file.delete()
263-
}
264-
} else {
265-
_diskBytesSpilled = file.length()
273+
} else {
274+
writer.commitAndClose()
275+
writer = null
276+
}
277+
for (block <- blocksToSpill) {
278+
block.blockData.release()
279+
if (block != tippingBlock) {
280+
shuffleMemoryManager.release(block.blockData.size)
266281
}
267-
releaseTempShuffleMemory(blocksToSpill)
268282
}
269283
}
270-
271-
tieredMerger.registerOnDiskBlock(tmpBlockId, file)
272-
273-
logInfo(s"Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}")
284+
_diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
274285
}
275286

276287
private def inMemoryBlocksToIterators(blocks: Seq[MemoryShuffleBlock])

core/src/main/scala/org/apache/spark/util/collection/TieredDiskMerger.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ import org.apache.spark.serializer.Serializer
3030
import org.apache.spark.util.CompletionIterator
3131

3232
/**
33-
* Explain the boundaries of where this starts and why we have second thread
34-
*
3533
* Manages blocks of sorted data on disk that need to be merged together. Carries out a tiered
3634
* merge that will never merge more than spark.shuffle.maxMergeFactor segments at a time. Except for
3735
* the final merge, which merges disk blocks to a returned iterator, TieredDiskMerger merges blocks
@@ -72,6 +70,7 @@ private[spark] class TieredDiskMerger[K, C](
7270

7371
private val mergeFinished = new CountDownLatch(1)
7472

73+
/** Whether more on-disk blocks may come in */
7574
@volatile private var doneRegistering = false
7675

7776
/** Number of bytes spilled on disk */

0 commit comments

Comments
 (0)