Skip to content

Commit 217b5e9

Browse files
committed
[SPARK-3108][MLLIB] add predictOnValues to StreamingLR and fix predictOn
It is useful in streaming to allow users to carry extra data with the prediction, for monitoring the prediction error for example. freeman-lab Author: Xiangrui Meng <[email protected]> Closes apache#2023 from mengxr/predict-on-values and squashes the following commits: cac47b8 [Xiangrui Meng] add classtag 2821b3b [Xiangrui Meng] use mapValues 0925efa [Xiangrui Meng] add predictOnValues to StreamingLR and fix predictOn
1 parent c8b16ca commit 217b5e9

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ object StreamingLinearRegression {
5959
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
6060

6161
val model = new StreamingLinearRegressionWithSGD()
62-
.setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0)))
62+
.setInitialWeights(Vectors.zeros(args(3).toInt))
6363

6464
model.trainOn(trainingData)
65-
model.predictOn(testData).print()
65+
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
6666

6767
ssc.start()
6868
ssc.awaitTermination()

mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.apache.spark.annotation.DeveloperApi
20+
import scala.reflect.ClassTag
21+
2122
import org.apache.spark.Logging
23+
import org.apache.spark.annotation.DeveloperApi
24+
import org.apache.spark.mllib.linalg.Vector
25+
import org.apache.spark.streaming.StreamingContext._
2226
import org.apache.spark.streaming.dstream.DStream
2327

2428
/**
@@ -92,15 +96,30 @@ abstract class StreamingLinearAlgorithm[
9296
/**
9397
* Use the model to make predictions on batches of data from a DStream
9498
*
95-
* @param data DStream containing labeled data
99+
* @param data DStream containing feature vectors
96100
* @return DStream containing predictions
97101
*/
98-
def predictOn(data: DStream[LabeledPoint]): DStream[Double] = {
102+
def predictOn(data: DStream[Vector]): DStream[Double] = {
99103
if (Option(model.weights) == None) {
100-
logError("Initial weights must be set before starting prediction")
101-
throw new IllegalArgumentException
104+
val msg = "Initial weights must be set before starting prediction"
105+
logError(msg)
106+
throw new IllegalArgumentException(msg)
102107
}
103-
data.map(x => model.predict(x.features))
108+
data.map(model.predict)
104109
}
105110

111+
/**
112+
* Use the model to make predictions on the values of a DStream and carry over its keys.
113+
* @param data DStream containing feature vectors
114+
* @tparam K key type
115+
* @return DStream containing the input keys and the predictions as values
116+
*/
117+
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
118+
if (Option(model.weights) == None) {
119+
val msg = "Initial weights must be set before starting prediction"
120+
logError(msg)
121+
throw new IllegalArgumentException(msg)
122+
}
123+
data.mapValues(model.predict)
124+
}
106125
}

0 commit comments

Comments
 (0)