Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 51 additions & 38 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,9 +1480,56 @@ def evaluateEachIteration(self, dataset, loss):
return self._call_java("evaluateEachIteration", dataset, loss)


class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasMaxIter, HasTol, HasFitIntercept,
HasAggregationDepth):
"""
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.

.. versionadded:: 3.0.0
"""

censorCol = Param(
Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
quantileProbabilities = Param(
Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.",
typeConverter=TypeConverters.toListFloat)
quantilesCol = Param(
Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.",
typeConverter=TypeConverters.toString)

@since("1.6.0")
def getCensorCol(self):
"""
Gets the value of censorCol or its default value.
"""
return self.getOrDefault(self.censorCol)

@since("1.6.0")
def getQuantileProbabilities(self):
"""
Gets the value of quantileProbabilities or its default value.
"""
return self.getOrDefault(self.quantileProbabilities)

@since("1.6.0")
def getQuantilesCol(self):
"""
Gets the value of quantilesCol or its default value.
"""
return self.getOrDefault(self.quantilesCol)


@inherit_doc
class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
HasAggregationDepth, JavaMLWritable, JavaMLReadable):
class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression

Expand Down Expand Up @@ -1529,20 +1576,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
.. versionadded:: 1.6.0
"""

censorCol = Param(Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
quantileProbabilities = \
Param(Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.",
typeConverter=TypeConverters.toListFloat)
quantilesCol = Param(Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
Expand Down Expand Up @@ -1588,43 +1621,23 @@ def setCensorCol(self, value):
"""
return self._set(censorCol=value)

@since("1.6.0")
def getCensorCol(self):
"""
Gets the value of censorCol or its default value.
"""
return self.getOrDefault(self.censorCol)

@since("1.6.0")
def setQuantileProbabilities(self, value):
"""
Sets the value of :py:attr:`quantileProbabilities`.
"""
return self._set(quantileProbabilities=value)

@since("1.6.0")
def getQuantileProbabilities(self):
"""
Gets the value of quantileProbabilities or its default value.
"""
return self.getOrDefault(self.quantileProbabilities)

@since("1.6.0")
def setQuantilesCol(self, value):
"""
Sets the value of :py:attr:`quantilesCol`.
"""
return self._set(quantilesCol=value)

@since("1.6.0")
def getQuantilesCol(self):
"""
Gets the value of quantilesCol or its default value.
"""
return self.getOrDefault(self.quantilesCol)


class AFTSurvivalRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable):
class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`AFTSurvivalRegression`.

Expand Down