Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@
'FMRegressor', 'FMRegressionModel']


class JavaRegressor(JavaPredictor, _JavaPredictorParams):
"""
Java Regressor for regression tasks.

.. versionadded:: 3.0.0
"""
pass


class JavaRegressionModel(JavaPredictionModel, _JavaPredictorParams):
"""
Java Model produced by a ``_JavaRegressor``.
To be mixed in with class:`pyspark.ml.JavaModel`

.. versionadded:: 3.0.0
"""
pass


class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter,
HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver,
HasAggregationDepth, HasLoss):
Expand Down Expand Up @@ -69,7 +88,7 @@ def getEpsilon(self):


@inherit_doc
class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
"""
Linear regression.

Expand Down Expand Up @@ -251,7 +270,7 @@ def setLoss(self, value):
return self._set(lossType=value)


class LinearRegressionModel(JavaPredictionModel, _LinearRegressionParams, GeneralJavaMLWritable,
class LinearRegressionModel(JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable,
JavaMLReadable, HasTrainingSummary):
"""
Model fitted by :class:`LinearRegression`.
Expand Down Expand Up @@ -758,7 +777,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha


@inherit_doc
class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLWritable,
class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable,
JavaMLReadable):
"""
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
Expand Down Expand Up @@ -953,8 +972,10 @@ def setVarianceCol(self, value):


@inherit_doc
class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams,
JavaMLWritable, JavaMLReadable):
class DecisionTreeRegressionModel(
JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams,
JavaMLWritable, JavaMLReadable
):
"""
Model fitted by :class:`DecisionTreeRegressor`.

Expand Down Expand Up @@ -1000,7 +1021,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):


@inherit_doc
class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLWritable,
class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLWritable,
JavaMLReadable):
"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
Expand Down Expand Up @@ -1198,8 +1219,10 @@ def setMinWeightFractionPerNode(self, value):
return self._set(minWeightFractionPerNode=value)


class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams,
JavaMLWritable, JavaMLReadable):
class RandomForestRegressionModel(
JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams,
JavaMLWritable, JavaMLReadable
):
"""
Model fitted by :class:`RandomForestRegressor`.

Expand Down Expand Up @@ -1251,7 +1274,7 @@ def getLossType(self):


@inherit_doc
class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for regression.
Expand Down Expand Up @@ -1492,7 +1515,10 @@ def setMinWeightFractionPerNode(self, value):
return self._set(minWeightFractionPerNode=value)


class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
class GBTRegressionModel(
JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams,
JavaMLWritable, JavaMLReadable
):
"""
Model fitted by :class:`GBTRegressor`.

Expand Down Expand Up @@ -1582,7 +1608,7 @@ def getQuantilesCol(self):


@inherit_doc
class AFTSurvivalRegression(JavaPredictor, _AFTSurvivalRegressionParams,
class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
Expand Down Expand Up @@ -1723,7 +1749,7 @@ def setAggregationDepth(self, value):
return self._set(aggregationDepth=value)


class AFTSurvivalRegressionModel(JavaPredictionModel, _AFTSurvivalRegressionParams,
class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`AFTSurvivalRegression`.
Expand Down Expand Up @@ -1855,7 +1881,7 @@ def getOffsetCol(self):


@inherit_doc
class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionParams,
class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Generalized Linear Regression.
Expand Down Expand Up @@ -2060,7 +2086,7 @@ def setAggregationDepth(self, value):
return self._set(aggregationDepth=value)


class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRegressionParams,
class GeneralizedLinearRegressionModel(JavaRegressionModel, _GeneralizedLinearRegressionParams,
JavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by :class:`GeneralizedLinearRegression`.
Expand Down Expand Up @@ -2348,7 +2374,7 @@ def getInitStd(self):


@inherit_doc
class FMRegressor(JavaPredictor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
"""
Factorization Machines learning algorithm for regression.

Expand Down Expand Up @@ -2512,7 +2538,7 @@ def setRegParam(self, value):
return self._set(regParam=value)


class FMRegressionModel(JavaPredictionModel, _FactorizationMachinesParams, JavaMLWritable,
class FMRegressionModel(JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by :class:`FMRegressor`.
Expand Down