1717
1818package org .apache .spark .mllib .regression
1919
20- import org .apache .spark .mllib .linalg .{ Vectors , Vector }
21- import org .apache .spark .mllib .regression .MonotonicityConstraint .Enum . MonotonicityConstraint
20+ import org .apache .spark .mllib .linalg .Vector
21+ import org .apache .spark .mllib .regression .MonotonicityConstraint .MonotonicityConstraint .{ Isotonic , MonotonicityConstraint }
2222import org .apache .spark .rdd .RDD
2323
24+ /**
25+ * Monotonicity constrains for monotone regression
26+ * Isotonic (increasing)
27+ * Antitonic (decreasing)
28+ */
2429object MonotonicityConstraint {
2530
26- object Enum {
31+ object MonotonicityConstraint {
2732
2833 sealed trait MonotonicityConstraint {
2934 private [regression] def holds (current : WeightedLabeledPoint , next : WeightedLabeledPoint ): Boolean
3035 }
3136
37+ /**
38+ * Isotonic monotonicity constraint. Increasing sequence
39+ */
3240 case object Isotonic extends MonotonicityConstraint {
3341 override def holds (current : WeightedLabeledPoint , next : WeightedLabeledPoint ): Boolean = {
3442 current.label <= next.label
3543 }
3644 }
3745
46+ /**
47+ * Antitonic monotonicity constrain. Decreasing sequence
48+ */
3849 case object Antitonic extends MonotonicityConstraint {
3950 override def holds (current : WeightedLabeledPoint , next : WeightedLabeledPoint ): Boolean = {
4051 current.label >= next.label
4152 }
4253 }
4354 }
4455
45- val Isotonic = Enum .Isotonic
46- val Antitonic = Enum .Antitonic
56+ val Isotonic = MonotonicityConstraint .Isotonic
57+ val Antitonic = MonotonicityConstraint .Antitonic
4758}
4859
4960/**
5061 * Regression model for Isotonic regression
5162 *
5263 * @param predictions Weights computed for every feature.
64+ * @param monotonicityConstraint specifies if the sequence is increasing or decreasing
5365 */
5466class IsotonicRegressionModel (
5567 val predictions : Seq [WeightedLabeledPoint ],
@@ -59,10 +71,11 @@ class IsotonicRegressionModel(
5971 override def predict (testData : RDD [Vector ]): RDD [Double ] =
6072 testData.map(predict)
6173
62- // take the highest of elements smaller than our feature or weight with lowest feature
63- override def predict ( testData : Vector ) : Double =
74+ override def predict ( testData : Vector ) : Double = {
75+ // take the highest of data points smaller than our feature or data point with lowest feature
6476 (predictions.head +:
6577 predictions.filter(y => y.features.toArray.head <= testData.toArray.head)).last.label
78+ }
6679}
6780
6881/**
@@ -71,49 +84,40 @@ class IsotonicRegressionModel(
7184trait IsotonicRegressionAlgorithm
7285 extends Serializable {
7386
87+ /**
88+ * Creates isotonic regression model with given parameters
89+ *
90+ * @param predictions labels estimated using isotonic regression algorithm. Used for predictions on new data points.
91+ * @param monotonicityConstraint isotonic or antitonic
92+ * @return isotonic regression model
93+ */
7494 protected def createModel (
75- weights : Seq [WeightedLabeledPoint ],
95+ predictions : Seq [WeightedLabeledPoint ],
7696 monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel
7797
7898 /**
7999 * Run algorithm to obtain isotonic regression model
100+ *
80101 * @param input data
81102 * @param monotonicityConstraint ascending or descenting
82- * @return model
103+ * @return isotonic regression model
83104 */
84105 def run (
85106 input : RDD [WeightedLabeledPoint ],
86107 monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel
87-
88- /**
89- * Run algorithm to obtain isotonic regression model
90- * @param input data
91- * @param monotonicityConstraint asc or desc
92- * @param weights weights
93- * @return
94- */
95- def run (
96- input : RDD [WeightedLabeledPoint ],
97- monotonicityConstraint : MonotonicityConstraint ,
98- weights : Vector ): IsotonicRegressionModel
99108}
100109
101- class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
110+ /**
111+ * Parallel pool adjacent violators algorithm for monotone regression
112+ */
113+ class PoolAdjacentViolators private [mllib]
114+ extends IsotonicRegressionAlgorithm {
102115
103116 override def run (
104117 input : RDD [WeightedLabeledPoint ],
105118 monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel = {
106119 createModel(
107- parallelPoolAdjacentViolators(input, monotonicityConstraint, Vectors .dense(Array (0d ))),
108- monotonicityConstraint)
109- }
110-
111- override def run (
112- input : RDD [WeightedLabeledPoint ],
113- monotonicityConstraint : MonotonicityConstraint ,
114- weights : Vector ): IsotonicRegressionModel = {
115- createModel(
116- parallelPoolAdjacentViolators(input, monotonicityConstraint, weights),
120+ parallelPoolAdjacentViolators(input, monotonicityConstraint),
117121 monotonicityConstraint)
118122 }
119123
@@ -180,16 +184,15 @@ class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
180184
181185 /**
182186 * Performs parallel pool adjacent violators algorithm
183- * Calls PAVA on each partition and then again on the result
187+ * Calls Pool adjacent violators on each partition and then again on the result
184188 *
185189 * @param testData input
186190 * @param monotonicityConstraint asc or desc
187191 * @return result
188192 */
189193 private def parallelPoolAdjacentViolators (
190194 testData : RDD [WeightedLabeledPoint ],
191- monotonicityConstraint : MonotonicityConstraint ,
192- weights : Vector ): Seq [WeightedLabeledPoint ] = {
195+ monotonicityConstraint : MonotonicityConstraint ): Seq [WeightedLabeledPoint ] = {
193196
194197 poolAdjacentViolators(
195198 testData
@@ -201,39 +204,24 @@ class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
201204}
202205
203206/**
204- * Top-level methods for calling IsotonicRegression .
207+ * Top-level methods for monotone regression (either isotonic or antitonic) .
205208 */
206209object IsotonicRegression {
207210
208211 /**
209- * Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number
210- * of iterations of gradient descent using the specified step size. Each iteration uses
211- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
212- * in gradient descent are initialized using the initial weights provided.
212+ * Train a monotone regression model given an RDD of (label, features, weight).
213+ * Currently only one dimensional algorithm is supported (features.length is one)
214+ * Label is the dependent y value
215+ * Weight of the data point is the number of measurements. Default is 1
213216 *
214- * @param input RDD of (label, array of features) pairs . Each pair describes a row of the data
215- * matrix A as well as the corresponding right hand side label y
216- * @param weights Initial set of weights to be used. Array should be equal in size to
217- * the number of features in the data.
217+ * @param input RDD of (label, array of features, weight) . Each point describes a row of the data
218+ * matrix A as well as the corresponding right hand side label y
219+ * and weight as number of measurements
220+ * @param monotonicityConstraint
218221 */
219222 def train (
220223 input : RDD [WeightedLabeledPoint ],
221- monotonicityConstraint : MonotonicityConstraint ,
222- weights : Vector ): IsotonicRegressionModel = {
223- new PoolAdjacentViolators ().run(input, monotonicityConstraint, weights)
224- }
225-
226- /**
227- * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
228- * of iterations of gradient descent using the specified step size. Each iteration uses
229- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
230- *
231- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
232- * matrix A as well as the corresponding right hand side label y
233- */
234- def train (
235- input : RDD [WeightedLabeledPoint ],
236- monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel = {
224+ monotonicityConstraint : MonotonicityConstraint = Isotonic ): IsotonicRegressionModel = {
237225 new PoolAdjacentViolators ().run(input, monotonicityConstraint)
238226 }
239227}
0 commit comments