diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index d96a6f43c0596..fd41c12ca3351 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -41,6 +41,25 @@ 'FMRegressor', 'FMRegressionModel'] +class JavaRegressor(JavaPredictor, _JavaPredictorParams): + """ + Java Regressor for regression tasks. + + .. versionadded:: 3.0.0 + """ + pass + + +class JavaRegressionModel(JavaPredictionModel, _JavaPredictorParams): + """ + Java Model produced by a ``_JavaRegressor``. + To be mixed in with class:`pyspark.ml.JavaModel` + + .. versionadded:: 3.0.0 + """ + pass + + class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter, HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver, HasAggregationDepth, HasLoss): @@ -69,7 +88,7 @@ def getEpsilon(self): @inherit_doc -class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable): +class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable): """ Linear regression. @@ -251,7 +270,7 @@ def setLoss(self, value): return self._set(lossType=value) -class LinearRegressionModel(JavaPredictionModel, _LinearRegressionParams, GeneralJavaMLWritable, +class LinearRegressionModel(JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by :class:`LinearRegression`. @@ -758,7 +777,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha @inherit_doc -class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLWritable, +class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable, JavaMLReadable): """ `Decision tree `_ @@ -953,8 +972,10 @@ def setVarianceCol(self, value): @inherit_doc -class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams, - JavaMLWritable, JavaMLReadable): +class DecisionTreeRegressionModel( + JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams, + JavaMLWritable, JavaMLReadable +): """ Model fitted by :class:`DecisionTreeRegressor`. @@ -1000,7 +1021,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): @inherit_doc -class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLWritable, +class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLWritable, JavaMLReadable): """ `Random Forest `_ @@ -1198,8 +1219,10 @@ def setMinWeightFractionPerNode(self, value): return self._set(minWeightFractionPerNode=value) -class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams, - JavaMLWritable, JavaMLReadable): +class RandomForestRegressionModel( + JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams, + JavaMLWritable, JavaMLReadable +): """ Model fitted by :class:`RandomForestRegressor`. @@ -1251,7 +1274,7 @@ def getLossType(self): @inherit_doc -class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): +class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): """ `Gradient-Boosted Trees (GBTs) `_ learning algorithm for regression. @@ -1492,7 +1515,10 @@ def setMinWeightFractionPerNode(self, value): return self._set(minWeightFractionPerNode=value) -class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): +class GBTRegressionModel( + JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams, + JavaMLWritable, JavaMLReadable +): """ Model fitted by :class:`GBTRegressor`. @@ -1582,7 +1608,7 @@ def getQuantilesCol(self): @inherit_doc -class AFTSurvivalRegression(JavaPredictor, _AFTSurvivalRegressionParams, +class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams, JavaMLWritable, JavaMLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -1723,7 +1749,7 @@ def setAggregationDepth(self, value): return self._set(aggregationDepth=value) -class AFTSurvivalRegressionModel(JavaPredictionModel, _AFTSurvivalRegressionParams, +class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionParams, JavaMLWritable, JavaMLReadable): """ Model fitted by :class:`AFTSurvivalRegression`. @@ -1855,7 +1881,7 @@ def getOffsetCol(self): @inherit_doc -class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionParams, +class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionParams, JavaMLWritable, JavaMLReadable): """ Generalized Linear Regression. @@ -2060,7 +2086,7 @@ def setAggregationDepth(self, value): return self._set(aggregationDepth=value) -class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRegressionParams, +class GeneralizedLinearRegressionModel(JavaRegressionModel, _GeneralizedLinearRegressionParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by :class:`GeneralizedLinearRegression`. @@ -2348,7 +2374,7 @@ def getInitStd(self): @inherit_doc -class FMRegressor(JavaPredictor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable): +class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable): """ Factorization Machines learning algorithm for regression. @@ -2512,7 +2538,7 @@ def setRegParam(self, value): return self._set(regParam=value) -class FMRegressionModel(JavaPredictionModel, _FactorizationMachinesParams, JavaMLWritable, +class FMRegressionModel(JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable): """ Model fitted by :class:`FMRegressor`.