Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions R/pkg/R/context.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ getMinPartitions <- function(sc, minPartitions) {
as.integer(minPartitions)
}

#' Total number of CPU cores of all executors registered in the cluster at the moment.
#'
#' @param sc SparkContext to use
#' @return current number of cores in the cluster.
numCores <- function(sc) {
callJMethod(sc, "numCores")
}

#' Total number of executors registered in the cluster at the moment.
#'
#' @param sc SparkContext to use
#' @return current number of executors in the cluster.
numExecutors <- function(sc) {
callJMethod(sc, "numExecutors")
}

#' Create an RDD from a text file.
#'
#' This function reads a text file from HDFS, a local file system (available on all
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2336,6 +2336,18 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def defaultMinPartitions: Int = math.min(defaultParallelism, 2)

/**
* Total number of CPU cores of all executors registered in the cluster at the moment.
* The number reflects current status of the cluster and can change in the future.
*/
def numCores: Int = taskScheduler.numCores

/**
* Total number of executors registered in the cluster at the moment.
* The number reflects current status of the cluster and can change in the future.
*/
def numExecutors: Int = taskScheduler.numExecutors

private val nextShuffleId = new AtomicInteger(0)

private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ class JavaSparkContext(val sc: SparkContext)
/** Default min number of partitions for Hadoop RDDs when not given by user */
def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions

/**
* Total number of CPU cores of all executors registered in the cluster at the moment.
* The number reflects current status of the cluster and can change in the future.
*/
def numCores: java.lang.Integer = sc.numCores

/**
* Total number of executors registered in the cluster at the moment.
* The number reflects current status of the cluster and can change in the future.
*/
def numExecutors: java.lang.Integer = sc.numExecutors

/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ private[spark] trait SchedulerBackend {
def stop(): Unit
def reviveOffers(): Unit
def defaultParallelism(): Int
def numCores(): Int
def numExecutors(): Int

/**
* Requests that an executor kills a running task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ private[spark] trait TaskScheduler {
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int

def numCores(): Int

def numExecutors(): Int

/**
* Update metrics for in-progress tasks and let the master know that the BlockManager is still
* alive. Return true if the driver knows about the given block manager. Otherwise, return false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@ private[spark] class TaskSchedulerImpl(

override def defaultParallelism(): Int = backend.defaultParallelism()

override def numCores(): Int = backend.numCores()

override def numExecutors(): Int = backend.numExecutors()

// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2))
}

override def numCores(): Int = totalCoreCount.get()
override def numExecutors(): Int = totalRegisteredExecutors.get()

/**
* Called by subclasses when notified of a lost worker. It just fires the message and returns
* at once.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ private[spark] class LocalSchedulerBackend(
override def defaultParallelism(): Int =
scheduler.conf.getInt("spark.default.parallelism", totalCores)

override def numCores(): Int = totalCores
override def numExecutors(): Int = 1

override def killTask(
taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
localEndpoint.send(KillTask(taskId, interruptThread, reason))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public void javaSparkContext() {
new JavaSparkContext("local", "name", "sparkHome", "jarFile").stop();
new JavaSparkContext("local", "name", "sparkHome", jars).stop();
new JavaSparkContext("local", "name", "sparkHome", jars, environment).stop();

JavaSparkContext sc = new JavaSparkContext(new SparkConf().setMaster("local[2]").setAppName("name"));
assert sc.numCores() == 2;
assert sc.numExecutors() == 1;
sc.stop();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,9 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend

override def defaultParallelism(): Int = sb.defaultParallelism()

override def numCores(): Int = sb.numCores()
override def numExecutors(): Int = sb.numExecutors()

override def killExecutorsOnHost(host: String): Boolean = {
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ class MultiExecutorMockBackend(
}.toMap
}

override def defaultParallelism(): Int = nHosts * nExecutorsPerHost * nCoresPerExecutor
override def defaultParallelism(): Int = numCores

override def numCores(): Int = nHosts * nExecutorsPerHost * nCoresPerExecutor
override def numExecutors(): Int = nHosts * nExecutorsPerHost

}

class MockRDDWithLocalityPrefs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
override def defaultParallelism() = numCores
override def numCores(): Int = 2
override def numExecutors(): Int = 1

override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
override def applicationAttemptId(): Option[String] = None
Expand Down Expand Up @@ -630,7 +633,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
throw new UnsupportedOperationException
}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def defaultParallelism(): Int = numCores
override def numCores(): Int = 2
override def numExecutors(): Int = 1
override def executorHeartbeatReceived(
execId: String,
accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ private class DummySchedulerBackend extends SchedulerBackend {
def stop() {}
def reviveOffers() {}
def defaultParallelism(): Int = 1
def numCores(): Int = 1
def numExecutors(): Int = 1
}

private class DummyTaskScheduler extends TaskScheduler {
Expand All @@ -83,6 +85,8 @@ private class DummyTaskScheduler extends TaskScheduler {
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def numCores(): Int = 2
override def numExecutors(): Int = 1
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
override def applicationAttemptId(): Option[String] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ private[spark] class SingleCoreMockBackend(
val cores = 1

override def defaultParallelism(): Int = conf.getInt("spark.default.parallelism", cores)
override def numCores(): Int = cores
override def numExecutors(): Int = 1

freeCores = cores
val localExecutorId = SparkContext.DRIVER_IDENTIFIER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class FakeSchedulerBackend extends SchedulerBackend {
def stop() {}
def reviveOffers() {}
def defaultParallelism(): Int = 1
def numCores(): Int = 1
def numExecutors(): Int = 1
}

class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,22 @@ def defaultMinPartitions(self):
"""
return self._jsc.sc().defaultMinPartitions()

@property
def numCores(self):
"""
Total number of CPU cores of all executors registered in the cluster at the moment.
The number reflects current status of the cluster and can change in the future.
"""
return self._jsc.sc().numCores()

@property
def numExecutors(self):
"""
Total number of executors registered in the cluster at the moment.
The number reflects current status of the cluster and can change in the future.
"""
return self._jsc.sc().numExecutors()

def stop(self):
"""
Shut down the SparkContext.
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,14 @@ def setUp(self):
# Allow retries even though they are normally disabled in local mode
self.sc = SparkContext('local[4, 2]', class_name)

def test_num_cores(self):
"""Test for number of cores in the cluster"""
self.assertEqual(self.sc.numCores, 4)

def test_num_executors(self):
"""Test for number of executors in the cluster"""
self.assertEqual(self.sc.numExecutors, 1)

def test_stage_id(self):
"""Test the stage ids are available and incrementing as expected."""
rdd = self.sc.parallelize(range(10))
Expand Down