@@ -43,7 +43,7 @@ object LogisticRegressionSuite {
4343 offset : Double ,
4444 scale : Double ,
4545 nPoints : Int ,
46- seed : Int ): Seq [LabeledPoint ] = {
46+ seed : Int ): Seq [LabeledPoint ] = {
4747 val rnd = new Random (seed)
4848 val x1 = Array .fill[Double ](nPoints)(rnd.nextGaussian())
4949
@@ -58,12 +58,15 @@ object LogisticRegressionSuite {
5858}
5959
6060class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
61- def validatePrediction (predictions : Seq [Double ], input : Seq [LabeledPoint ]) {
61+ def validatePrediction (
62+ predictions : Seq [Double ],
63+ input : Seq [LabeledPoint ],
64+ expectedAcc : Double = 0.83 ) {
6265 val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
6366 prediction != expected.label
6467 }
6568 // At least 83% of the predictions should be on.
66- ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83
69+ ((input.length - numOffPredictions).toDouble / input.length) should be > expectedAcc
6770 }
6871
6972 // Test if we can correctly learn A, B where Y = logistic(A + B*X)
@@ -155,6 +158,41 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
155158 validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
156159 }
157160
161+ test(" logistic regression with initial weights and non-default regularization parameter" ) {
162+ val nPoints = 10000
163+ val A = 2.0
164+ val B = - 1.5
165+
166+ val testData = LogisticRegressionSuite .generateLogisticInput(A , B , nPoints, 42 )
167+
168+ val initialB = - 1.0
169+ val initialWeights = Vectors .dense(initialB)
170+
171+ val testRDD = sc.parallelize(testData, 2 )
172+ testRDD.cache()
173+
174+ // Use half as many iterations as the previous test.
175+ val lr = new LogisticRegressionWithSGD ().setIntercept(true )
176+ lr.optimizer.
177+ setStepSize(10.0 ).
178+ setNumIterations(10 ).
179+ setRegParam(1.0 )
180+
181+ val model = lr.run(testRDD, initialWeights)
182+
183+ // Test the weights
184+ assert(model.weights(0 ) ~== - 430000.0 relTol 20000.0 )
185+ assert(model.intercept ~== 370000.0 relTol 20000.0 )
186+
187+ val validationData = LogisticRegressionSuite .generateLogisticInput(A , B , nPoints, 17 )
188+ val validationRDD = sc.parallelize(validationData, 2 )
189+ // Test prediction on RDD.
190+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8 )
191+
192+ // Test prediction on Array.
193+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8 )
194+ }
195+
158196 test(" logistic regression with initial weights with LBFGS" ) {
159197 val nPoints = 10000
160198 val A = 2.0
0 commit comments