Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class BarrierTaskContext private[spark] (
taskContext.getMetricsSources(sourceName)
}

override def resources(): Map[String, ResourceInformation] = taskContext.resources()

override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted()

override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason()
Expand Down
33 changes: 9 additions & 24 deletions core/src/main/scala/org/apache/spark/SparkConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
}
}

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we lost the check to make sure the executor resources are a multiple of task requirements, do you want to add that back?
note I added a check (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/SparkConf.scala#L609) to make sure they were large enough but not that it was an exact fit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add it back later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added back in SparkContext.checkResourcesPerTask()

* Get task resource requirements.
*/
private[spark] def getTaskResourceRequirements(): Map[String, Int] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need unit test or leave a TODO since this might go away with the conf refactoring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test case in SparkConfSuite

getAllWithPrefix(SPARK_TASK_RESOURCE_PREFIX)
.withFilter { case (k, v) => k.endsWith(SPARK_RESOURCE_COUNT_SUFFIX)}
.map { case (k, v) => (k.dropRight(SPARK_RESOURCE_COUNT_SUFFIX.length), v.toInt)}.toMap
}

/**
* Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones.
Expand Down Expand Up @@ -603,30 +612,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
require(executorTimeoutThresholdMs > executorHeartbeatIntervalMs, "The value of " +
s"${networkTimeout}=${executorTimeoutThresholdMs}ms must be no less than the value of " +
s"${EXECUTOR_HEARTBEAT_INTERVAL.key}=${executorHeartbeatIntervalMs}ms.")

// Make sure the executor resources were specified and are large enough if
// any task resources were specified.
val taskResourcesAndCount =
getAllWithPrefixAndSuffix(SPARK_TASK_RESOURCE_PREFIX, SPARK_RESOURCE_COUNT_SUFFIX).toMap
val executorResourcesAndCounts =
getAllWithPrefixAndSuffix(SPARK_EXECUTOR_RESOURCE_PREFIX, SPARK_RESOURCE_COUNT_SUFFIX).toMap

taskResourcesAndCount.foreach { case (rName, taskCount) =>
val execCount = executorResourcesAndCounts.get(rName).getOrElse(
throw new SparkException(
s"The executor resource config: " +
s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " +
"needs to be specified since a task requirement config: " +
s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} was specified")
)
if (execCount.toLong < taskCount.toLong) {
throw new SparkException(
s"The executor resource config: " +
s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " +
s"= $execCount has to be >= the task config: " +
s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} = $taskCount")
}
}
}

/**
Expand Down
84 changes: 65 additions & 19 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2707,27 +2707,73 @@ object SparkContext extends Logging {
// When running locally, don't try to re-execute tasks on failure.
val MAX_LOCAL_TASK_FAILURES = 1

// SPARK-26340: Ensure that executor's core num meets at least one task requirement.
def checkCpusPerTask(
clusterMode: Boolean,
maxCoresPerExecutor: Option[Int]): Unit = {
val cpusPerTask = sc.conf.get(CPUS_PER_TASK)
if (clusterMode && sc.conf.contains(EXECUTOR_CORES)) {
if (sc.conf.get(EXECUTOR_CORES) < cpusPerTask) {
throw new SparkException(s"${CPUS_PER_TASK.key}" +
s" must be <= ${EXECUTOR_CORES.key} when run on $master.")
// Ensure that executor's resources satisfies one or more tasks requirement.
def checkResourcesPerTask(clusterMode: Boolean, executorCores: Option[Int]): Unit = {
val taskCores = sc.conf.get(CPUS_PER_TASK)
val execCores = if (clusterMode) {
executorCores.getOrElse(sc.conf.get(EXECUTOR_CORES))
} else {
executorCores.get
}

// Number of cores per executor must meet at least one task requirement.
if (execCores < taskCores) {
throw new SparkException(s"The number of cores per executor (=$execCores) has to be >= " +
s"the task config: ${CPUS_PER_TASK.key} = $taskCores when run on $master.")
}

// Calculate the max slots each executor can provide based on resources available on each
// executor and resources required by each task.
val taskResourcesAndCount = sc.conf.getTaskResourceRequirements()
val executorResourcesAndCounts = sc.conf.getAllWithPrefixAndSuffix(
SPARK_EXECUTOR_RESOURCE_PREFIX, SPARK_RESOURCE_COUNT_SUFFIX).toMap
var numSlots = execCores / taskCores
var limitingResourceName = "CPU"
taskResourcesAndCount.foreach { case (rName, taskCount) =>
// Make sure the executor resources were specified through config.
val execCount = executorResourcesAndCounts.getOrElse(rName,
throw new SparkException(
s"The executor resource config: " +
s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " +
"needs to be specified since a task requirement config: " +
s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} was specified")
)
// Make sure the executor resources are large enough to launch at least one task.
if (execCount.toLong < taskCount.toLong) {
throw new SparkException(
s"The executor resource config: " +
s"${SPARK_EXECUTOR_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} " +
s"= $execCount has to be >= the task config: " +
s"${SPARK_TASK_RESOURCE_PREFIX + rName + SPARK_RESOURCE_COUNT_SUFFIX} = $taskCount")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we perhaps want to expand this to say something like - so we don't waste resources

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

}
// Compare and update the max slots each executor can provide.
val resourceNumSlots = execCount.toInt / taskCount
if (resourceNumSlots < numSlots) {
numSlots = resourceNumSlots
limitingResourceName = rName
}
} else if (maxCoresPerExecutor.isDefined) {
if (maxCoresPerExecutor.get < cpusPerTask) {
throw new SparkException(s"Only ${maxCoresPerExecutor.get} cores available per executor" +
s" when run on $master, and ${CPUS_PER_TASK.key} must be <= it.")
}
// There have been checks above to make sure the executor resources were specified and are
// large enough if any task resources were specified.
taskResourcesAndCount.foreach { case (rName, taskCount) =>
val execCount = executorResourcesAndCounts(rName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to make sure this returns something or throw exception saying executor resource config required

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combined this check with the check from the SparkConf.

if (taskCount.toInt * numSlots < execCount.toInt) {
val message = s"The configuration of resource: $rName (exec = ${execCount.toInt}, " +
s"task = ${taskCount}) will result in wasted resources due to resource " +
s"${limitingResourceName} limiting the number of runnable tasks per executor to: " +
s"${numSlots}. Please adjust your configuration."
if (Utils.isTesting) {
throw new SparkException(message)
} else {
logWarning(message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

originally we talked about throwing here to not allow it, just want to make sure we intentionally changed our mind here? I'm really ok either way we go as there were some people questioning this on the Spip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now have TaskSchedulerImpl.resourcesMeetTaskRequirements() to ensure there are enough resources before schedule a task, I think it's safe to just place a warning here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer a warning because the discovery script might return more and it is out of user's control. And available resources might not happen to be a multiple of task requested counts. For example, you have 32 CPU Cores and 3 GPUs.

}
}
}
}

master match {
case "local" =>
checkCpusPerTask(clusterMode = false, Some(1))
checkResourcesPerTask(clusterMode = false, Some(1))
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1)
scheduler.initialize(backend)
Expand All @@ -2740,7 +2786,7 @@ object SparkContext extends Logging {
if (threadCount <= 0) {
throw new SparkException(s"Asked to run locally with $threadCount threads")
}
checkCpusPerTask(clusterMode = false, Some(threadCount))
checkResourcesPerTask(clusterMode = false, Some(threadCount))
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
scheduler.initialize(backend)
Expand All @@ -2751,22 +2797,22 @@ object SparkContext extends Logging {
// local[*, M] means the number of cores on the computer with M failures
// local[N, M] means exactly N threads with M failures
val threadCount = if (threads == "*") localCpuCount else threads.toInt
checkCpusPerTask(clusterMode = false, Some(threadCount))
checkResourcesPerTask(clusterMode = false, Some(threadCount))
val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
scheduler.initialize(backend)
(backend, scheduler)

case SPARK_REGEX(sparkUrl) =>
checkCpusPerTask(clusterMode = true, None)
checkResourcesPerTask(clusterMode = true, None)
val scheduler = new TaskSchedulerImpl(sc)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
(backend, scheduler)

case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
checkCpusPerTask(clusterMode = true, Some(coresPerSlave.toInt))
checkResourcesPerTask(clusterMode = true, Some(coresPerSlave.toInt))
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
if (sc.executorMemory > memoryPerSlaveInt) {
Expand All @@ -2787,7 +2833,7 @@ object SparkContext extends Logging {
(backend, scheduler)

case masterUrl =>
checkCpusPerTask(clusterMode = true, None)
checkResourcesPerTask(clusterMode = true, None)
val cm = getClusterManager(masterUrl) match {
case Some(clusterMgr) => clusterMgr
case None => throw new SparkException("Could not parse Master URL: '" + master + "'")
Expand Down
9 changes: 8 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark
import java.io.Serializable
import java.util.Properties

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.{DeveloperApi, Evolving}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
Expand Down Expand Up @@ -176,6 +176,13 @@ abstract class TaskContext extends Serializable {
*/
def getLocalProperty(key: String): String

