Skip to content

Commit 80f3bcb

Browse files
viiryamengxr
authored andcommitted
[SPARK-5652][Mllib] Use broadcasted weights in LogisticRegressionModel
`LogisticRegressionModel`'s `predictPoint` should directly use broadcasted weights. This pr also fixes the compilation errors of two unit test suite: `JavaLogisticRegressionSuite ` and `JavaLinearRegressionSuite`. Author: Liang-Chi Hsieh <[email protected]> Closes #4429 from viirya/use_bcvalue and squashes the following commits: 5a797e5 [Liang-Chi Hsieh] Use broadcasted weights. Fix compilation error.
1 parent 0d74bd7 commit 80f3bcb

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
3333
*
3434
* @param weights Weights computed for every feature.
3535
* @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression.
36-
* In Multinomial Logistic Regression, the intercepts will not be a single values,
36+
* In Multinomial Logistic Regression, the intercepts will not be a single value,
3737
* so the intercepts will be part of the weights.)
3838
* @param numFeatures the dimension of the features.
3939
* @param numClasses the number of possible outcomes for k classes classification problem in
@@ -107,7 +107,7 @@ class LogisticRegressionModel (
107107
// If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression.
108108
if (numClasses == 2) {
109109
require(numFeatures == weightMatrix.size)
110-
val margin = dot(weights, dataMatrix) + intercept
110+
val margin = dot(weightMatrix, dataMatrix) + intercept
111111
val score = 1.0 / (1.0 + math.exp(-margin))
112112
threshold match {
113113
case Some(t) => if (score > t) 1.0 else 0.0
@@ -116,11 +116,11 @@ class LogisticRegressionModel (
116116
} else {
117117
val dataWithBiasSize = weightMatrix.size / (numClasses - 1)
118118

119-
val weightsArray = weights match {
119+
val weightsArray = weightMatrix match {
120120
case dv: DenseVector => dv.values
121121
case _ =>
122122
throw new IllegalArgumentException(
123-
s"weights only supports dense vector but got type ${weights.getClass}.")
123+
s"weights only supports dense vector but got type ${weightMatrix.getClass}.")
124124
}
125125

126126
val margins = (0 until numClasses - 1).map { i =>

mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void logisticRegressionWithSetters() {
8484
.setThreshold(0.6)
8585
.setProbabilityCol("myProbability");
8686
LogisticRegressionModel model = lr.fit(dataset);
87-
assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
87+
assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
8888
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
8989
assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
9090
assert(model.getThreshold() == 0.6);
@@ -109,7 +109,7 @@ public void logisticRegressionWithSetters() {
109109
// Call fit() with new params, and check as many params as we can.
110110
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
111111
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
112-
assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
112+
assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
113113
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
114114
assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
115115
assert(model2.getThreshold() == 0.4);

mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ public void linearRegressionWithSetters() {
7676
.setMaxIter(10)
7777
.setRegParam(1.0);
7878
LinearRegressionModel model = lr.fit(dataset);
79-
assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
79+
assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
8080
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
8181

8282
// Call fit() with new params, and check as many params as we can.
8383
LinearRegressionModel model2 =
8484
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
85-
assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
85+
assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
8686
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
8787
assert(model2.getPredictionCol().equals("thePred"));
8888
}

0 commit comments

Comments
 (0)