-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-15113][PySpark][ML] Add missing num features num classes #12889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0ef9826
256c550
c1961ae
beb6920
e3b01f5
f5c69f1
6e35559
a127e7c
020c096
45570f5
0d7defa
a15c2a4
8a30c7a
c0f2e80
9283e3d
5045f7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,23 @@ | |
| 'OneVsRest', 'OneVsRestModel'] | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class JavaClassificationModel(JavaPredictionModel): | ||
| """ | ||
| (Private) Java Model produced by a ``Classifier``. | ||
| Classes are indexed {0, 1, ..., numClasses - 1}. | ||
| To be mixed in with class:`pyspark.ml.JavaModel` | ||
| """ | ||
|
|
||
| @property | ||
| @since("2.1.0") | ||
| def numClasses(self): | ||
| """ | ||
| Number of classes (values which the label can take). | ||
| """ | ||
| return self._call_java("numClasses") | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, | ||
| HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, | ||
|
|
@@ -212,7 +229,7 @@ def _checkThresholdConsistency(self): | |
| " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) | ||
|
|
||
|
|
||
| class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
| class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): | ||
| """ | ||
| Model fitted by LogisticRegression. | ||
|
|
||
|
|
@@ -522,6 +539,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred | |
| 1 | ||
| >>> model.featureImportances | ||
| SparseVector(1, {0: 1.0}) | ||
| >>> model.numFeatures | ||
| 1 | ||
| >>> model.numClasses | ||
| 2 | ||
| >>> print(model.toDebugString) | ||
| DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes... | ||
| >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) | ||
|
|
@@ -595,7 +616,8 @@ def _create_model(self, java_model): | |
|
|
||
|
|
||
| @inherit_doc | ||
| class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): | ||
| class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @holdenk are we not missing out
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious to know why we don't expose |
||
| JavaMLReadable): | ||
| """ | ||
| Model fitted by DecisionTreeClassifier. | ||
|
|
||
|
|
@@ -722,7 +744,8 @@ def _create_model(self, java_model): | |
| return RandomForestClassificationModel(java_model) | ||
|
|
||
|
|
||
| class RandomForestClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable): | ||
| class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, | ||
| JavaMLReadable): | ||
| """ | ||
| Model fitted by RandomForestClassifier. | ||
|
|
||
|
|
@@ -873,7 +896,8 @@ def getLossType(self): | |
| return self.getOrDefault(self.lossType) | ||
|
|
||
|
|
||
| class GBTClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable): | ||
| class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, | ||
| JavaMLReadable): | ||
| """ | ||
| Model fitted by GBTClassifier. | ||
|
|
||
|
|
@@ -1027,7 +1051,7 @@ def getModelType(self): | |
| return self.getOrDefault(self.modelType) | ||
|
|
||
|
|
||
| class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
| class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): | ||
| """ | ||
| Model fitted by NaiveBayes. | ||
|
|
||
|
|
@@ -1226,7 +1250,8 @@ def getInitialWeights(self): | |
| return self.getOrDefault(self.initialWeights) | ||
|
|
||
|
|
||
| class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
| class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable, | ||
| JavaMLReadable): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction | |
| True | ||
| >>> model.intercept == model2.intercept | ||
| True | ||
| >>> model.numFeatures | ||
| 1 | ||
|
|
||
| .. versionadded:: 1.4.0 | ||
| """ | ||
|
|
@@ -126,7 +128,7 @@ def _create_model(self, java_model): | |
| return LinearRegressionModel(java_model) | ||
|
|
||
|
|
||
| class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
| class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): | ||
| """ | ||
| Model fitted by :class:`LinearRegression`. | ||
|
|
||
|
|
@@ -654,6 +656,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi | |
| 3 | ||
| >>> model.featureImportances | ||
| SparseVector(1, {0: 1.0}) | ||
| >>> model.numFeatures | ||
| 1 | ||
| >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) | ||
| >>> model.transform(test0).head().prediction | ||
| 0.0 | ||
|
|
@@ -719,7 +723,7 @@ def _create_model(self, java_model): | |
|
|
||
|
|
||
| @inherit_doc | ||
| class DecisionTreeModel(JavaModel): | ||
| class DecisionTreeModel(JavaModel, JavaPredictionModel): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @holdenk what about
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was going to leave the ones which aren't base classes for classification models for whoever picks up update the regression models (
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @holdenk it doesn't look like anyone's doing the regression side, so should we just roll those into this one PR?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That seems reasonable, I've got some less direct code things I need to catch up on for the next few days, but if no one picks up the regression stuff in the meantime I'll roll that into this PR. |
||
| """ | ||
| Abstraction for Decision Tree models. | ||
|
|
||
|
|
@@ -843,6 +847,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi | |
| >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) | ||
| >>> model.transform(test0).head().prediction | ||
| 0.0 | ||
| >>> model.numFeatures | ||
| 1 | ||
| >>> model.trees | ||
| [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] | ||
| >>> model.getNumTrees | ||
|
|
@@ -909,7 +915,8 @@ def _create_model(self, java_model): | |
| return RandomForestRegressionModel(java_model) | ||
|
|
||
|
|
||
| class RandomForestRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable): | ||
| class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, | ||
| JavaMLReadable): | ||
| """ | ||
| Model fitted by :class:`RandomForestRegressor`. | ||
|
|
||
|
|
@@ -958,6 +965,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, | |
| >>> model = gbt.fit(df) | ||
| >>> model.featureImportances | ||
| SparseVector(1, {0: 1.0}) | ||
| >>> model.numFeatures | ||
| 1 | ||
| >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) | ||
| True | ||
| >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) | ||
|
|
@@ -1047,7 +1056,7 @@ def getLossType(self): | |
| return self.getOrDefault(self.lossType) | ||
|
|
||
|
|
||
| class GBTRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable): | ||
| class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): | ||
| """ | ||
| Model fitted by :class:`GBTRegressor`. | ||
|
|
||
|
|
@@ -1307,6 +1316,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha | |
| True | ||
| >>> model.coefficients | ||
| DenseVector([1.5..., -1.0...]) | ||
| >>> model.numFeatures | ||
| 2 | ||
| >>> abs(model.intercept - 1.5) < 0.001 | ||
| True | ||
| >>> glr_path = temp_path + "/glr" | ||
|
|
@@ -1412,7 +1423,8 @@ def getLink(self): | |
| return self.getOrDefault(self.link) | ||
|
|
||
|
|
||
| class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
| class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, | ||
| JavaMLReadable): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll need a
@Since("2.1.0")on thisThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numFeatures has always been here its just been the default implementation - but I guess the since wouldn't be too confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think adding a Since might actually be somewhat counter intuitive - how about a Javadoc note which says that it is now defined for this model starting in 2.1.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok fair point - I don't feel super strongly about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need to add this don't we? Otherwise it is the only public method in this class that doesn't have it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base class has
@Since("1.6.0")on the method - so it has been public since 1.6 already.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that reflected in the documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YUp, e.g. LogisticRegressionModel, RandomForestClassificationModel etc