@@ -24,6 +24,7 @@ import java.lang.ref.WeakReference
2424import java .util .concurrent .ConcurrentHashMap
2525
2626import org .apache .spark .Logging
27+ import java .util .concurrent .atomic .AtomicInteger
2728
2829private [util] case class TimeStampedWeakValue [T ](timestamp : Long , weakValue : WeakReference [T ]) {
2930 def this (timestamp : Long , value : T ) = this (timestamp, new WeakReference [T ](value))
@@ -44,6 +45,12 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea
4445private [spark] class TimeStampedWeakValueHashMap [A , B ]()
4546 extends WrappedJavaHashMap [A , B , A , TimeStampedWeakValue [B ]] with Logging {
4647
48+ /** Number of inserts after which keys whose weak ref values are null will be cleaned */
49+ private val CLEANUP_INTERVAL = 1000
50+
51+ /** Counter for counting the number of inserts */
52+ private val insertCounts = new AtomicInteger (0 )
53+
4754 protected [util] val internalJavaMap : util.Map [A , TimeStampedWeakValue [B ]] = {
4855 new ConcurrentHashMap [A , TimeStampedWeakValue [B ]]()
4956 }
@@ -52,11 +59,21 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
5259 new TimeStampedWeakValueHashMap [K1 , V1 ]()
5360 }
5461
62+ override def += (kv : (A , B )): this .type = {
63+ // Cleanup null value at certain intervals
64+ if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0 ) {
65+ cleanNullValues()
66+ }
67+ super .+= (kv)
68+ }
69+
5570 override def get (key : A ): Option [B ] = {
5671 Option (internalJavaMap.get(key)) match {
5772 case Some (weakValue) =>
5873 val value = weakValue.weakValue.get
59- if (value == null ) cleanupKey(key)
74+ if (value == null ) {
75+ internalJavaMap.remove(key)
76+ }
6077 Option (value)
6178 case None =>
6279 None
@@ -72,16 +89,10 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
7289 }
7390
7491 override def iterator : Iterator [(A , B )] = {
75- val jIterator = internalJavaMap.entrySet().iterator()
76- JavaConversions .asScalaIterator(jIterator).flatMap(kv => {
77- val key = kv.getKey
78- val value = kv.getValue.weakValue.get
79- if (value == null ) {
80- cleanupKey(key)
81- Seq .empty
82- } else {
83- Seq ((key, value))
84- }
92+ val iterator = internalJavaMap.entrySet().iterator()
93+ JavaConversions .asScalaIterator(iterator).flatMap(kv => {
94+ val (key, value) = (kv.getKey, kv.getValue.weakValue.get)
95+ if (value != null ) Seq ((key, value)) else Seq .empty
8596 })
8697 }
8798
@@ -104,8 +115,18 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
104115 }
105116 }
106117
107- private def cleanupKey (key : A ) {
108- // TODO: Consider cleaning up keys to empty weak ref values automatically in future.
118+ /**
119+ * Removes keys whose weak referenced values have become null.
120+ */
121+ private def cleanNullValues () {
122+ val iterator = internalJavaMap.entrySet().iterator()
123+ while (iterator.hasNext) {
124+ val entry = iterator.next()
125+ if (entry.getValue.weakValue.get == null ) {
126+ logDebug(" Removing key " + entry.getKey)
127+ iterator.remove()
128+ }
129+ }
109130 }
110131
111132 private def currentTime = System .currentTimeMillis()
0 commit comments