Skip to content

Commit 8f5daf9

Browse files
SPARK-3278 added comments and cleaned up api to consistently handle weights
1 parent 629a1ce commit 8f5daf9

File tree

4 files changed

+57
-66
lines changed

4 files changed

+57
-66
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,39 +17,51 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.apache.spark.mllib.linalg.{Vectors, Vector}
21-
import org.apache.spark.mllib.regression.MonotonicityConstraint.Enum.MonotonicityConstraint
20+
import org.apache.spark.mllib.linalg.Vector
21+
import org.apache.spark.mllib.regression.MonotonicityConstraint.MonotonicityConstraint.{Isotonic, MonotonicityConstraint}
2222
import org.apache.spark.rdd.RDD
2323

24+
/**
25+
* Monotonicity constrains for monotone regression
26+
* Isotonic (increasing)
27+
* Antitonic (decreasing)
28+
*/
2429
object MonotonicityConstraint {
2530

26-
object Enum {
31+
object MonotonicityConstraint {
2732

2833
sealed trait MonotonicityConstraint {
2934
private[regression] def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean
3035
}
3136

37+
/**
38+
* Isotonic monotonicity constraint. Increasing sequence
39+
*/
3240
case object Isotonic extends MonotonicityConstraint {
3341
override def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean = {
3442
current.label <= next.label
3543
}
3644
}
3745

46+
/**
47+
* Antitonic monotonicity constrain. Decreasing sequence
48+
*/
3849
case object Antitonic extends MonotonicityConstraint {
3950
override def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean = {
4051
current.label >= next.label
4152
}
4253
}
4354
}
4455

45-
val Isotonic = Enum.Isotonic
46-
val Antitonic = Enum.Antitonic
56+
val Isotonic = MonotonicityConstraint.Isotonic
57+
val Antitonic = MonotonicityConstraint.Antitonic
4758
}
4859

4960
/**
5061
* Regression model for Isotonic regression
5162
*
5263
* @param predictions Weights computed for every feature.
64+
* @param monotonicityConstraint specifies if the sequence is increasing or decreasing
5365
*/
5466
class IsotonicRegressionModel(
5567
val predictions: Seq[WeightedLabeledPoint],
@@ -59,10 +71,11 @@ class IsotonicRegressionModel(
5971
override def predict(testData: RDD[Vector]): RDD[Double] =
6072
testData.map(predict)
6173

62-
//take the highest of elements smaller than our feature or weight with lowest feature
63-
override def predict(testData: Vector): Double =
74+
override def predict(testData: Vector): Double = {
75+
//take the highest of data points smaller than our feature or data point with lowest feature
6476
(predictions.head +:
6577
predictions.filter(y => y.features.toArray.head <= testData.toArray.head)).last.label
78+
}
6679
}
6780

6881
/**
@@ -71,49 +84,40 @@ class IsotonicRegressionModel(
7184
trait IsotonicRegressionAlgorithm
7285
extends Serializable {
7386

87+
/**
88+
* Creates isotonic regression model with given parameters
89+
*
90+
* @param predictions labels estimated using isotonic regression algorithm. Used for predictions on new data points.
91+
* @param monotonicityConstraint isotonic or antitonic
92+
* @return isotonic regression model
93+
*/
7494
protected def createModel(
75-
weights: Seq[WeightedLabeledPoint],
95+
predictions: Seq[WeightedLabeledPoint],
7696
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
7797

7898
/**
7999
* Run algorithm to obtain isotonic regression model
100+
*
80101
* @param input data
81102
* @param monotonicityConstraint ascending or descenting
82-
* @return model
103+
* @return isotonic regression model
83104
*/
84105
def run(
85106
input: RDD[WeightedLabeledPoint],
86107
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
87-
88-
/**
89-
* Run algorithm to obtain isotonic regression model
90-
* @param input data
91-
* @param monotonicityConstraint asc or desc
92-
* @param weights weights
93-
* @return
94-
*/
95-
def run(
96-
input: RDD[WeightedLabeledPoint],
97-
monotonicityConstraint: MonotonicityConstraint,
98-
weights: Vector): IsotonicRegressionModel
99108
}
100109

101-
class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
110+
/**
111+
* Parallel pool adjacent violators algorithm for monotone regression
112+
*/
113+
class PoolAdjacentViolators private [mllib]
114+
extends IsotonicRegressionAlgorithm {
102115

103116
override def run(
104117
input: RDD[WeightedLabeledPoint],
105118
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
106119
createModel(
107-
parallelPoolAdjacentViolators(input, monotonicityConstraint, Vectors.dense(Array(0d))),
108-
monotonicityConstraint)
109-
}
110-
111-
override def run(
112-
input: RDD[WeightedLabeledPoint],
113-
monotonicityConstraint: MonotonicityConstraint,
114-
weights: Vector): IsotonicRegressionModel = {
115-
createModel(
116-
parallelPoolAdjacentViolators(input, monotonicityConstraint, weights),
120+
parallelPoolAdjacentViolators(input, monotonicityConstraint),
117121
monotonicityConstraint)
118122
}
119123

@@ -180,16 +184,15 @@ class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
180184

181185
/**
182186
* Performs parallel pool adjacent violators algorithm
183-
* Calls PAVA on each partition and then again on the result
187+
* Calls Pool adjacent violators on each partition and then again on the result
184188
*
185189
* @param testData input
186190
* @param monotonicityConstraint asc or desc
187191
* @return result
188192
*/
189193
private def parallelPoolAdjacentViolators(
190194
testData: RDD[WeightedLabeledPoint],
191-
monotonicityConstraint: MonotonicityConstraint,
192-
weights: Vector): Seq[WeightedLabeledPoint] = {
195+
monotonicityConstraint: MonotonicityConstraint): Seq[WeightedLabeledPoint] = {
193196

194197
poolAdjacentViolators(
195198
testData
@@ -201,39 +204,24 @@ class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
201204
}
202205

203206
/**
204-
* Top-level methods for calling IsotonicRegression.
207+
* Top-level methods for monotone regression (either isotonic or antitonic).
205208
*/
206209
object IsotonicRegression {
207210

208211
/**
209-
* Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number
210-
* of iterations of gradient descent using the specified step size. Each iteration uses
211-
* `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
212-
* in gradient descent are initialized using the initial weights provided.
212+
* Train a monotone regression model given an RDD of (label, features, weight).
213+
* Currently only one dimensional algorithm is supported (features.length is one)
214+
* Label is the dependent y value
215+
* Weight of the data point is the number of measurements. Default is 1
213216
*
214-
* @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
215-
* matrix A as well as the corresponding right hand side label y
216-
* @param weights Initial set of weights to be used. Array should be equal in size to
217-
* the number of features in the data.
217+
* @param input RDD of (label, array of features, weight). Each point describes a row of the data
218+
* matrix A as well as the corresponding right hand side label y
219+
* and weight as number of measurements
220+
* @param monotonicityConstraint
218221
*/
219222
def train(
220223
input: RDD[WeightedLabeledPoint],
221-
monotonicityConstraint: MonotonicityConstraint,
222-
weights: Vector): IsotonicRegressionModel = {
223-
new PoolAdjacentViolators().run(input, monotonicityConstraint, weights)
224-
}
225-
226-
/**
227-
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
228-
* of iterations of gradient descent using the specified step size. Each iteration uses
229-
* `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
230-
*
231-
* @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
232-
* matrix A as well as the corresponding right hand side label y
233-
*/
234-
def train(
235-
input: RDD[WeightedLabeledPoint],
236-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
224+
monotonicityConstraint: MonotonicityConstraint = Isotonic): IsotonicRegressionModel = {
237225
new PoolAdjacentViolators().run(input, monotonicityConstraint)
238226
}
239227
}

mllib/src/main/scala/org/apache/spark/mllib/regression/WeightedLabeledPoint.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,21 @@ import scala.beans.BeanInfo
2525
object WeightedLabeledPointConversions {
2626
implicit def labeledPointToWeightedLabeledPoint(
2727
labeledPoint: LabeledPoint): WeightedLabeledPoint = {
28-
WeightedLabeledPoint(labeledPoint.label, labeledPoint.features, 1)
28+
WeightedLabeledPoint(labeledPoint.label, labeledPoint.features)
2929
}
3030

3131
implicit def labeledPointRDDToWeightedLabeledPointRDD(
3232
rdd: RDD[LabeledPoint]): RDD[WeightedLabeledPoint] = {
33-
rdd.map(lp => WeightedLabeledPoint(lp.label, lp.features, 1))
33+
rdd.map(lp => WeightedLabeledPoint(lp.label, lp.features))
3434
}
3535
}
3636

3737
/**
38-
* Labeled point with weight
38+
* Class that represents the features and labels of a data point with associated weight
39+
*
40+
* @param label Label for this data point.
41+
* @param features List of features for this data point.
42+
* @param weight Weight of the data point. Defaults to 1.
3943
*/
4044
@BeanInfo
41-
case class WeightedLabeledPoint(label: Double, features: Vector, weight: Double)
45+
case class WeightedLabeledPoint(label: Double, features: Vector, weight: Double = 1)

mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.junit.Test;
2929

3030
import java.io.Serializable;
31-
import java.util.Arrays;
3231
import java.util.List;
3332

3433
public class JavaIsotonicRegressionSuite implements Serializable {

mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.mllib.regression
1919

2020
import org.apache.spark.mllib.linalg.Vectors
21-
import org.apache.spark.mllib.regression.MonotonicityConstraint.Enum.{Antitonic, Isotonic}
21+
import org.apache.spark.mllib.regression.MonotonicityConstraint.MonotonicityConstraint.{Antitonic, Isotonic}
2222
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
2323
import org.scalatest.{Matchers, FunSuite}
2424
import WeightedLabeledPointConversions._

0 commit comments

Comments
 (0)