@@ -61,10 +61,10 @@ class ExternalAppendOnlyMap[K, V, C](
6161 blockManager : BlockManager = SparkEnv .get.blockManager,
6262 context : TaskContext = TaskContext .get(),
6363 serializerManager : SerializerManager = SparkEnv .get.serializerManager)
64- extends Iterable [( K , C )]
64+ extends Spillable [ SizeTracker ](context.taskMemoryManager())
6565 with Serializable
6666 with Logging
67- with Spillable [ SizeTracker ] {
67+ with Iterable [( K , C ) ] {
6868
6969 if (context == null ) {
7070 throw new IllegalStateException (
@@ -81,9 +81,7 @@ class ExternalAppendOnlyMap[K, V, C](
8181 this (createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext .get())
8282 }
8383
84- override protected [this ] def taskMemoryManager : TaskMemoryManager = context.taskMemoryManager()
85-
86- private var currentMap = new SizeTrackingAppendOnlyMap [K , C ]
84+ @ volatile private var currentMap = new SizeTrackingAppendOnlyMap [K , C ]
8785 private val spilledMaps = new ArrayBuffer [DiskMapIterator ]
8886 private val sparkConf = SparkEnv .get.conf
8987 private val diskBlockManager = blockManager.diskBlockManager
@@ -117,6 +115,8 @@ class ExternalAppendOnlyMap[K, V, C](
117115 private val keyComparator = new HashComparator [K ]
118116 private val ser = serializer.newInstance()
119117
118+ @ volatile private var readingIterator : SpillableIterator = null
119+
120120 /**
121121 * Number of files this map has spilled so far.
122122 * Exposed for testing.
@@ -182,6 +182,29 @@ class ExternalAppendOnlyMap[K, V, C](
182182 * Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
183183 */
184184 override protected [this ] def spill (collection : SizeTracker ): Unit = {
185+ val inMemoryIterator = currentMap.destructiveSortedIterator(keyComparator)
186+ val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator)
187+ spilledMaps.append(diskMapIterator)
188+ }
189+
190+ /**
191+ * Force to spilling the current in-memory collection to disk to release memory,
192+ * It will be called by TaskMemoryManager when there is not enough memory for the task.
193+ */
194+ override protected [this ] def forceSpill (): Boolean = {
195+ assert(readingIterator != null )
196+ val isSpilled = readingIterator.spill()
197+ if (isSpilled) {
198+ currentMap = null
199+ }
200+ isSpilled
201+ }
202+
203+ /**
204+ * Spill the in-memory Iterator to a temporary file on disk.
205+ */
206+ private [this ] def spillMemoryIteratorToDisk (inMemoryIterator : Iterator [(K , C )])
207+ : DiskMapIterator = {
185208 val (blockId, file) = diskBlockManager.createTempLocalBlock()
186209 curWriteMetrics = new ShuffleWriteMetrics ()
187210 var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
@@ -202,9 +225,8 @@ class ExternalAppendOnlyMap[K, V, C](
202225
203226 var success = false
204227 try {
205- val it = currentMap.destructiveSortedIterator(keyComparator)
206- while (it.hasNext) {
207- val kv = it.next()
228+ while (inMemoryIterator.hasNext) {
229+ val kv = inMemoryIterator.next()
208230 writer.write(kv._1, kv._2)
209231 objectsWritten += 1
210232
@@ -237,7 +259,17 @@ class ExternalAppendOnlyMap[K, V, C](
237259 }
238260 }
239261
240- spilledMaps.append(new DiskMapIterator (file, blockId, batchSizes))
262+ new DiskMapIterator (file, blockId, batchSizes)
263+ }
264+
265+ /**
266+ * Returns a destructive iterator for iterating over the entries of this map.
267+ * If this iterator is forced spill to disk to release memory when there is not enough memory,
268+ * it returns pairs from an on-disk map.
269+ */
270+ def destructiveIterator (inMemoryIterator : Iterator [(K , C )]): Iterator [(K , C )] = {
271+ readingIterator = new SpillableIterator (inMemoryIterator)
272+ readingIterator
241273 }
242274
243275 /**
@@ -250,15 +282,18 @@ class ExternalAppendOnlyMap[K, V, C](
250282 " ExternalAppendOnlyMap.iterator is destructive and should only be called once." )
251283 }
252284 if (spilledMaps.isEmpty) {
253- CompletionIterator [(K , C ), Iterator [(K , C )]](currentMap.iterator, freeCurrentMap())
285+ CompletionIterator [(K , C ), Iterator [(K , C )]](
286+ destructiveIterator(currentMap.iterator), freeCurrentMap())
254287 } else {
255288 new ExternalIterator ()
256289 }
257290 }
258291
259292 private def freeCurrentMap (): Unit = {
260- currentMap = null // So that the memory can be garbage-collected
261- releaseMemory()
293+ if (currentMap != null ) {
294+ currentMap = null // So that the memory can be garbage-collected
295+ releaseMemory()
296+ }
262297 }
263298
264299 /**
@@ -272,8 +307,8 @@ class ExternalAppendOnlyMap[K, V, C](
272307
273308 // Input streams are derived both from the in-memory map and spilled maps on disk
274309 // The in-memory map is sorted in place, while the spilled maps are already in sorted order
275- private val sortedMap = CompletionIterator [(K , C ), Iterator [(K , C )]](
276- currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
310+ private val sortedMap = CompletionIterator [(K , C ), Iterator [(K , C )]](destructiveIterator(
311+ currentMap.destructiveSortedIterator(keyComparator)) , freeCurrentMap())
277312 private val inputStreams = (Seq (sortedMap) ++ spilledMaps).map(it => it.buffered)
278313
279314 inputStreams.foreach { it =>
@@ -532,8 +567,56 @@ class ExternalAppendOnlyMap[K, V, C](
532567 context.addTaskCompletionListener(context => cleanup())
533568 }
534569
570+ private [this ] class SpillableIterator (var upstream : Iterator [(K , C )])
571+ extends Iterator [(K , C )] {
572+
573+ private val SPILL_LOCK = new Object ()
574+
575+ private var nextUpstream : Iterator [(K , C )] = null
576+
577+ private var cur : (K , C ) = readNext()
578+
579+ private var hasSpilled : Boolean = false
580+
581+ def spill (): Boolean = SPILL_LOCK .synchronized {
582+ if (hasSpilled) {
583+ false
584+ } else {
585+ logInfo(s " Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
586+ s " it will release ${org.apache.spark.util.Utils .bytesToString(getUsed())} memory " )
587+ nextUpstream = spillMemoryIteratorToDisk(upstream)
588+ hasSpilled = true
589+ true
590+ }
591+ }
592+
593+ def readNext (): (K , C ) = SPILL_LOCK .synchronized {
594+ if (nextUpstream != null ) {
595+ upstream = nextUpstream
596+ nextUpstream = null
597+ }
598+ if (upstream.hasNext) {
599+ upstream.next()
600+ } else {
601+ null
602+ }
603+ }
604+
605+ override def hasNext (): Boolean = cur != null
606+
607+ override def next (): (K , C ) = {
608+ val r = cur
609+ cur = readNext()
610+ r
611+ }
612+ }
613+
535614 /** Convenience function to hash the given (K, C) pair by the key. */
536615 private def hashKey (kc : (K , C )): Int = ExternalAppendOnlyMap .hash(kc._1)
616+
617+ override def toString (): String = {
618+ this .getClass.getName + " @" + java.lang.Integer .toHexString(this .hashCode())
619+ }
537620}
538621
539622private [spark] object ExternalAppendOnlyMap {
0 commit comments