Skip to content

Commit a718694

Browse files
committed
LogisticRegressionModel.toString should summarize model
1 parent 845c039 commit a718694

File tree

4 files changed

+20
-0
lines changed

4 files changed

+20
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,10 @@ class LogisticRegressionModel private[spark] (
11691169
*/
11701170
@Since("1.6.0")
11711171
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
1172+
1173+
override def toString: String = {
1174+
s"${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
1175+
}
11721176
}
11731177

11741178

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,12 @@ class LogisticRegressionSuite
25942594
assert(model.getFamily === family)
25952595
}
25962596
}
2597+
2598+
test("toString") {
2599+
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
2600+
val expected = "logReg, numClasses = 2, numFeatures = 3"
2601+
assert(model.toString === expected)
2602+
}
25972603
}
25982604

25992605
object LogisticRegressionSuite {

python/pyspark/ml/classification.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
237237
True
238238
>>> blorModel.intercept == model2.intercept
239239
True
240+
>>> model2._resetUid("logReg")
241+
uid = logReg, numClasses = 2, numFeatures = 2
240242
241243
.. versionadded:: 1.3.0
242244
"""
@@ -558,6 +560,11 @@ def evaluate(self, dataset):
558560
java_blr_summary = self._call_java("evaluate", dataset)
559561
return BinaryLogisticRegressionSummary(java_blr_summary)
560562

563+
def __repr__(self):
564+
numClasses = str(self._call_java("numClasses"))
565+
numFeatures = str(self._call_java("numFeatures"))
566+
return "uid = %s, numClasses = %s, numFeatures = %s" % (self.uid, numClasses, numFeatures)
567+
561568

562569
class LogisticRegressionSummary(JavaWrapper):
563570
"""

python/pyspark/mllib/classification.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ def load(cls, sc, path):
257257
model.setThreshold(threshold)
258258
return model
259259

260+
def __repr__(self):
261+
return self._call_java("toString")
262+
260263

261264
class LogisticRegressionWithSGD(object):
262265
"""

0 commit comments

Comments
 (0)