diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f2bcc662030c6..1a7d39ba89450 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1480,9 +1480,56 @@ def evaluateEachIteration(self, dataset, loss): return self._call_java("evaluateEachIteration", dataset, loss) +class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasMaxIter, HasTol, HasFitIntercept, + HasAggregationDepth): + """ + Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`. + + .. versionadded:: 3.0.0 + """ + + censorCol = Param( + Params._dummy(), "censorCol", + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.", typeConverter=TypeConverters.toString) + quantileProbabilities = Param( + Params._dummy(), "quantileProbabilities", + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range (0, 1) and the array should be non-empty.", + typeConverter=TypeConverters.toListFloat) + quantilesCol = Param( + Params._dummy(), "quantilesCol", + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.", + typeConverter=TypeConverters.toString) + + @since("1.6.0") + def getCensorCol(self): + """ + Gets the value of censorCol or its default value. + """ + return self.getOrDefault(self.censorCol) + + @since("1.6.0") + def getQuantileProbabilities(self): + """ + Gets the value of quantileProbabilities or its default value. + """ + return self.getOrDefault(self.quantileProbabilities) + + @since("1.6.0") + def getQuantilesCol(self): + """ + Gets the value of quantilesCol or its default value. + """ + return self.getOrDefault(self.quantilesCol) + + @inherit_doc -class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol, - HasAggregationDepth, JavaMLWritable, JavaMLReadable): +class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams, + JavaMLWritable, JavaMLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -1529,20 +1576,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol, .. versionadded:: 1.6.0 """ - censorCol = Param(Params._dummy(), "censorCol", - "censor column name. The value of this column could be 0 or 1. " + - "If the value is 1, it means the event has occurred i.e. " + - "uncensored; otherwise censored.", typeConverter=TypeConverters.toString) - quantileProbabilities = \ - Param(Params._dummy(), "quantileProbabilities", - "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range (0, 1) and the array should be non-empty.", - typeConverter=TypeConverters.toListFloat) - quantilesCol = Param(Params._dummy(), "quantilesCol", - "quantiles column name. This column will output quantiles of " + - "corresponding quantileProbabilities if it is set.", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", @@ -1588,13 +1621,6 @@ def setCensorCol(self, value): """ return self._set(censorCol=value) - @since("1.6.0") - def getCensorCol(self): - """ - Gets the value of censorCol or its default value. - """ - return self.getOrDefault(self.censorCol) - @since("1.6.0") def setQuantileProbabilities(self, value): """ @@ -1602,13 +1628,6 @@ def setQuantileProbabilities(self, value): """ return self._set(quantileProbabilities=value) - @since("1.6.0") - def getQuantileProbabilities(self): - """ - Gets the value of quantileProbabilities or its default value. - """ - return self.getOrDefault(self.quantileProbabilities) - @since("1.6.0") def setQuantilesCol(self, value): """ @@ -1616,15 +1635,9 @@ def setQuantilesCol(self, value): """ return self._set(quantilesCol=value) - @since("1.6.0") - def getQuantilesCol(self): - """ - Gets the value of quantilesCol or its default value. - """ - return self.getOrDefault(self.quantilesCol) - -class AFTSurvivalRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams, + JavaMLWritable, JavaMLReadable): """ Model fitted by :class:`AFTSurvivalRegression`.