Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.hadoop.yarn.server.resourcemanager.{ClientRMService, RMAppMana
import org.apache.hadoop.yarn.server.resourcemanager.ahs.RMApplicationHistoryWriter
import org.apache.hadoop.yarn.server.resourcemanager.metrics.SystemMetricsPublisher
import org.apache.hadoop.yarn.server.resourcemanager.rmapp.RMApp
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.YarnScheduler
import org.apache.hadoop.yarn.server.security.ApplicationACLsManager
import org.apache.hadoop.yarn.util.Records
import org.mockito.ArgumentMatchers.{any, anyBoolean, anyShort, eq => meq}
Expand Down Expand Up @@ -222,11 +223,50 @@ class ClientSuite extends SparkFunSuite with Matchers {
3 -> ("SPARK-SQL", "SPARK-SQL"),
4 -> ("012345678901234567890123", "01234567890123456789"))

// Mock yarn submit application
val yarnClient = mock(classOf[YarnClient])
val rmApps = new ConcurrentHashMap[ApplicationId, RMApp]()
val rmContext = mock(classOf[RMContext])
when(rmContext.getRMApps).thenReturn(rmApps)
val dispatcher = mock(classOf[Dispatcher])
when(rmContext.getDispatcher).thenReturn(dispatcher)
when[EventHandler[_]](dispatcher.getEventHandler).thenReturn(
new EventHandler[Event[_]] {
override def handle(event: Event[_]): Unit = {}
}
)
val writer = mock(classOf[RMApplicationHistoryWriter])
when(rmContext.getRMApplicationHistoryWriter).thenReturn(writer)
val publisher = mock(classOf[SystemMetricsPublisher])
when(rmContext.getSystemMetricsPublisher).thenReturn(publisher)
val yarnScheduler = mock(classOf[YarnScheduler])
val rmAppManager = new RMAppManager(rmContext,
yarnScheduler,
null,
mock(classOf[ApplicationACLsManager]),
new Configuration())
val clientRMService = new ClientRMService(rmContext,
yarnScheduler,
rmAppManager,
null,
null,
null)
clientRMService.init(new Configuration())
when(yarnClient.submitApplication(any())).thenAnswer((invocationOnMock: InvocationOnMock) => {
val subContext = invocationOnMock.getArguments()(0)
.asInstanceOf[ApplicationSubmissionContext]
val request = Records.newRecord(classOf[SubmitApplicationRequest])
request.setApplicationSubmissionContext(subContext)
clientRMService.submitApplication(request)
null
})

// Spark submit application
val appContext = spy(Records.newRecord(classOf[ApplicationSubmissionContext]))
when(appContext.getUnmanagedAM).thenReturn(true)
for ((id, (sourceType, targetType)) <- appTypes) {
val sparkConf = new SparkConf().set("spark.yarn.applicationType", sourceType)
val args = new ClientArguments(Array())

val appContext = spy(Records.newRecord(classOf[ApplicationSubmissionContext]))
val appId = ApplicationId.newInstance(123456, id)
appContext.setApplicationId(appId)
val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse])
Expand All @@ -237,48 +277,8 @@ class ClientSuite extends SparkFunSuite with Matchers {
new YarnClientApplication(getNewApplicationResponse, appContext),
containerLaunchContext)

val yarnClient = mock(classOf[YarnClient])
when(yarnClient.submitApplication(any())).thenAnswer((invocationOnMock: InvocationOnMock) => {
val subContext = invocationOnMock.getArguments()(0)
.asInstanceOf[ApplicationSubmissionContext]
val request = Records.newRecord(classOf[SubmitApplicationRequest])
request.setApplicationSubmissionContext(subContext)

val rmContext = mock(classOf[RMContext])
val conf = mock(classOf[Configuration])
val map = new ConcurrentHashMap[ApplicationId, RMApp]()
when(rmContext.getRMApps).thenReturn(map)
val dispatcher = mock(classOf[Dispatcher])
when(rmContext.getDispatcher).thenReturn(dispatcher)
when[EventHandler[_]](dispatcher.getEventHandler).thenReturn(
new EventHandler[Event[_]] {
override def handle(event: Event[_]): Unit = {}
}
)
val writer = mock(classOf[RMApplicationHistoryWriter])
when(rmContext.getRMApplicationHistoryWriter).thenReturn(writer)
val publisher = mock(classOf[SystemMetricsPublisher])
when(rmContext.getSystemMetricsPublisher).thenReturn(publisher)
when(appContext.getUnmanagedAM).thenReturn(true)

val rmAppManager = new RMAppManager(rmContext,
null,
null,
mock(classOf[ApplicationACLsManager]),
conf)
val clientRMService = new ClientRMService(rmContext,
null,
rmAppManager,
null,
null,
null)
clientRMService.submitApplication(request)

assert(map.get(subContext.getApplicationId).getApplicationType === targetType)
null
})

yarnClient.submitApplication(context)
assert(rmApps.get(appId).getApplicationType === targetType)
}
}

Expand Down