Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0")
override def write: MLWriter =
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)

override val numFeatures: Int = coefficients.size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need a @Since("2.1.0") on this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numFeatures has always been here its just been the default implementation - but I guess the since wouldn't be too confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding a Since might actually be somewhat counter intuitive - how about a Javadoc note which says that it is now defined for this model starting in 2.1.0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok fair point - I don't feel super strongly about it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to add this don't we? Otherwise it is the only public method in this class that doesn't have it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class has @Since("1.6.0") on the method - so it has been public since 1.6 already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that reflected in the documentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

@Since("2.0.0")
Expand Down
37 changes: 31 additions & 6 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@
'OneVsRest', 'OneVsRestModel']


@inherit_doc
class JavaClassificationModel(JavaPredictionModel):
"""
(Private) Java Model produced by a ``Classifier``.
Classes are indexed {0, 1, ..., numClasses - 1}.
To be mixed in with class:`pyspark.ml.JavaModel`
"""

@property
@since("2.1.0")
def numClasses(self):
"""
Number of classes (values which the label can take).
"""
return self._call_java("numClasses")


@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
Expand Down Expand Up @@ -212,7 +229,7 @@ def _checkThresholdConsistency(self):
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))


class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.

Expand Down Expand Up @@ -522,6 +539,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
1
>>> model.numClasses
2
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
Expand Down Expand Up @@ -595,7 +616,8 @@ def _create_model(self, java_model):


@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk are we not missing out GBTClassificationModel, RandomForestClassificationModel in classification? I think GBT should just be JavaPredictionModel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomForestClassificationModel and NaiveBayesModel should be extended from JavaClassificationModel, GBTClassificationModel should be JavaPredictionModel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious to know why we don't expose numClasses in GBTClassificationModel. Do we not support multiclass currently, or is there some other reason?

JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.

Expand Down Expand Up @@ -722,7 +744,8 @@ def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)


class RandomForestClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by RandomForestClassifier.

Expand Down Expand Up @@ -873,7 +896,8 @@ def getLossType(self):
return self.getOrDefault(self.lossType)


class GBTClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by GBTClassifier.

Expand Down Expand Up @@ -1027,7 +1051,7 @@ def getModelType(self):
return self.getOrDefault(self.modelType)


class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.

Expand Down Expand Up @@ -1226,7 +1250,8 @@ def getInitialWeights(self):
return self.getOrDefault(self.initialWeights)


class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
.. note:: Experimental

Expand Down
22 changes: 17 additions & 5 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
True
>>> model.intercept == model2.intercept
True
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -126,7 +128,7 @@ def _create_model(self, java_model):
return LinearRegressionModel(java_model)


class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`LinearRegression`.

Expand Down Expand Up @@ -654,6 +656,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
3
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
1
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -719,7 +723,7 @@ def _create_model(self, java_model):


@inherit_doc
class DecisionTreeModel(JavaModel):
class DecisionTreeModel(JavaModel, JavaPredictionModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk what about LinearRegressionModel, GeneralizedLinearRegressionModel (though numFeatures is missing on the Scala side), GBTRegressionModel, RandomForestRegressionModel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to leave the ones which aren't base classes for classification models for whoever picks up update the regression models (DecisionTreeModel is the base for DecisionTreeClassificationModel so I updated it).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk it doesn't look like anyone's doing the regression side, so should we just roll those into this one PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems reasonable, I've got some less direct code things I need to catch up on for the next few days, but if no one picks up the regression stuff in the meantime I'll roll that into this PR.

"""
Abstraction for Decision Tree models.

Expand Down Expand Up @@ -843,6 +847,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
>>> model.numFeatures
1
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
>>> model.getNumTrees
Expand Down Expand Up @@ -909,7 +915,8 @@ def _create_model(self, java_model):
return RandomForestRegressionModel(java_model)


class RandomForestRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by :class:`RandomForestRegressor`.

Expand Down Expand Up @@ -958,6 +965,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
>>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
1
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
Expand Down Expand Up @@ -1047,7 +1056,7 @@ def getLossType(self):
return self.getOrDefault(self.lossType)


class GBTRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`GBTRegressor`.

Expand Down Expand Up @@ -1307,6 +1316,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
>>> model.numFeatures
2
>>> abs(model.intercept - 1.5) < 0.001
True
>>> glr_path = temp_path + "/glr"
Expand Down Expand Up @@ -1412,7 +1423,8 @@ def getLink(self):
return self.getOrDefault(self.link)


class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
.. note:: Experimental

Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,19 @@ class JavaMLReadable(MLReadable):
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)


@inherit_doc
class JavaPredictionModel():
"""
(Private) Java Model for prediction tasks (regression and classification).
To be mixed in with class:`pyspark.ml.JavaModel`
"""

@property
@since("2.1.0")
def numFeatures(self):
"""
Returns the number of features the model was trained on. If unknown, returns -1
"""
return self._call_java("numFeatures")