Skip to content

Commit b904518

Browse files
Niklas Wilckemengxr
authored andcommitted
[Spark-4060] [MLlib] exposing special rdd functions to the public
Author: Niklas Wilcke <[email protected]> Closes #2907 from numbnut/master and squashes the following commits: 7f7c767 [Niklas Wilcke] [Spark-4060] [MLlib] exposing special rdd functions to the public, #2907 (cherry picked from commit f90ad5d) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 4b13bff commit b904518

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve {
4343
*/
4444
def of(curve: RDD[(Double, Double)]): Double = {
4545
curve.sliding(2).aggregate(0.0)(
46-
seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
46+
seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
4747
combOp = _ + _
4848
)
4949
}

mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd
2020
import scala.language.implicitConversions
2121
import scala.reflect.ClassTag
2222

23+
import org.apache.spark.annotation.DeveloperApi
2324
import org.apache.spark.HashPartitioner
2425
import org.apache.spark.SparkContext._
2526
import org.apache.spark.rdd.RDD
@@ -28,8 +29,8 @@ import org.apache.spark.util.Utils
2829
/**
2930
* Machine learning specific RDD functions.
3031
*/
31-
private[mllib]
32-
class RDDFunctions[T: ClassTag](self: RDD[T]) {
32+
@DeveloperApi
33+
class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
3334

3435
/**
3536
* Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
@@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
3940
* trigger a Spark job if the parent RDD has more than one partitions and the window size is
4041
* greater than 1.
4142
*/
42-
def sliding(windowSize: Int): RDD[Seq[T]] = {
43+
def sliding(windowSize: Int): RDD[Array[T]] = {
4344
require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.")
4445
if (windowSize == 1) {
45-
self.map(Seq(_))
46+
self.map(Array(_))
4647
} else {
4748
new SlidingRDD[T](self, windowSize)
4849
}
@@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
112113
}
113114
}
114115

115-
private[mllib]
116+
@DeveloperApi
116117
object RDDFunctions {
117118

118119
/** Implicit conversion from an RDD to RDDFunctions. */

mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]
4545
*/
4646
private[mllib]
4747
class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
48-
extends RDD[Seq[T]](parent) {
48+
extends RDD[Array[T]](parent) {
4949

5050
require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")
5151

52-
override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = {
52+
override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
5353
val part = split.asInstanceOf[SlidingRDDPartition[T]]
5454
(firstParent[T].iterator(part.prev, context) ++ part.tail)
5555
.sliding(windowSize)
5656
.withPartial(false)
57+
.map(_.toArray)
5758
}
5859

5960
override def getPreferredLocations(split: Partition): Seq[String] =

mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
4242
val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
4343
val rdd = sc.parallelize(data, data.length).flatMap(s => s)
4444
assert(rdd.partitions.size === data.length)
45-
val sliding = rdd.sliding(3)
46-
val expected = data.flatMap(x => x).sliding(3).toList
47-
assert(sliding.collect().toList === expected)
45+
val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq)
46+
val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
47+
assert(sliding === expected)
4848
}
4949

5050
test("treeAggregate") {

0 commit comments

Comments
 (0)