diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index d83da0d126d89..19ff109b673e1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -80,7 +80,10 @@ class ExternalAppendOnlyMap[K, V, C]( this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) } - @volatile private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + /** + * Exposed for testing + */ + @volatile private[collection] var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager @@ -267,7 +270,7 @@ class ExternalAppendOnlyMap[K, V, C]( */ def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { readingIterator = new SpillableIterator(inMemoryIterator) - readingIterator + readingIterator.toCompletionIterator } /** @@ -280,8 +283,7 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - CompletionIterator[(K, C), Iterator[(K, C)]]( - destructiveIterator(currentMap.iterator), freeCurrentMap()) + destructiveIterator(currentMap.iterator) } else { new ExternalIterator() } @@ -305,8 +307,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator( - currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) + private val sortedMap = destructiveIterator( + currentMap.destructiveSortedIterator(keyComparator)) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -568,13 +570,11 @@ class ExternalAppendOnlyMap[K, V, C]( context.addTaskCompletionListener[Unit](context => cleanup()) } - private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) + private class SpillableIterator(var upstream: Iterator[(K, C)]) extends Iterator[(K, C)] { private val SPILL_LOCK = new Object() - private var nextUpstream: Iterator[(K, C)] = null - private var cur: (K, C) = readNext() private var hasSpilled: Boolean = false @@ -585,17 +585,24 @@ class ExternalAppendOnlyMap[K, V, C]( } else { logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - nextUpstream = spillMemoryIteratorToDisk(upstream) + val nextUpstream = spillMemoryIteratorToDisk(upstream) + assert(!upstream.hasNext) hasSpilled = true + upstream = nextUpstream true } } + private def destroy(): Unit = { + freeCurrentMap() + upstream = Iterator.empty + } + + def toCompletionIterator: CompletionIterator[(K, C), SpillableIterator] = { + CompletionIterator[(K, C), SpillableIterator](this, this.destroy) + } + def readNext(): (K, C) = SPILL_LOCK.synchronized { - if (nextUpstream != null) { - upstream = nextUpstream - nextUpstream = null - } if (upstream.hasNext) { upstream.next() } else { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 35312f2d71131..d542ba0b6640d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -17,14 +17,24 @@ package org.apache.spark.util.collection +import java.util.Objects + import scala.collection.mutable.ArrayBuffer +import scala.ref.WeakReference + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.util.CompletionIterator -class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { +class ExternalAppendOnlyMapSuite extends SparkFunSuite + with LocalSparkContext + with Eventually + with Matchers{ import TestUtils.{assertNotSpilled, assertSpilled} private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS @@ -414,7 +424,112 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("external aggregation updates peak execution memory") { + test("SPARK-22713 spill during iteration leaks internal map") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val it = map.iterator + assert(it.isInstanceOf[CompletionIterator[_, _]]) + // org.apache.spark.util.collection.AppendOnlyMap.destructiveSortedIterator returns + // an instance of an annonymous Iterator class. + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val first50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(map.numSpills == 0) + map.spill(Long.MaxValue, null) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + // assert(map.currentMap == null) + eventually { + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + + val next50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(!it.hasNext) + val keys = (first50Keys ++ next50Keys).sorted + assert(keys == (0 until 100)) + } + + test("drop all references to the underlying map once the iterator is exhausted") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val it = map.iterator + assert( it.isInstanceOf[CompletionIterator[_, _]]) + + + val keys = it.map{ + case (k, vs) => + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + .toList + .sorted + + assert(it.isEmpty) + assert(keys == (0 until 100)) + + assert(map.numSpills == 0) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + assert(map.currentMap == null) + + eventually { + Thread.sleep(500) + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + assert(it.toList.isEmpty) + } + + test("SPARK-22713 external aggregation updates peak execution memory") { val spillThreshold = 1000 val conf = createSparkConf(loadDefaults = false) .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString)