diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index a354f44a1be1..cf957ffcec9a 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -185,6 +185,8 @@ class BarrierTaskContext private[spark] ( taskContext.getMetricsSources(sourceName) } + override def resources(): Map[String, ResourceInformation] = taskContext.resources() + override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted() override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason() diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 15f1730ca483..227f4a5bb3a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -507,6 +507,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + /** + * Get task resource requirements. + */ + private[spark] def getTaskResourceRequirements(): Map[String, Int] = { + getAllWithPrefix(SPARK_TASK_RESOURCE_PREFIX) + .withFilter { case (k, v) => k.endsWith(SPARK_RESOURCE_COUNT_SUFFIX)} + .map { case (k, v) => (k.dropRight(SPARK_RESOURCE_COUNT_SUFFIX.length), v.toInt)}.toMap + } + /** * Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. @@ -603,30 +612,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria require(executorTimeoutThresholdMs > executorHeartbeatIntervalMs, "The value of " + s"${networkTimeout}=${executorTimeoutThresholdMs}ms must be no less than the value of " + s"${EXECUTOR_HEARTBEAT_INTERVAL.key}=${executorHeartbeatIntervalMs}ms.") - - // Make sure the executor resources were specified and are large enough if - // any task resources were specified. - val taskResourcesAndCount = - getAllWithPrefixAndSuffix(SPARK_TASK_RESOURCE_PREFIX, SPARK_RESOURCE_COUNT_SUFFIX).toMap - val executorResourcesAndCounts = - getAllWithPrefixAndSuffix(SPARK_EXECUTOR_RESOURCE_PREFIX, SPARK_RESOURCE_COUNT_SUFFIX).toMap - - taskResourcesAndCount.foreach { case (rName, taskCount) => - val execCount = executorResourcesAndCounts.get(rName).getOrElse( - throw new SparkException( - s"The executor resource config: " + - s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " + - "needs to be specified since a task requirement config: " + - s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} was specified") - ) - if (execCount.toLong < taskCount.toLong) { - throw new SparkException( - s"The executor resource config: " + - s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " + - s"= $execCount has to be >= the task config: " + - s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} = $taskCount") - } - } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6266ce62669e..b00bb9add5c8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2707,27 +2707,73 @@ object SparkContext extends Logging { // When running locally, don't try to re-execute tasks on failure. val MAX_LOCAL_TASK_FAILURES = 1 - // SPARK-26340: Ensure that executor's core num meets at least one task requirement. - def checkCpusPerTask( - clusterMode: Boolean, - maxCoresPerExecutor: Option[Int]): Unit = { - val cpusPerTask = sc.conf.get(CPUS_PER_TASK) - if (clusterMode && sc.conf.contains(EXECUTOR_CORES)) { - if (sc.conf.get(EXECUTOR_CORES) < cpusPerTask) { - throw new SparkException(s"${CPUS_PER_TASK.key}" + - s" must be <= ${EXECUTOR_CORES.key} when run on $master.") + // Ensure that executor's resources satisfies one or more tasks requirement. + def checkResourcesPerTask(clusterMode: Boolean, executorCores: Option[Int]): Unit = { + val taskCores = sc.conf.get(CPUS_PER_TASK) + val execCores = if (clusterMode) { + executorCores.getOrElse(sc.conf.get(EXECUTOR_CORES)) + } else { + executorCores.get + } + + // Number of cores per executor must meet at least one task requirement. + if (execCores < taskCores) { + throw new SparkException(s"The number of cores per executor (=$execCores) has to be >= " + + s"the task config: ${CPUS_PER_TASK.key} = $taskCores when run on $master.") + } + + // Calculate the max slots each executor can provide based on resources available on each + // executor and resources required by each task. + val taskResourcesAndCount = sc.conf.getTaskResourceRequirements() + val executorResourcesAndCounts = sc.conf.getAllWithPrefixAndSuffix( + SPARK_EXECUTOR_RESOURCE_PREFIX, SPARK_RESOURCE_COUNT_SUFFIX).toMap + var numSlots = execCores / taskCores + var limitingResourceName = "CPU" + taskResourcesAndCount.foreach { case (rName, taskCount) => + // Make sure the executor resources were specified through config. + val execCount = executorResourcesAndCounts.getOrElse(rName, + throw new SparkException( + s"The executor resource config: " + + s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " + + "needs to be specified since a task requirement config: " + + s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} was specified") + ) + // Make sure the executor resources are large enough to launch at least one task. + if (execCount.toLong < taskCount.toLong) { + throw new SparkException( + s"The executor resource config: " + + s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " + + s"= $execCount has to be >= the task config: " + + s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} = $taskCount") + } + // Compare and update the max slots each executor can provide. + val resourceNumSlots = execCount.toInt / taskCount + if (resourceNumSlots < numSlots) { + numSlots = resourceNumSlots + limitingResourceName = rName } - } else if (maxCoresPerExecutor.isDefined) { - if (maxCoresPerExecutor.get < cpusPerTask) { - throw new SparkException(s"Only ${maxCoresPerExecutor.get} cores available per executor" + - s" when run on $master, and ${CPUS_PER_TASK.key} must be <= it.") + } + // There have been checks above to make sure the executor resources were specified and are + // large enough if any task resources were specified. + taskResourcesAndCount.foreach { case (rName, taskCount) => + val execCount = executorResourcesAndCounts(rName) + if (taskCount.toInt * numSlots < execCount.toInt) { + val message = s"The configuration of resource: $rName (exec = ${execCount.toInt}, " + + s"task = ${taskCount}) will result in wasted resources due to resource " + + s"${limitingResourceName} limiting the number of runnable tasks per executor to: " + + s"${numSlots}. Please adjust your configuration." + if (Utils.isTesting) { + throw new SparkException(message) + } else { + logWarning(message) + } } } } master match { case "local" => - checkCpusPerTask(clusterMode = false, Some(1)) + checkResourcesPerTask(clusterMode = false, Some(1)) val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) @@ -2740,7 +2786,7 @@ object SparkContext extends Logging { if (threadCount <= 0) { throw new SparkException(s"Asked to run locally with $threadCount threads") } - checkCpusPerTask(clusterMode = false, Some(threadCount)) + checkResourcesPerTask(clusterMode = false, Some(threadCount)) val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) @@ -2751,14 +2797,14 @@ object SparkContext extends Logging { // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt - checkCpusPerTask(clusterMode = false, Some(threadCount)) + checkResourcesPerTask(clusterMode = false, Some(threadCount)) val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) case SPARK_REGEX(sparkUrl) => - checkCpusPerTask(clusterMode = true, None) + checkResourcesPerTask(clusterMode = true, None) val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls) @@ -2766,7 +2812,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - checkCpusPerTask(clusterMode = true, Some(coresPerSlave.toInt)) + checkResourcesPerTask(clusterMode = true, Some(coresPerSlave.toInt)) // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt if (sc.executorMemory > memoryPerSlaveInt) { @@ -2787,7 +2833,7 @@ object SparkContext extends Logging { (backend, scheduler) case masterUrl => - checkCpusPerTask(clusterMode = true, None) + checkResourcesPerTask(clusterMode = true, None) val cm = getClusterManager(masterUrl) match { case Some(clusterMgr) => clusterMgr case None => throw new SparkException("Could not parse Master URL: '" + master + "'") diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 959f246f3f9f..803167ee95aa 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.Serializable import java.util.Properties -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Evolving} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source @@ -176,6 +176,13 @@ abstract class TaskContext extends Serializable { */ def getLocalProperty(key: String): String + /** + * Resources allocated to the task. The key is the resource name and the value is information + * about the resource. Please refer to [[ResourceInformation]] for specifics. + */ + @Evolving + def resources(): Map[String, ResourceInformation] + @DeveloperApi def taskMetrics(): TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 76296c5d0abd..8e40b7f1affc 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -51,7 +51,8 @@ private[spark] class TaskContextImpl( localProperties: Properties, @transient private val metricsSystem: MetricsSystem, // The default value is only used in tests. - override val taskMetrics: TaskMetrics = TaskMetrics.empty) + override val taskMetrics: TaskMetrics = TaskMetrics.empty, + override val resources: Map[String, ResourceInformation] = Map.empty) extends TaskContext with Logging { diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index c97b10ee63b1..d306eed757b6 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -38,6 +38,7 @@ import com.google.common.io.{ByteStreams, Files} import org.apache.log4j.PropertyConfigurator import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -311,6 +312,16 @@ private[spark] object TestUtils { current ++ current.filter(_.isDirectory).flatMap(recursiveList) } + /** + * Set task resource requirement. + */ + def setTaskResourceRequirement( + conf: SparkConf, + resourceName: String, + resourceCount: Int): SparkConf = { + val key = s"${SPARK_TASK_RESOURCE_PREFIX}${resourceName}${SPARK_RESOURCE_COUNT_SUFFIX}" + conf.set(key, resourceCount.toString) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b262c235d06c..13bf8a97ef2c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -66,6 +66,13 @@ private[spark] class CoarseGrainedExecutorBackend( // to be changed so that we don't share the serializer instance across threads private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() + /** + * Map each taskId to the information about the resource allocated to it, Please refer to + * [[ResourceInformation]] for specifics. + * Exposed for testing only. + */ + private[executor] val taskResources = new mutable.HashMap[Long, Map[String, ResourceInformation]] + override def onStart() { logInfo("Connecting to driver: " + driverUrl) val resources = parseOrFindResources(resourcesFile) @@ -151,6 +158,7 @@ private[spark] class CoarseGrainedExecutorBackend( } else { val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) + taskResources(taskDesc.taskId) = taskDesc.resources executor.launchTask(this, taskDesc) } @@ -197,7 +205,11 @@ private[spark] class CoarseGrainedExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - val msg = StatusUpdate(executorId, taskId, state, data) + val resources = taskResources.getOrElse(taskId, Map.empty[String, ResourceInformation]) + val msg = StatusUpdate(executorId, taskId, state, data, resources) + if (TaskState.isFinished(state)) { + taskResources.remove(taskId) + } driver match { case Some(driverRef) => driverRef.send(msg) case None => logWarning(s"Drop $msg because has not yet connected to driver") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index cc3cc1604d68..2c035285c08f 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -422,7 +422,8 @@ private[spark] class Executor( val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, - metricsSystem = env.metricsSystem) + metricsSystem = env.metricsSystem, + resources = taskDescription.resources) threwException = false res } { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorResourceInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorResourceInfo.scala new file mode 100644 index 000000000000..c75931d53b4b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorResourceInfo.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.util.collection.OpenHashMap + +/** + * Class to hold information about a type of Resource on an Executor. This information is managed + * by SchedulerBackend, and TaskScheduler shall schedule tasks on idle Executors based on the + * information. + * Please note that this class is intended to be used in a single thread. + * @param name Resource name + * @param addresses Resource addresses provided by the executor + */ +private[spark] class ExecutorResourceInfo( + val name: String, + addresses: Seq[String]) extends Serializable { + + /** + * Map from an address to its availability, the value `true` means the address is available, + * while value `false` means the address is assigned. + * TODO Use [[OpenHashMap]] instead to gain better performance. + */ + private val addressAvailabilityMap = mutable.HashMap(addresses.map(_ -> true): _*) + + /** + * Sequence of currently available resource addresses. + */ + def availableAddrs: Seq[String] = addressAvailabilityMap.flatMap { case (addr, available) => + if (available) Some(addr) else None + }.toSeq + + /** + * Sequence of currently assigned resource addresses. + * Exposed for testing only. + */ + private[scheduler] def assignedAddrs: Seq[String] = addressAvailabilityMap + .flatMap { case (addr, available) => + if (!available) Some(addr) else None + }.toSeq + + /** + * Acquire a sequence of resource addresses (to a launched task), these addresses must be + * available. When the task finishes, it will return the acquired resource addresses. + * Throw an Exception if an address is not available or doesn't exist. + */ + def acquire(addrs: Seq[String]): Unit = { + addrs.foreach { address => + if (!addressAvailabilityMap.contains(address)) { + throw new SparkException(s"Try to acquire an address that doesn't exist. $name address " + + s"$address doesn't exist.") + } + val isAvailable = addressAvailabilityMap(address) + if (isAvailable) { + addressAvailabilityMap(address) = false + } else { + throw new SparkException(s"Try to acquire an address that is not available. $name " + + s"address $address is not available.") + } + } + } + + /** + * Release a sequence of resource addresses, these addresses must have been assigned. Resource + * addresses are released when a task has finished. + * Throw an Exception if an address is not assigned or doesn't exist. + */ + def release(addrs: Seq[String]): Unit = { + addrs.foreach { address => + if (!addressAvailabilityMap.contains(address)) { + throw new SparkException(s"Try to release an address that doesn't exist. $name address " + + s"$address doesn't exist.") + } + val isAvailable = addressAvailabilityMap(address) + if (!isAvailable) { + addressAvailabilityMap(address) = true + } else { + throw new SparkException(s"Try to release an address that is not assigned. $name " + + s"address $address is not assigned.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 29b4380a9c33..bb44e9ada381 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -74,12 +74,14 @@ private[spark] abstract class Task[T]( * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) + * @param resources other host resources (like gpus) that this task attempt can access * @return the result of the task along with updates of Accumulators. */ final def run( taskAttemptId: Long, attemptNumber: Int, - metricsSystem: MetricsSystem): T = { + metricsSystem: MetricsSystem, + resources: Map[String, ResourceInformation]): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether // the stage is barrier. @@ -92,7 +94,8 @@ private[spark] abstract class Task[T]( taskMemoryManager, localProperties, metricsSystem, - metrics) + metrics, + resources) context = if (isBarrier) { new BarrierTaskContext(taskContext) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index bb4a4442b943..c29ee0619e5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -23,8 +23,10 @@ import java.nio.charset.StandardCharsets import java.util.Properties import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, Map} +import scala.collection.immutable +import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import org.apache.spark.ResourceInformation import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** @@ -54,6 +56,7 @@ private[spark] class TaskDescription( val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, + val resources: immutable.Map[String, ResourceInformation], val serializedTask: ByteBuffer) { override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) @@ -62,12 +65,23 @@ private[spark] class TaskDescription( private[spark] object TaskDescription { private def serializeStringLongMap(map: Map[String, Long], dataOut: DataOutputStream): Unit = { dataOut.writeInt(map.size) - for ((key, value) <- map) { + map.foreach { case (key, value) => dataOut.writeUTF(key) dataOut.writeLong(value) } } + private def serializeResources(map: immutable.Map[String, ResourceInformation], + dataOut: DataOutputStream): Unit = { + dataOut.writeInt(map.size) + map.foreach { case (key, value) => + dataOut.writeUTF(key) + dataOut.writeUTF(value.name) + dataOut.writeInt(value.addresses.size) + value.addresses.foreach(dataOut.writeUTF(_)) + } + } + def encode(taskDescription: TaskDescription): ByteBuffer = { val bytesOut = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(bytesOut) @@ -95,6 +109,9 @@ private[spark] object TaskDescription { dataOut.write(bytes) } + // Write resources. + serializeResources(taskDescription.resources, dataOut) + // Write the task. The task is already serialized, so write it directly to the byte buffer. Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut) @@ -106,12 +123,35 @@ private[spark] object TaskDescription { private def deserializeStringLongMap(dataIn: DataInputStream): HashMap[String, Long] = { val map = new HashMap[String, Long]() val mapSize = dataIn.readInt() - for (i <- 0 until mapSize) { + var i = 0 + while (i < mapSize) { map(dataIn.readUTF()) = dataIn.readLong() + i += 1 } map } + private def deserializeResources(dataIn: DataInputStream): + immutable.Map[String, ResourceInformation] = { + val map = new HashMap[String, ResourceInformation]() + val mapSize = dataIn.readInt() + var i = 0 + while (i < mapSize) { + val resType = dataIn.readUTF() + val name = dataIn.readUTF() + val numIdentifier = dataIn.readInt() + val identifiers = new ArrayBuffer[String](numIdentifier) + var j = 0 + while (j < numIdentifier) { + identifiers += dataIn.readUTF() + j += 1 + } + map(resType) = new ResourceInformation(name, identifiers.toArray) + i += 1 + } + map.toMap + } + def decode(byteBuffer: ByteBuffer): TaskDescription = { val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer)) val taskId = dataIn.readLong() @@ -138,10 +178,13 @@ private[spark] object TaskDescription { properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8)) } + // Read resources. + val resources = deserializeResources(dataIn) + // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles, - taskJars, properties, serializedTask) + taskJars, properties, resources, serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 532eb322769a..cf07847190f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -22,7 +22,7 @@ import java.util.{Locale, Timer, TimerTask} import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap, HashSet} import scala.util.Random import org.apache.spark._ @@ -92,6 +92,9 @@ private[spark] class TaskSchedulerImpl( // CPUs to request per task val CPUS_PER_TASK = conf.get(config.CPUS_PER_TASK) + // Resources to request per task + val resourcesPerTask = conf.getTaskResourceRequirements() + // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. Protected by `this` private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] @@ -327,6 +330,7 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], + availableResources: Array[Map[String, Buffer[String]]], tasks: IndexedSeq[ArrayBuffer[TaskDescription]], addressesWithDescs: ArrayBuffer[(String, TaskDescription)]) : Boolean = { var launchedTask = false @@ -335,9 +339,10 @@ private[spark] class TaskSchedulerImpl( for (i <- 0 until shuffledOffers.size) { val execId = shuffledOffers(i).executorId val host = shuffledOffers(i).host - if (availableCpus(i) >= CPUS_PER_TASK) { + if (availableCpus(i) >= CPUS_PER_TASK && + resourcesMeetTaskRequirements(availableResources(i))) { try { - for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { + for (task <- taskSet.resourceOffer(execId, host, maxLocality, availableResources(i))) { tasks(i) += task val tid = task.taskId taskIdToTaskSetManager.put(tid, taskSet) @@ -345,6 +350,15 @@ private[spark] class TaskSchedulerImpl( executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) + task.resources.foreach { case (rName, rInfo) => + // Remove the first n elements from availableResources addresses, these removed + // addresses are the same as that we allocated in taskSet.resourceOffer() since it's + // synchronized. We don't remove the exact addresses allocated because the current + // approach produces the identical result with less time complexity. + availableResources(i).getOrElse(rName, + throw new SparkException(s"Try to acquire resource $rName that doesn't exist.")) + .remove(0, rInfo.addresses.size) + } // Only update hosts for a barrier task. if (taskSet.isBarrier) { // The executor address is expected to be non empty. @@ -364,6 +378,15 @@ private[spark] class TaskSchedulerImpl( launchedTask } + /** + * Check whether the resources from the WorkerOffer are enough to run at least one task. + */ + private def resourcesMeetTaskRequirements(resources: Map[String, Buffer[String]]): Boolean = { + resourcesPerTask.forall { case (rName, rNum) => + resources.contains(rName) && resources(rName).size >= rNum + } + } + /** * Called by cluster manager to offer resources on slaves. We respond by asking our active task * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so @@ -405,6 +428,7 @@ private[spark] class TaskSchedulerImpl( val shuffledOffers = shuffleOffers(filteredOffers) // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) + val availableResources = shuffledOffers.map(_.resources).toArray val availableCpus = shuffledOffers.map(o => o.cores).toArray val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum val sortedTaskSets = rootPool.getSortedTaskSetQueue @@ -436,7 +460,8 @@ private[spark] class TaskSchedulerImpl( var launchedTaskAtCurrentMaxLocality = false do { launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet, - currentMaxLocality, shuffledOffers, availableCpus, tasks, addressesWithDescs) + currentMaxLocality, shuffledOffers, availableCpus, + availableResources, tasks, addressesWithDescs) launchedAnyTask |= launchedTaskAtCurrentMaxLocality } while (launchedTaskAtCurrentMaxLocality) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 52323b3331d7..6f2b982d1793 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -21,6 +21,7 @@ import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.immutable.Map import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.math.max import scala.util.control.NonFatal @@ -467,7 +468,8 @@ private[spark] class TaskSetManager( def resourceOffer( execId: String, host: String, - maxLocality: TaskLocality.TaskLocality) + maxLocality: TaskLocality.TaskLocality, + availableResources: Map[String, Seq[String]] = Map.empty) : Option[TaskDescription] = { val offerBlacklisted = taskSetBlacklistHelperOpt.exists { blacklist => @@ -532,6 +534,15 @@ private[spark] class TaskSetManager( logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " + s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit()} bytes)") + val extraResources = sched.resourcesPerTask.map { case (rName, rNum) => + val rAddresses = availableResources.getOrElse(rName, Seq.empty) + assert(rAddresses.size >= rNum, s"Required $rNum $rName addresses, but only " + + s"${rAddresses.size} available.") + // We'll drop the allocated addresses later inside TaskSchedulerImpl. + val allocatedAddresses = rAddresses.take(rNum) + (rName, new ResourceInformation(rName, allocatedAddresses.toArray)) + } + sched.dagScheduler.taskStarted(task, info) new TaskDescription( taskId, @@ -543,6 +554,7 @@ private[spark] class TaskSetManager( addedFiles, addedJars, task.localProperties, + extraResources, serializedTask) } } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 6ec74913e42f..522dbfa9457b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.Buffer + /** * Represents free resources available on an executor. */ @@ -27,4 +29,5 @@ case class WorkerOffer( cores: Int, // `address` is an optional hostPort string, it provide more useful information than `host` // when multiple executors are launched on the same host. - address: Option[String] = None) + address: Option[String] = None, + resources: Map[String, Buffer[String]] = Map.empty) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 89425e702677..82d51f8a169a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -69,14 +69,19 @@ private[spark] object CoarseGrainedClusterMessages { resources: Map[String, ResourceInformation]) extends CoarseGrainedClusterMessage - case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends CoarseGrainedClusterMessage + case class StatusUpdate( + executorId: String, + taskId: Long, + state: TaskState, + data: SerializableBuffer, + resources: Map[String, ResourceInformation] = Map.empty) + extends CoarseGrainedClusterMessage object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ - def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer) - : StatusUpdate = { - StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) + def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer, + resources: Map[String, ResourceInformation]): StatusUpdate = { + StatusUpdate(executorId, taskId, state, new SerializableBuffer(data), resources) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f7cf212d0bfe..9f535889193f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.Future @@ -139,12 +140,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def receive: PartialFunction[Any, Unit] = { - case StatusUpdate(executorId, taskId, state, data) => + case StatusUpdate(executorId, taskId, state, data, resources) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { executorDataMap.get(executorId) match { case Some(executorInfo) => executorInfo.freeCores += scheduler.CPUS_PER_TASK + resources.foreach { case (k, v) => + executorInfo.resourcesInfo.get(k).foreach { r => + r.release(v.addresses) + } + } makeOffers(executorId) case None => // Ignoring the update since we don't know about the executor. @@ -209,8 +215,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) + val resourcesInfo = resources.map{ case (k, v) => + (v.name, new ExecutorResourceInfo(v.name, v.addresses))} val data = new ExecutorData(executorRef, executorAddress, hostname, - cores, cores, logUrlHandler.applyPattern(logUrls, attributes), attributes) + cores, cores, logUrlHandler.applyPattern(logUrls, attributes), attributes, + resourcesInfo) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -263,7 +272,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores, - Some(executorData.executorAddress.hostPort)) + Some(executorData.executorAddress.hostPort), + executorData.resourcesInfo.map { case (rName, rInfo) => + (rName, rInfo.availableAddrs.toBuffer) + }) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -289,7 +301,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorData = executorDataMap(executorId) val workOffers = IndexedSeq( new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores, - Some(executorData.executorAddress.hostPort))) + Some(executorData.executorAddress.hostPort), + executorData.resourcesInfo.map { case (rName, rInfo) => + (rName, rInfo.availableAddrs.toBuffer) + })) scheduler.resourceOffers(workOffers) } else { Seq.empty @@ -324,7 +339,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } else { val executorData = executorDataMap(task.executorId) + // Do resources allocation here. The allocated resources will get released after the task + // finishes. executorData.freeCores -= scheduler.CPUS_PER_TASK + task.resources.foreach { case (rName, rInfo) => + assert(executorData.resourcesInfo.contains(rName)) + executorData.resourcesInfo(rName).acquire(rInfo.addresses) + } logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + s"${executorData.executorHost}.") @@ -525,6 +546,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp }.sum } + // this function is for testing only + def getExecutorAvailableResources(executorId: String): Map[String, ExecutorResourceInfo] = { + executorDataMap.get(executorId).map(_.resourcesInfo).getOrElse(Map.empty) + } + /** * Request an additional number of executors from the cluster manager. * @return whether the request is acknowledged. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index ebe1c1eb0a35..435365d5b6e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.scheduler.ExecutorResourceInfo /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. @@ -27,6 +28,7 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor + * @param resourcesInfo The information of the currently available resources on the executor */ private[cluster] class ExecutorData( val executorEndpoint: RpcEndpointRef, @@ -35,5 +37,6 @@ private[cluster] class ExecutorData( var freeCores: Int, override val totalCores: Int, override val logUrlMap: Map[String, String], - override val attributes: Map[String, String] + override val attributes: Map[String, String], + val resourcesInfo: Map[String, ExecutorResourceInfo] ) extends ExecutorInfo(executorHost, totalCores, logUrlMap, attributes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index fde2a328f02f..cbcc5310a59f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,6 +81,7 @@ private[spark] class LocalEndpoint( } def reviveOffers() { + // local mode doesn't support extra resources like GPUs right now val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores, Some(rpcEnv.address.hostPort))) for (task <- scheduler.resourceOffers(offers).flatten) { diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index f8e233a05a44..62a0b85915ef 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -40,6 +40,10 @@ public static void test() { tc.stageId(); tc.stageAttemptNumber(); tc.taskAttemptId(); + tc.resources(); + tc.taskMetrics(); + tc.taskMemoryManager(); + tc.getLocalProperties(); } /** diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager index cf8565c74e95..1c78f1a01900 100644 --- a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager +++ b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -1,3 +1,4 @@ org.apache.spark.scheduler.DummyExternalClusterManager org.apache.spark.scheduler.MockExternalClusterManager org.apache.spark.DummyLocalExternalClusterManager +org.apache.spark.scheduler.CSMockExternalClusterManager diff --git a/core/src/test/scala/org/apache/spark/ResourceName.scala b/core/src/test/scala/org/apache/spark/ResourceName.scala new file mode 100644 index 000000000000..6efe064a7773 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ResourceName.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +private[spark] object ResourceName { + // known types of resources + final val GPU: String = "gpu" + final val FPGA: String = "fpga" +} diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 965c7df69cfe..6978f303b376 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -25,6 +25,7 @@ import scala.util.{Random, Try} import com.esotericsoftware.kryo.Kryo +import org.apache.spark.ResourceName._ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.History._ import org.apache.spark.internal.config.Kryo._ @@ -446,6 +447,29 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(thrown.getMessage.contains(key)) } } + + test("get task resource requirement from config") { + val conf = new SparkConf() + conf.set(SPARK_TASK_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "2") + conf.set(SPARK_TASK_RESOURCE_PREFIX + FPGA + SPARK_RESOURCE_COUNT_SUFFIX, "1") + var taskResourceRequirement = conf.getTaskResourceRequirements() + assert(taskResourceRequirement.size == 2) + assert(taskResourceRequirement(GPU) == 2) + assert(taskResourceRequirement(FPGA) == 1) + + conf.remove(SPARK_TASK_RESOURCE_PREFIX + FPGA + SPARK_RESOURCE_COUNT_SUFFIX) + // Ignore invalid prefix + conf.set("spark.invalid.prefix" + FPGA + SPARK_RESOURCE_COUNT_SUFFIX, "1") + taskResourceRequirement = conf.getTaskResourceRequirements() + assert(taskResourceRequirement.size == 1) + assert(taskResourceRequirement.get(FPGA).isEmpty) + + // Ignore invalid suffix + conf.set(SPARK_TASK_RESOURCE_PREFIX + FPGA + "invalid.suffix", "1") + taskResourceRequirement = conf.getTaskResourceRequirements() + assert(taskResourceRequirement.size == 1) + assert(taskResourceRequirement.get(FPGA).isEmpty) + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ded914d3d896..40ec1b2194fe 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{compact, render} import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually +import org.apache.spark.ResourceName.GPU import org.apache.spark.internal.config._ import org.apache.spark.internal.config.UI._ import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} @@ -718,7 +719,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } test(s"Avoid setting ${CPUS_PER_TASK.key} unreasonably (SPARK-27192)") { - val FAIL_REASON = s"${CPUS_PER_TASK.key} must be <=" + val FAIL_REASON = s"has to be >= the task config: ${CPUS_PER_TASK.key}" Seq( ("local", 2, None), ("local[2]", 3, None), @@ -745,9 +746,9 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu """'{"name": "gpu","addresses":["5", "6"]}'""") val conf = new SparkConf() - .set(SPARK_DRIVER_RESOURCE_PREFIX + "gpu" + + .set(SPARK_DRIVER_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "1") - .set(SPARK_DRIVER_RESOURCE_PREFIX + "gpu" + + .set(SPARK_DRIVER_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_DISCOVERY_SCRIPT_SUFFIX, scriptPath) .setMaster("local-cluster[1, 1, 1024]") .setAppName("test-cluster") @@ -758,8 +759,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc.statusTracker.getExecutorInfos.size == 1) } assert(sc.resources.size === 1) - assert(sc.resources.get("gpu").get.addresses === Array("5", "6")) - assert(sc.resources.get("gpu").get.name === "gpu") + assert(sc.resources.get(GPU).get.addresses === Array("5", "6")) + assert(sc.resources.get(GPU).get.name === "gpu") } } @@ -782,9 +783,9 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu val resourcesFile = writeJsonFile(dir, ja) val conf = new SparkConf() - .set(SPARK_DRIVER_RESOURCE_PREFIX + "gpu" + + .set(SPARK_DRIVER_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "1") - .set(SPARK_DRIVER_RESOURCE_PREFIX + "gpu" + + .set(SPARK_DRIVER_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_DISCOVERY_SCRIPT_SUFFIX, scriptPath) .set(DRIVER_RESOURCES_FILE, resourcesFile) .setMaster("local-cluster[1, 1, 1024]") @@ -797,14 +798,14 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } // driver gpu addresses config should take precedence over the script assert(sc.resources.size === 1) - assert(sc.resources.get("gpu").get.addresses === Array("0", "1", "8")) - assert(sc.resources.get("gpu").get.name === "gpu") + assert(sc.resources.get(GPU).get.addresses === Array("0", "1", "8")) + assert(sc.resources.get(GPU).get.name === "gpu") } } test("Test parsing resources task configs with missing executor config") { val conf = new SparkConf() - .set(SPARK_TASK_RESOURCE_PREFIX + "gpu" + + .set(SPARK_TASK_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "1") .setMaster("local-cluster[1, 1, 1024]") .setAppName("test-cluster") @@ -820,9 +821,9 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Test parsing resources executor config < task requirements") { val conf = new SparkConf() - .set(SPARK_TASK_RESOURCE_PREFIX + "gpu" + + .set(SPARK_TASK_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "2") - .set(SPARK_EXECUTOR_RESOURCE_PREFIX + "gpu" + + .set(SPARK_EXECUTOR_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "1") .setMaster("local-cluster[1, 1, 1024]") .setAppName("test-cluster") @@ -836,6 +837,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu "spark.task.resource.gpu.count = 2")) } + test("Parse resources executor config not the same multiple numbers of the task requirements") { + val conf = new SparkConf() + .set(SPARK_TASK_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "2") + .set(SPARK_EXECUTOR_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "4") + .setMaster("local-cluster[1, 1, 1024]") + .setAppName("test-cluster") + + var error = intercept[SparkException] { + sc = new SparkContext(conf) + }.getMessage() + + assert(error.contains("The configuration of resource: gpu (exec = 4, task = 2) will result " + + "in wasted resources due to resource CPU limiting the number of runnable tasks per " + + "executor to: 1. Please adjust your configuration.")) + } + def mockDiscoveryScript(file: File, result: String): String = { Files.write(s"echo $result", file, StandardCharsets.UTF_8) JavaFiles.setPosixFilePermissions(file.toPath(), @@ -843,6 +860,44 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu file.getPath() } + test("test resource scheduling under local-cluster mode") { + import org.apache.spark.TestUtils._ + + assume(!(Utils.isWindows)) + withTempDir { dir => + val resourceFile = new File(dir, "resourceDiscoverScript") + val resources = """'{"name": "gpu", "addresses": ["0", "1", "2"]}'""" + Files.write(s"echo $resources", resourceFile, StandardCharsets.UTF_8) + JavaFiles.setPosixFilePermissions(resourceFile.toPath(), + EnumSet.of(OWNER_READ, OWNER_EXECUTE, OWNER_WRITE)) + val discoveryScript = resourceFile.getPath() + + val conf = new SparkConf() + .set(s"${SPARK_EXECUTOR_RESOURCE_PREFIX}${GPU}${SPARK_RESOURCE_COUNT_SUFFIX}", "3") + .set(s"${SPARK_EXECUTOR_RESOURCE_PREFIX}${GPU}${SPARK_RESOURCE_DISCOVERY_SCRIPT_SUFFIX}", + discoveryScript) + .setMaster("local-cluster[3, 3, 1024]") + .setAppName("test-cluster") + setTaskResourceRequirement(conf, GPU, 1) + sc = new SparkContext(conf) + + // Ensure all executors has started + eventually(timeout(60.seconds)) { + assert(sc.statusTracker.getExecutorInfos.size == 3) + } + + val rdd = sc.makeRDD(1 to 10, 9).mapPartitions { it => + val context = TaskContext.get() + context.resources().get(GPU).get.addresses.iterator + } + val gpus = rdd.collect() + assert(gpus.sorted === Seq("0", "0", "0", "1", "1", "1", "2", "2", "2")) + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index b3d16d179cb4..c66feeb58fa9 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -20,23 +20,32 @@ package org.apache.spark.executor import java.io.File import java.net.URL +import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.nio.file.{Files => JavaFiles} import java.nio.file.attribute.PosixFilePermission.{OWNER_EXECUTE, OWNER_READ, OWNER_WRITE} -import java.util.EnumSet +import java.util.{EnumSet, Properties} + +import scala.collection.mutable +import scala.concurrent.duration._ import com.google.common.io.Files import org.json4s.JsonAST.{JArray, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, render} import org.mockito.Mockito.when +import org.scalatest.concurrent.Eventually.{eventually, timeout} import org.scalatest.mockito.MockitoSugar import org.apache.spark._ +import org.apache.spark.ResourceInformation +import org.apache.spark.ResourceName.GPU import org.apache.spark.internal.config._ import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableBuffer, Utils} class CoarseGrainedExecutorBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { @@ -224,13 +233,59 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } } - private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + test("track allocated resources by taskId") { + val conf = new SparkConf + val securityMgr = new SecurityManager(conf) + val serializer = new JavaSerializer(conf) + var backend: CoarseGrainedExecutorBackend = null + + try { + val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr) + val env = createMockEnv(conf, serializer, Some(rpcEnv)) + backend = new CoarseGrainedExecutorBackend(env.rpcEnv, rpcEnv.address.hostPort, "1", + "host1", 4, Seq.empty[URL], env, None) + assert(backend.taskResources.isEmpty) + + val taskId = 1000000 + // We don't really verify the data, just pass it around. + val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) + val taskDescription = new TaskDescription(taskId, 2, "1", "TASK 1000000", 19, 1, + mutable.Map.empty, mutable.Map.empty, new Properties, + Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data) + val serializedTaskDescription = TaskDescription.encode(taskDescription) + backend.executor = mock[Executor] + backend.rpcEnv.setupEndpoint("Executor 1", backend) + + // Launch a new task shall add an entry to `taskResources` map. + backend.self.send(LaunchTask(new SerializableBuffer(serializedTaskDescription))) + eventually(timeout(10.seconds)) { + assert(backend.taskResources.size == 1) + assert(backend.taskResources(taskId)(GPU).addresses sameElements Array("0", "1")) + } + + // Update the status of a running task shall not affect `taskResources` map. + backend.statusUpdate(taskId, TaskState.RUNNING, data) + assert(backend.taskResources.size == 1) + assert(backend.taskResources(taskId)(GPU).addresses sameElements Array("0", "1")) + + // Update the status of a finished task shall remove the entry from `taskResources` map. + backend.statusUpdate(taskId, TaskState.FINISHED, data) + assert(backend.taskResources.isEmpty) + } finally { + if (backend != null) { + backend.rpcEnv.shutdown() + } + } + } + + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer, + rpcEnv: Option[RpcEnv] = None): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv] when(mockEnv.conf).thenReturn(conf) when(mockEnv.serializer).thenReturn(serializer) when(mockEnv.closureSerializer).thenReturn(serializer) - when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.rpcEnv).thenReturn(rpcEnv.getOrElse(mockRpcEnv)) SparkEnv.set(mockEnv) mockEnv } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 8a4f7a34e545..495321fcba74 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -24,6 +24,7 @@ import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.immutable import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Map import scala.concurrent.duration._ @@ -369,6 +370,7 @@ class ExecutorSuite extends SparkFunSuite addedFiles = Map[String, Long](), addedJars = Map[String, Long](), properties = new Properties, + resources = immutable.Map[String, ResourceInformation](), serializedTask) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 4858d38cad40..6b3916b1f5e6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,21 +17,29 @@ package org.apache.spark.scheduler +import java.util.Properties import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.immutable +import scala.collection.mutable import scala.concurrent.duration._ +import scala.language.postfixOps +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock import org.scalatest.concurrent.Eventually import org.scalatest.mockito.MockitoSugar._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} -import org.apache.spark.internal.config.{CPUS_PER_TASK, UI} +import org.apache.spark._ +import org.apache.spark.ResourceName.GPU +import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.RPC_MESSAGE_MAX_SIZE import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{RpcUtils, SerializableBuffer} +import org.apache.spark.util.{RpcUtils, SerializableBuffer, Utils} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -174,6 +182,77 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(executorAddedCount === 3) } + test("extra resources from executor") { + import TestUtils._ + + val conf = new SparkConf() + .set(EXECUTOR_CORES, 3) + .set(SPARK_EXECUTOR_RESOURCE_PREFIX + GPU + SPARK_RESOURCE_COUNT_SUFFIX, "3") + .set(SCHEDULER_REVIVE_INTERVAL.key, "1m") // don't let it auto revive during test + .setMaster( + "coarseclustermanager[org.apache.spark.scheduler.TestCoarseGrainedSchedulerBackend]") + .setAppName("test") + setTaskResourceRequirement(conf, GPU, 1) + + sc = new SparkContext(conf) + val backend = sc.schedulerBackend.asInstanceOf[TestCoarseGrainedSchedulerBackend] + val mockEndpointRef = mock[RpcEndpointRef] + val mockAddress = mock[RpcAddress] + when(mockEndpointRef.send(LaunchTask)).thenAnswer((_: InvocationOnMock) => {}) + + val resources = Map(GPU -> new ResourceInformation(GPU, Array("0", "1", "3"))) + + var executorAddedCount: Int = 0 + val listener = new SparkListener() { + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + executorAddedCount += 1 + } + } + + sc.addSparkListener(listener) + + backend.driverEndpoint.askSync[Boolean]( + RegisterExecutor("1", mockEndpointRef, mockAddress.host, 1, Map.empty, Map.empty, resources)) + backend.driverEndpoint.askSync[Boolean]( + RegisterExecutor("2", mockEndpointRef, mockAddress.host, 1, Map.empty, Map.empty, resources)) + backend.driverEndpoint.askSync[Boolean]( + RegisterExecutor("3", mockEndpointRef, mockAddress.host, 1, Map.empty, Map.empty, resources)) + + val frameSize = RpcUtils.maxMessageSizeBytes(sc.conf) + val bytebuffer = java.nio.ByteBuffer.allocate(frameSize - 100) + val buffer = new SerializableBuffer(bytebuffer) + + var execResources = backend.getExecutorAvailableResources("1") + + assert(execResources(GPU).availableAddrs.sorted === Array("0", "1", "3")) + + val taskResources = Map(GPU -> new ResourceInformation(GPU, Array("0"))) + var taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1, 0, "1", + "t1", 0, 1, mutable.Map.empty[String, Long], mutable.Map.empty[String, Long], + new Properties(), taskResources, bytebuffer))) + val ts = backend.getTaskSchedulerImpl() + when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]])).thenReturn(taskDescs) + + backend.driverEndpoint.send(ReviveOffers) + + eventually(timeout(5 seconds)) { + execResources = backend.getExecutorAvailableResources("1") + assert(execResources(GPU).availableAddrs.sorted === Array("1", "3")) + assert(execResources(GPU).assignedAddrs === Array("0")) + } + + backend.driverEndpoint.send( + StatusUpdate("1", 1, TaskState.FINISHED, buffer, taskResources)) + + eventually(timeout(5 seconds)) { + execResources = backend.getExecutorAvailableResources("1") + assert(execResources(GPU).availableAddrs.sorted === Array("0", "1", "3")) + assert(execResources(GPU).assignedAddrs.isEmpty) + } + sc.listenerBus.waitUntilEmpty(executorUpTimeout.toMillis) + assert(executorAddedCount === 3) + } + private def testSubmitJob(sc: SparkContext, rdd: RDD[Int]): Unit = { sc.submitJob( rdd, @@ -184,3 +263,47 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo ) } } + +/** Simple cluster manager that wires up our mock backend for the resource tests. */ +private class CSMockExternalClusterManager extends ExternalClusterManager { + + private var ts: TaskSchedulerImpl = _ + + private val MOCK_REGEX = """coarseclustermanager\[(.*)\]""".r + override def canCreate(masterURL: String): Boolean = MOCK_REGEX.findFirstIn(masterURL).isDefined + + override def createTaskScheduler( + sc: SparkContext, + masterURL: String): TaskScheduler = { + ts = mock[TaskSchedulerImpl] + when(ts.sc).thenReturn(sc) + when(ts.applicationId()).thenReturn("appid1") + when(ts.applicationAttemptId()).thenReturn(Some("attempt1")) + when(ts.schedulingMode).thenReturn(SchedulingMode.FIFO) + when(ts.nodeBlacklist()).thenReturn(Set.empty[String]) + ts + } + + override def createSchedulerBackend( + sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend = { + masterURL match { + case MOCK_REGEX(backendClassName) => + val backendClass = Utils.classForName(backendClassName) + val ctor = backendClass.getConstructor(classOf[TaskSchedulerImpl], classOf[RpcEnv]) + ctor.newInstance(scheduler, sc.env.rpcEnv).asInstanceOf[SchedulerBackend] + } + } + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + } +} + +private[spark] +class TestCoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, override val rpcEnv: RpcEnv) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + def getTaskSchedulerImpl(): TaskSchedulerImpl = scheduler +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala new file mode 100644 index 000000000000..b4ff73083a75 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ResourceName.GPU + +class ExecutorResourceInfoSuite extends SparkFunSuite { + + test("Track Executor Resource information") { + // Init Executor Resource. + val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3")) + assert(info.availableAddrs.sorted sameElements Seq("0", "1", "2", "3")) + assert(info.assignedAddrs.isEmpty) + + // Acquire addresses + info.acquire(Seq("0", "1")) + assert(info.availableAddrs.sorted sameElements Seq("2", "3")) + assert(info.assignedAddrs.sorted sameElements Seq("0", "1")) + + // release addresses + info.release(Array("0", "1")) + assert(info.availableAddrs.sorted sameElements Seq("0", "1", "2", "3")) + assert(info.assignedAddrs.isEmpty) + } + + test("Don't allow acquire address that is not available") { + // Init Executor Resource. + val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3")) + // Acquire some addresses. + info.acquire(Seq("0", "1")) + assert(!info.availableAddrs.contains("1")) + // Acquire an address that is not available + val e = intercept[SparkException] { + info.acquire(Array("1")) + } + assert(e.getMessage.contains("Try to acquire an address that is not available.")) + } + + test("Don't allow acquire address that doesn't exist") { + // Init Executor Resource. + val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3")) + assert(!info.availableAddrs.contains("4")) + // Acquire an address that doesn't exist + val e = intercept[SparkException] { + info.acquire(Array("4")) + } + assert(e.getMessage.contains("Try to acquire an address that doesn't exist.")) + } + + test("Don't allow release address that is not assigned") { + // Init Executor Resource. + val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3")) + // Acquire addresses + info.acquire(Array("0", "1")) + assert(!info.assignedAddrs.contains("2")) + // Release an address that is not assigned + val e = intercept[SparkException] { + info.release(Array("2")) + } + assert(e.getMessage.contains("Try to release an address that is not assigned.")) + } + + test("Don't allow release address that doesn't exist") { + // Init Executor Resource. + val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3")) + assert(!info.assignedAddrs.contains("4")) + // Release an address that doesn't exist + val e = intercept[SparkException] { + info.release(Array("4")) + } + assert(e.getMessage.contains("Try to release an address that doesn't exist.")) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 2f677776fe82..c16b552d2089 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { - task.run(0, 0, null) + task.run(0, 0, null, null) } assert(TaskContextSuite.completed) } @@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { - task.run(0, 0, null) + task.run(0, 0, null, null) } assert(TaskContextSuite.lastError.getMessage == "damn error") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index ba62eec0522d..233bc73aa741 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -23,6 +23,8 @@ import java.util.Properties import scala.collection.mutable.HashMap +import org.apache.spark.ResourceInformation +import org.apache.spark.ResourceName.GPU import org.apache.spark.SparkFunSuite class TaskDescriptionSuite extends SparkFunSuite { @@ -53,6 +55,9 @@ class TaskDescriptionSuite extends SparkFunSuite { } } + val originalResources = + Map(GPU -> new ResourceInformation(GPU, Array("1", "2", "3"))) + // Create a dummy byte buffer for the task. val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) @@ -66,6 +71,7 @@ class TaskDescriptionSuite extends SparkFunSuite { originalFiles, originalJars, originalProperties, + originalResources, taskBuffer ) @@ -82,6 +88,17 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) + assert(equalResources(decodedTaskDescription.resources, originalTaskDescription.resources)) assert(decodedTaskDescription.serializedTask.equals(taskBuffer)) + + def equalResources(original: Map[String, ResourceInformation], + target: Map[String, ResourceInformation]): Boolean = { + original.size == target.size && original.forall { case (name, info) => + target.get(name).exists { targetInfo => + info.name.equals(targetInfo.name) && + info.addresses.sameElements(targetInfo.addresses) + } + } + } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 29614058485a..456996c8eb3f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.duration._ import org.mockito.ArgumentMatchers.{any, anyInt, anyString, eq => meq} @@ -29,6 +29,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.mockito.MockitoSugar import org.apache.spark._ +import org.apache.spark.ResourceName.GPU import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.util.ManualClock @@ -80,6 +81,10 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B setupSchedulerWithMaster("local", confs: _*) } + def setupScheduler(numCores: Int, confs: (String, String)*): TaskSchedulerImpl = { + setupSchedulerWithMaster(s"local[$numCores]", confs: _*) + } + def setupSchedulerWithMaster(master: String, confs: (String, String)*): TaskSchedulerImpl = { val conf = new SparkConf().setMaster(master).setAppName("TaskSchedulerImplSuite") confs.foreach { case (k, v) => conf.set(k, v) } @@ -1238,4 +1243,37 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test")) assert(tsm.isZombie) } + + test("Scheduler correctly accounts for GPUs per task") { + val taskCpus = 1 + val taskGpus = 1 + val executorGpus = 4 + val executorCpus = 4 + val taskScheduler = setupScheduler(numCores = executorCpus, + config.CPUS_PER_TASK.key -> taskCpus.toString, + s"${config.SPARK_TASK_RESOURCE_PREFIX}${GPU}${config.SPARK_RESOURCE_COUNT_SUFFIX}" -> + taskGpus.toString, + s"${config.SPARK_EXECUTOR_RESOURCE_PREFIX}${GPU}${config.SPARK_RESOURCE_COUNT_SUFFIX}" -> + executorGpus.toString, + config.EXECUTOR_CORES.key -> executorCpus.toString) + val taskSet = FakeTask.createTaskSet(3) + + val numFreeCores = 2 + val resources = Map(GPU -> ArrayBuffer("0", "1", "2", "3")) + val singleCoreWorkerOffers = + IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores, None, resources)) + val zeroGpuWorkerOffers = + IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores, None, Map.empty)) + taskScheduler.submitTasks(taskSet) + // WorkerOffer doesn't contain GPU resource, don't launch any task. + var taskDescriptions = taskScheduler.resourceOffers(zeroGpuWorkerOffers).flatten + assert(0 === taskDescriptions.length) + assert(!failedTaskSet) + // Launch tasks on executor that satisfies resource requirements. + taskDescriptions = taskScheduler.resourceOffers(singleCoreWorkerOffers).flatten + assert(2 === taskDescriptions.length) + assert(!failedTaskSet) + assert(ArrayBuffer("0") === taskDescriptions(0).resources.get(GPU).get.addresses) + assert(ArrayBuffer("1") === taskDescriptions(1).resources.get(GPU).get.addresses) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 72c6ab964ccf..3d09b10a6dc8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -27,6 +27,7 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.apache.spark._ +import org.apache.spark.ResourceName.GPU import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerInstance @@ -1633,4 +1634,24 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // by that point. assert(FakeRackUtil.numBatchInvocation === 1) } + + test("TaskSetManager allocate resource addresses from available resources") { + import TestUtils._ + + sc = new SparkContext("local", "test") + setTaskResourceRequirement(sc.conf, GPU, 2) + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(1) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + + val availableResources = Map(GPU -> ArrayBuffer("0", "1", "2", "3")) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF, availableResources) + assert(taskOption.isDefined) + val allocatedResources = taskOption.get.resources + assert(allocatedResources.size == 1) + assert(allocatedResources(GPU).addresses sameElements Array("0", "1")) + // Allocated resource addresses should still present in `availableResources`, they will only + // get removed inside TaskSchedulerImpl later. + assert(availableResources(GPU) sameElements Array("0", "1", "2", "3")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5ccaa38c0867..98e4764bdfe4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-27366][CORE] Support GPU Resources in Spark job scheduling + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resources"), + // [SPARK-27410][MLLIB] Remove deprecated / no-op mllib.KMeans getRuns, setRuns ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"), diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 3e63c35361a5..b70be7a4b0ee 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -24,6 +24,7 @@ import java.util.Collections import java.util.Properties import scala.collection.JavaConverters._ +import scala.collection.immutable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -35,7 +36,8 @@ import org.mockito.ArgumentMatchers.{any, anyLong, eq => meq} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, ResourceInformation, SparkConf, SparkContext, + SparkFunSuite} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, @@ -262,6 +264,7 @@ class MesosFineGrainedSchedulerBackendSuite addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), + resources = immutable.Map.empty[String, ResourceInformation], ByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) @@ -372,6 +375,7 @@ class MesosFineGrainedSchedulerBackendSuite addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), + resources = immutable.Map.empty[String, ResourceInformation], ByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1)