Skip to content

Commit 1c310e4

Browse files
author
Andrew Or
committed
Wrap a few more RDD functions in an operation scope
1 parent 3ffe566 commit 1c310e4

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,45 +30,57 @@ import org.apache.spark.util.StatCounter
3030
*/
3131
class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
3232
/** Add up the elements in this RDD. */
33-
def sum(): Double = {
33+
def sum(): Double = self.withScope {
3434
self.fold(0.0)(_ + _)
3535
}
3636

3737
/**
3838
* Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and
3939
* count of the RDD's elements in one operation.
4040
*/
41-
def stats(): StatCounter = {
41+
def stats(): StatCounter = self.withScope {
4242
self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
4343
}
4444

4545
/** Compute the mean of this RDD's elements. */
46-
def mean(): Double = stats().mean
46+
def mean(): Double = self.withScope {
47+
stats().mean
48+
}
4749

4850
/** Compute the variance of this RDD's elements. */
49-
def variance(): Double = stats().variance
51+
def variance(): Double = self.withScope {
52+
stats().variance
53+
}
5054

5155
/** Compute the standard deviation of this RDD's elements. */
52-
def stdev(): Double = stats().stdev
56+
def stdev(): Double = self.withScope {
57+
stats().stdev
58+
}
5359

5460
/**
5561
* Compute the sample standard deviation of this RDD's elements (which corrects for bias in
5662
* estimating the standard deviation by dividing by N-1 instead of N).
5763
*/
58-
def sampleStdev(): Double = stats().sampleStdev
64+
def sampleStdev(): Double = self.withScope {
65+
stats().sampleStdev
66+
}
5967

6068
/**
6169
* Compute the sample variance of this RDD's elements (which corrects for bias in
6270
* estimating the variance by dividing by N-1 instead of N).
6371
*/
64-
def sampleVariance(): Double = stats().sampleVariance
72+
def sampleVariance(): Double = self.withScope {
73+
stats().sampleVariance
74+
}
6575

6676
/**
6777
* :: Experimental ::
6878
* Approximate operation to return the mean within a timeout.
6979
*/
7080
@Experimental
71-
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
81+
def meanApprox(
82+
timeout: Long,
83+
confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope {
7284
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
7385
val evaluator = new MeanEvaluator(self.partitions.length, confidence)
7486
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
@@ -79,7 +91,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
7991
* Approximate operation to return the sum within a timeout.
8092
*/
8193
@Experimental
82-
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
94+
def sumApprox(
95+
timeout: Long,
96+
confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope {
8397
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
8498
val evaluator = new SumEvaluator(self.partitions.length, confidence)
8599
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
@@ -93,7 +107,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
93107
* If the RDD contains infinity, NaN throws an exception
94108
* If the elements in RDD do not vary (max == min) always returns a single bucket.
95109
*/
96-
def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
110+
def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope {
97111
// Scala's built-in range has issues. See #SI-8782
98112
def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = {
99113
val span = max - min
@@ -140,7 +154,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
140154
* the maximum value of the last position and all NaN entries will be counted
141155
* in that bucket.
142156
*/
143-
def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
157+
def histogram(
158+
buckets: Array[Double],
159+
evenBuckets: Boolean = false): Array[Long] = self.withScope {
144160
if (buckets.length < 2) {
145161
throw new IllegalArgumentException("buckets array must have at least two elements")
146162
}

core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
4747
self: RDD[P])
4848
extends Logging with Serializable
4949
{
50-
// TODO: Don't forget to scope me later
51-
5250
private val ordering = implicitly[Ordering[K]]
5351

5452
/**
@@ -59,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
5957
*/
6058
// TODO: this currently doesn't work on P other than Tuple2!
6159
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
62-
: RDD[(K, V)] =
60+
: RDD[(K, V)] = self.withScope
6361
{
6462
val part = new RangePartitioner(numPartitions, self, ascending)
6563
new ShuffledRDD[K, V, V](self, part)
@@ -73,7 +71,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
7371
* This is more efficient than calling `repartition` and then sorting within each partition
7472
* because it can push the sorting down into the shuffle machinery.
7573
*/
76-
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = {
74+
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = self.withScope {
7775
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
7876
}
7977

@@ -83,7 +81,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
8381
* performed efficiently by only scanning the partitions that might contain matching elements.
8482
* Otherwise, a standard `filter` is applied to all partitions.
8583
*/
86-
def filterByRange(lower: K, upper: K): RDD[P] = {
84+
def filterByRange(lower: K, upper: K): RDD[P] = self.withScope {
8785

8886
def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper)
8987

core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
8585
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
8686
* file system.
8787
*/
88-
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
88+
def saveAsSequenceFile(
89+
path: String,
90+
codec: Option[Class[_ <: CompressionCodec]] = None): Unit = self.withScope {
8991
def anyToWritable[U <% Writable](u: U): Writable = u
9092

9193
// TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and

0 commit comments

Comments
 (0)