diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 4d763cbd29d3c..a43ad466a7c80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1100,20 +1100,24 @@ class LogisticRegressionModel private[spark] ( private lazy val _intercept = interceptVector(0) private lazy val _interceptVector = interceptVector.toDense - private var _threshold = Double.NaN - private var _rawThreshold = Double.NaN - - updateBinaryThreshold() + private lazy val _binaryThresholdArray = { + val array = Array(Double.NaN, Double.NaN) + updateBinaryThresholds(array) + array + } + private def _threshold: Double = _binaryThresholdArray(0) + private def _rawThreshold: Double = _binaryThresholdArray(1) - private def updateBinaryThreshold(): Unit = { + private def updateBinaryThresholds(array: Array[Double]): Unit = { if (!isMultinomial) { - _threshold = getThreshold + val _threshold = getThreshold + array(0) = _threshold if (_threshold == 0.0) { - _rawThreshold = Double.NegativeInfinity + array(1) = Double.NegativeInfinity } else if (_threshold == 1.0) { - _rawThreshold = Double.PositiveInfinity + array(1) = Double.PositiveInfinity } else { - _rawThreshold = math.log(_threshold / (1.0 - _threshold)) + array(1) = math.log(_threshold / (1.0 - _threshold)) } } } @@ -1121,7 +1125,7 @@ class LogisticRegressionModel private[spark] ( @Since("1.5.0") override def setThreshold(value: Double): this.type = { super.setThreshold(value) - updateBinaryThreshold() + updateBinaryThresholds(_binaryThresholdArray) this } @@ -1131,7 +1135,7 @@ class LogisticRegressionModel private[spark] ( @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = { super.setThresholds(value) - updateBinaryThreshold() + updateBinaryThresholds(_binaryThresholdArray) this } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 56eadff6df078..51a6ae3c7e49b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -400,10 +400,9 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { } test("thresholds prediction") { - val blr = new LogisticRegression().setFamily("binomial") + val blr = new LogisticRegression().setFamily("binomial").setThreshold(1.0) val binaryModel = blr.fit(smallBinaryDataset) - binaryModel.setThreshold(1.0) testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { row => assert(row.getDouble(0) === 0.0) }