diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 0e3ec4d35f9b6..715916296e417 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -96,6 +96,13 @@ def numClasses(self): """ return self._call_java("numClasses") + @since("3.0.0") + def predictRaw(self, value): + """ + Raw prediction for each possible label. + """ + return self._call_java("predictRaw", value) + class _JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _JavaClassifierParams): """ @@ -149,6 +156,13 @@ def setThresholds(self, value): """ return self._set(thresholds=value) + @since("3.0.0") + def predictProbability(self, value): + """ + Predict the probability of each class given the features. + """ + return self._call_java("predictProbability", value) + class _LinearSVCParams(_JavaClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol, HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold): @@ -211,6 +225,8 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF() >>> model.predict(test0.head().features) 1.0 + >>> model.predictRaw(test0.head().features) + DenseVector([-1.4831, 1.4831]) >>> result = model.transform(test0).head() >>> result.newPrediction 1.0 @@ -568,6 +584,10 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams, >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF() >>> blorModel.predict(test0.head().features) 1.0 + >>> blorModel.predictRaw(test0.head().features) + DenseVector([-3.54..., 3.54...]) + >>> blorModel.predictProbability(test0.head().features) + DenseVector([0.028, 0.972]) >>> result = blorModel.transform(test0).head() >>> result.prediction 1.0 @@ -1148,6 +1168,10 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.predict(test0.head().features) 0.0 + >>> model.predictRaw(test0.head().features) + DenseVector([1.0, 0.0]) + >>> model.predictProbability(test0.head().features) + DenseVector([1.0, 0.0]) >>> result = model.transform(test0).head() >>> result.prediction 0.0 @@ -1379,6 +1403,10 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.predict(test0.head().features) 0.0 + >>> model.predictRaw(test0.head().features) + DenseVector([2.0, 0.0]) + >>> model.predictProbability(test0.head().features) + DenseVector([1.0, 0.0]) >>> result = model.transform(test0).head() >>> result.prediction 0.0 @@ -1640,6 +1668,10 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams, >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.predict(test0.head().features) 0.0 + >>> model.predictRaw(test0.head().features) + DenseVector([1.1697, -1.1697]) + >>> model.predictProbability(test0.head().features) + DenseVector([0.9121, 0.0879]) >>> result = model.transform(test0).head() >>> result.prediction 0.0 @@ -1959,6 +1991,10 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() >>> model.predict(test0.head().features) 1.0 + >>> model.predictRaw(test0.head().features) + DenseVector([-1.72..., -0.99...]) + >>> model.predictProbability(test0.head().features) + DenseVector([0.32..., 0.67...]) >>> result = model.transform(test0).head() >>> result.prediction 1.0 @@ -2174,6 +2210,10 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer ... (Vectors.dense([0.0, 0.0]),)], ["features"]) >>> model.predict(testDF.head().features) 1.0 + >>> model.predictRaw(testDF.head().features) + DenseVector([-16.208, 16.344]) + >>> model.predictProbability(testDF.head().features) + DenseVector([0.0, 1.0]) >>> model.transform(testDF).select("features", "prediction").show() +---------+----------+ | features|prediction| @@ -2791,6 +2831,10 @@ class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, Ja ... (Vectors.dense(0.5),), ... (Vectors.dense(1.0),), ... (Vectors.dense(2.0),)], ["features"]) + >>> model.predictRaw(test0.head().features) + DenseVector([22.13..., -22.13...]) + >>> model.predictProbability(test0.head().features) + DenseVector([1.0, 0.0]) >>> model.transform(test0).select("features", "probability").show(10, False) +--------+------------------------------------------+ |features|probability |