From 27eab001af295008ed09ece35e4663cd21bf8fa9 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Mon, 12 Oct 2020 15:48:16 +0800 Subject: [PATCH 1/3] check code --- .../org/apache/spark/ml/classification/LogisticRegression.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 4d763cbd29d3c..b535ed59497ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1104,6 +1104,7 @@ class LogisticRegressionModel private[spark] ( private var _rawThreshold = Double.NaN updateBinaryThreshold() + log.warn(s"_threshold=${_threshold}, _rawThreshold=${_rawThreshold}") private def updateBinaryThreshold(): Unit = { if (!isMultinomial) { @@ -1205,6 +1206,7 @@ class LogisticRegressionModel private[spark] ( super.predict(features) } else { // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden. + log.warn(s"_threshold=${_threshold}, _rawThreshold=${_rawThreshold}") if (score(features) > _threshold) 1 else 0 } From 49690ebc06855231394000704a59e9f53cc69c5c Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Mon, 12 Oct 2020 16:04:33 +0800 Subject: [PATCH 2/3] use lazy array use lazy array --- .../classification/LogisticRegression.scala | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index b535ed59497ad..a43ad466a7c80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1100,21 +1100,24 @@ class LogisticRegressionModel private[spark] ( private lazy val _intercept = interceptVector(0) private lazy val _interceptVector = interceptVector.toDense - private var _threshold = Double.NaN - private var _rawThreshold = Double.NaN - - updateBinaryThreshold() - log.warn(s"_threshold=${_threshold}, _rawThreshold=${_rawThreshold}") + private lazy val _binaryThresholdArray = { + val array = Array(Double.NaN, Double.NaN) + updateBinaryThresholds(array) + array + } + private def _threshold: Double = _binaryThresholdArray(0) + private def _rawThreshold: Double = _binaryThresholdArray(1) - private def updateBinaryThreshold(): Unit = { + private def updateBinaryThresholds(array: Array[Double]): Unit = { if (!isMultinomial) { - _threshold = getThreshold + val _threshold = getThreshold + array(0) = _threshold if (_threshold == 0.0) { - _rawThreshold = Double.NegativeInfinity + array(1) = Double.NegativeInfinity } else if (_threshold == 1.0) { - _rawThreshold = Double.PositiveInfinity + array(1) = Double.PositiveInfinity } else { - _rawThreshold = math.log(_threshold / (1.0 - _threshold)) + array(1) = math.log(_threshold / (1.0 - _threshold)) } } } @@ -1122,7 +1125,7 @@ class LogisticRegressionModel private[spark] ( @Since("1.5.0") override def setThreshold(value: Double): this.type = { super.setThreshold(value) - updateBinaryThreshold() + updateBinaryThresholds(_binaryThresholdArray) this } @@ -1132,7 +1135,7 @@ class LogisticRegressionModel private[spark] ( @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = { super.setThresholds(value) - updateBinaryThreshold() + updateBinaryThresholds(_binaryThresholdArray) this } @@ -1206,7 +1209,6 @@ class LogisticRegressionModel private[spark] ( super.predict(features) } else { // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden. - log.warn(s"_threshold=${_threshold}, _rawThreshold=${_rawThreshold}") if (score(features) > _threshold) 1 else 0 } From 7a1749c2e4c40ded1546c7ce3f0a6b190c340c84 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Tue, 13 Oct 2020 09:56:06 +0800 Subject: [PATCH 3/3] update testsuite --- .../spark/ml/classification/LogisticRegressionSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 56eadff6df078..51a6ae3c7e49b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -400,10 +400,9 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { } test("thresholds prediction") { - val blr = new LogisticRegression().setFamily("binomial") + val blr = new LogisticRegression().setFamily("binomial").setThreshold(1.0) val binaryModel = blr.fit(smallBinaryDataset) - binaryModel.setThreshold(1.0) testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { row => assert(row.getDouble(0) === 0.0) }