/**
* Resources allocated to the task. The key is the resource name and the value is information
* about the resource. Please refer to [[ResourceInformation]] for specifics.
*/
@Evolving
def resources(): Map[String, ResourceInformation]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the doc, it might be better to explain what the keys are.


@DeveloperApi
def taskMetrics(): TaskMetrics

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ private[spark] class TaskContextImpl(
localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
// The default value is only used in tests.
override val taskMetrics: TaskMetrics = TaskMetrics.empty)
override val taskMetrics: TaskMetrics = TaskMetrics.empty,
override val resources: Map[String, ResourceInformation] = Map.empty)
extends TaskContext
with Logging {

Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import com.google.common.io.{ByteStreams, Files}
import org.apache.log4j.PropertyConfigurator

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config._
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -311,6 +312,16 @@ private[spark] object TestUtils {
current ++ current.filter(_.isDirectory).flatMap(recursiveList)
}

/**
* Set task resource requirement.
*/
def setTaskResourceRequirement(
conf: SparkConf,
resourceName: String,
resourceCount: Int): SparkConf = {
val key = s"${SPARK_TASK_RESOURCE_PREFIX}${resourceName}${SPARK_RESOURCE_COUNT_SUFFIX}"
conf.set(key, resourceCount.toString)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ private[spark] class CoarseGrainedExecutorBackend(
// to be changed so that we don't share the serializer instance across threads
private[this] val ser: SerializerInstance = env.closureSerializer.newInstance()

/**
* Map each taskId to the information about the resource allocated to it, Please refer to
* [[ResourceInformation]] for specifics.
* Exposed for testing only.
*/
private[executor] val taskResources = new mutable.HashMap[Long, Map[String, ResourceInformation]]

override def onStart() {
logInfo("Connecting to driver: " + driverUrl)
val resources = parseOrFindResources(resourcesFile)
Expand Down Expand Up @@ -151,6 +158,7 @@ private[spark] class CoarseGrainedExecutorBackend(
} else {
val taskDesc = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
taskResources(taskDesc.taskId) = taskDesc.resources
executor.launchTask(this, taskDesc)
}

Expand Down Expand Up @@ -197,7 +205,11 @@ private[spark] class CoarseGrainedExecutorBackend(
}

override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
val msg = StatusUpdate(executorId, taskId, state, data)
val resources = taskResources.getOrElse(taskId, Map.empty[String, ResourceInformation])
val msg = StatusUpdate(executorId, taskId, state, data, resources)
if (TaskState.isFinished(state)) {
taskResources.remove(taskId)
}
driver match {
case Some(driverRef) => driverRef.send(msg)
case None => logWarning(s"Drop $msg because has not yet connected to driver")
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ private[spark] class Executor(
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
metricsSystem = env.metricsSystem,
resources = taskDescription.resources)
threwException = false
res
} {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.scheduler

import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.util.collection.OpenHashMap

/**
* Class to hold information about a type of Resource on an Executor. This information is managed
* by SchedulerBackend, and TaskScheduler shall schedule tasks on idle Executors based on the
* information.
* Please note that this class is intended to be used in a single thread.
* @param name Resource name
* @param addresses Resource addresses provided by the executor
*/
private[spark] class ExecutorResourceInfo(
val name: String,
addresses: Seq[String]) extends Serializable {

/**
* Map from an address to its availability, the value `true` means the address is available,
* while value `false` means the address is assigned.
* TODO Use [[OpenHashMap]] instead to gain better performance.
*/
private val addressAvailabilityMap = mutable.HashMap(addresses.map(_ -> true): _*)

/**
* Sequence of currently available resource addresses.
*/
def availableAddrs: Seq[String] = addressAvailabilityMap.flatMap { case (addr, available) =>
if (available) Some(addr) else None
}.toSeq

/**
* Sequence of currently assigned resource addresses.
* Exposed for testing only.
*/
private[scheduler] def assignedAddrs: Seq[String] = addressAvailabilityMap
.flatMap { case (addr, available) =>
if (!available) Some(addr) else None
}.toSeq

/**
* Acquire a sequence of resource addresses (to a launched task), these addresses must be
* available. When the task finishes, it will return the acquired resource addresses.
* Throw an Exception if an address is not available or doesn't exist.
*/
def acquire(addrs: Seq[String]): Unit = {
addrs.foreach { address =>
if (!addressAvailabilityMap.contains(address)) {
throw new SparkException(s"Try to acquire an address that doesn't exist. $name address " +
s"$address doesn't exist.")
}
val isAvailable = addressAvailabilityMap(address)
if (isAvailable) {
addressAvailabilityMap(address) = false
} else {
throw new SparkException(s"Try to acquire an address that is not available. $name " +
s"address $address is not available.")
}
}
}

/**
* Release a sequence of resource addresses, these addresses must have been assigned. Resource
* addresses are released when a task has finished.
* Throw an Exception if an address is not assigned or doesn't exist.
*/
def release(addrs: Seq[String]): Unit = {
addrs.foreach { address =>
if (!addressAvailabilityMap.contains(address)) {
throw new SparkException(s"Try to release an address that doesn't exist. $name address " +
s"$address doesn't exist.")
}
val isAvailable = addressAvailabilityMap(address)
if (!isAvailable) {
addressAvailabilityMap(address) = true
} else {
throw new SparkException(s"Try to release an address that is not assigned. $name " +
s"address $address is not assigned.")
}
}
}
}
Loading