diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 51a3e8efe8b4..1c308308a02a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -404,6 +404,18 @@ def test_copy_param_extras(self): self.assertEqual(tp._paramMap, copied_no_extra) self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + def test_logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegression + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + class EvaluatorTests(SparkSessionTestCase): @@ -807,18 +819,6 @@ def test_logistic_regression(self): except OSError: pass - def logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegressionModel - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - def _compare_params(self, m1, m2, param): """ Compare 2 ML Params instances for the given param, and assert both have the same param value