@@ -21,20 +21,24 @@ import java.util.Date
2121import java .util .concurrent .ConcurrentLinkedQueue
2222
2323import scala .collection .JavaConverters ._
24+ import scala .collection .mutable .{HashMap , HashSet }
2425import scala .concurrent .duration ._
2526import scala .io .Source
2627import scala .language .postfixOps
28+ import scala .reflect .ClassTag
2729
2830import org .json4s ._
2931import org .json4s .jackson .JsonMethods ._
32+ import org .mockito .Mockito .{mock , when }
3033import org .scalatest .{BeforeAndAfter , Matchers , PrivateMethodTester }
3134import org .scalatest .concurrent .Eventually
3235import other .supplier .{CustomPersistenceEngine , CustomRecoveryModeFactory }
3336
3437import org .apache .spark .{SecurityManager , SparkConf , SparkFunSuite }
3538import org .apache .spark .deploy ._
3639import org .apache .spark .deploy .DeployMessages ._
37- import org .apache .spark .rpc .{RpcAddress , RpcEndpoint , RpcEnv }
40+ import org .apache .spark .rpc .{RpcAddress , RpcEndpoint , RpcEndpointRef , RpcEnv }
41+ import org .apache .spark .serializer
3842
3943class MasterSuite extends SparkFunSuite
4044 with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter {
@@ -134,6 +138,81 @@ class MasterSuite extends SparkFunSuite
134138 CustomRecoveryModeFactory .instantiationAttempts should be > instantiationAttempts
135139 }
136140
141+ test(" master correctly recover the application" ) {
142+ val conf = new SparkConf (loadDefaults = false )
143+ conf.set(" spark.deploy.recoveryMode" , " CUSTOM" )
144+ conf.set(" spark.deploy.recoveryMode.factory" ,
145+ classOf [FakeRecoveryModeFactory ].getCanonicalName)
146+ conf.set(" spark.master.rest.enabled" , " false" )
147+
148+ val fakeAppInfo = makeAppInfo(1024 )
149+ val fakeWorkerInfo = makeWorkerInfo(8192 , 16 )
150+ val fakeDriverInfo = new DriverInfo (
151+ startTime = 0 ,
152+ id = " test_driver" ,
153+ desc = new DriverDescription (
154+ jarUrl = " " ,
155+ mem = 1024 ,
156+ cores = 1 ,
157+ supervise = false ,
158+ command = new Command (" " , Nil , Map .empty, Nil , Nil , Nil )),
159+ submitDate = new Date ())
160+
161+ // Build the fake recovery data
162+ FakeRecoveryModeFactory .persistentData.put(s " app_ ${fakeAppInfo.id}" , fakeAppInfo)
163+ FakeRecoveryModeFactory .persistentData.put(s " driver_ ${fakeDriverInfo.id}" , fakeDriverInfo)
164+ FakeRecoveryModeFactory .persistentData.put(s " worker_ ${fakeWorkerInfo.id}" , fakeWorkerInfo)
165+
166+ var master : Master = null
167+ try {
168+ master = makeMaster(conf)
169+ master.rpcEnv.setupEndpoint(Master .ENDPOINT_NAME , master)
170+ // Wait until Master recover from checkpoint data.
171+ eventually(timeout(5 seconds), interval(100 milliseconds)) {
172+ master.idToApp.size should be(1 )
173+ }
174+
175+ master.idToApp.keySet should be(Set (fakeAppInfo.id))
176+ getDrivers(master) should be(Set (fakeDriverInfo))
177+ master.workers should be(Set (fakeWorkerInfo))
178+
179+ // Notify Master about the executor and driver info to make it correctly recovered.
180+ val fakeExecutors = List (
181+ new ExecutorDescription (fakeAppInfo.id, 0 , 8 , ExecutorState .RUNNING ),
182+ new ExecutorDescription (fakeAppInfo.id, 0 , 7 , ExecutorState .RUNNING ))
183+
184+ fakeAppInfo.state should be(ApplicationState .UNKNOWN )
185+ fakeWorkerInfo.coresFree should be(16 )
186+ fakeWorkerInfo.coresUsed should be(0 )
187+
188+ master.self.send(MasterChangeAcknowledged (fakeAppInfo.id))
189+ eventually(timeout(1 second), interval(10 milliseconds)) {
190+ // Application state should be WAITING when "MasterChangeAcknowledged" event executed.
191+ fakeAppInfo.state should be(ApplicationState .WAITING )
192+ }
193+
194+ master.self.send(
195+ WorkerSchedulerStateResponse (fakeWorkerInfo.id, fakeExecutors, Seq (fakeDriverInfo.id)))
196+
197+ eventually(timeout(5 seconds), interval(100 milliseconds)) {
198+ getState(master) should be(RecoveryState .ALIVE )
199+ }
200+
201+ // If driver's resource is also counted, free cores should 0
202+ fakeWorkerInfo.coresFree should be(0 )
203+ fakeWorkerInfo.coresUsed should be(16 )
204+ // State of application should be RUNNING
205+ fakeAppInfo.state should be(ApplicationState .RUNNING )
206+ } finally {
207+ if (master != null ) {
208+ master.rpcEnv.shutdown()
209+ master.rpcEnv.awaitTermination()
210+ master = null
211+ FakeRecoveryModeFactory .persistentData.clear()
212+ }
213+ }
214+ }
215+
137216 test(" master/worker web ui available" ) {
138217 implicit val formats = org.json4s.DefaultFormats
139218 val conf = new SparkConf ()
@@ -394,6 +473,9 @@ class MasterSuite extends SparkFunSuite
394473 // ==========================================
395474
396475 private val _scheduleExecutorsOnWorkers = PrivateMethod [Array [Int ]](' scheduleExecutorsOnWorkers )
476+ private val _drivers = PrivateMethod [HashSet [DriverInfo ]](' drivers )
477+ private val _state = PrivateMethod [RecoveryState .Value ](' state )
478+
397479 private val workerInfo = makeWorkerInfo(4096 , 10 )
398480 private val workerInfos = Array (workerInfo, workerInfo, workerInfo)
399481
@@ -412,12 +494,18 @@ class MasterSuite extends SparkFunSuite
412494 val desc = new ApplicationDescription (
413495 " test" , maxCores, memoryPerExecutorMb, null , " " , None , None , coresPerExecutor)
414496 val appId = System .currentTimeMillis.toString
415- new ApplicationInfo (0 , appId, desc, new Date , null , Int .MaxValue )
497+ val endpointRef = mock(classOf [RpcEndpointRef ])
498+ val mockAddress = mock(classOf [RpcAddress ])
499+ when(endpointRef.address).thenReturn(mockAddress)
500+ new ApplicationInfo (0 , appId, desc, new Date , endpointRef, Int .MaxValue )
416501 }
417502
418503 private def makeWorkerInfo (memoryMb : Int , cores : Int ): WorkerInfo = {
419504 val workerId = System .currentTimeMillis.toString
420- new WorkerInfo (workerId, " host" , 100 , cores, memoryMb, null , " http://localhost:80" )
505+ val endpointRef = mock(classOf [RpcEndpointRef ])
506+ val mockAddress = mock(classOf [RpcAddress ])
507+ when(endpointRef.address).thenReturn(mockAddress)
508+ new WorkerInfo (workerId, " host" , 100 , cores, memoryMb, endpointRef, " http://localhost:80" )
421509 }
422510
423511 private def scheduleExecutorsOnWorkers (
@@ -499,4 +587,40 @@ class MasterSuite extends SparkFunSuite
499587 assert(receivedMasterAddress === RpcAddress (" localhost2" , 10000 ))
500588 }
501589 }
590+
591+ private def getDrivers (master : Master ): HashSet [DriverInfo ] = {
592+ master.invokePrivate(_drivers())
593+ }
594+
595+ private def getState (master : Master ): RecoveryState .Value = {
596+ master.invokePrivate(_state())
597+ }
598+ }
599+
600+ private class FakeRecoveryModeFactory (conf : SparkConf , ser : serializer.Serializer )
601+ extends StandaloneRecoveryModeFactory (conf, ser) {
602+ import FakeRecoveryModeFactory .persistentData
603+
604+ override def createPersistenceEngine (): PersistenceEngine = new PersistenceEngine {
605+
606+ override def unpersist (name : String ): Unit = {
607+ persistentData.remove(name)
608+ }
609+
610+ override def persist (name : String , obj : Object ): Unit = {
611+ persistentData(name) = obj
612+ }
613+
614+ override def read [T : ClassTag ](prefix : String ): Seq [T ] = {
615+ persistentData.filter(_._1.startsWith(prefix)).map(_._2.asInstanceOf [T ]).toSeq
616+ }
617+ }
618+
619+ override def createLeaderElectionAgent (master : LeaderElectable ): LeaderElectionAgent = {
620+ new MonarchyLeaderAgent (master)
621+ }
622+ }
623+
624+ private object FakeRecoveryModeFactory {
625+ val persistentData = new HashMap [String , Object ]()
502626}
0 commit comments