Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
DataFrame-based machine learning APIs to let users quickly assemble and configure practical
machine learning pipelines.
"""
from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
from pyspark.ml.base import Estimator, Model, Predictor, PredictionModel, \
Transformer, UnaryTransformer
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml import classification, clustering, evaluation, feature, fpm, \
image, pipeline, recommendation, regression, stat, tuning, util, linalg, param

__all__ = [
"Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel",
"Transformer", "UnaryTransformer", "Estimator", "Model",
"Predictor", "PredictionModel", "Pipeline", "PipelineModel",
"classification", "clustering", "evaluation", "feature", "fpm", "image",
"recommendation", "regression", "stat", "tuning", "util", "linalg", "param",
]
81 changes: 80 additions & 1 deletion python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod, abstractproperty

import copy
import threading
Expand Down Expand Up @@ -246,3 +246,82 @@ def _transform(self, dataset):
transformedDataset = dataset.withColumn(self.getOutputCol(),
transformUDF(dataset[self.getInputCol()]))
return transformedDataset


@inherit_doc
class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
"""
Params for :py:class:`Predictor` and :py:class:`PredictorModel`.

.. versionadded:: 3.0.0
"""
pass


@inherit_doc
class Predictor(Estimator, _PredictorParams):
"""
Estimator for prediction tasks (regression and classification).
"""

__metaclass__ = ABCMeta

@since("3.0.0")
def setLabelCol(self, value):
"""
Sets the value of :py:attr:`labelCol`.
"""
return self._set(labelCol=value)

@since("3.0.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)


@inherit_doc
class PredictionModel(Model, _PredictorParams):
"""
Model for prediction tasks (regression and classification).
"""

__metaclass__ = ABCMeta

@since("3.0.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)

@abstractproperty
@since("2.1.0")
def numFeatures(self):
"""
Returns the number of features the model was trained on. If unknown, returns -1
"""
raise NotImplementedError()

@abstractmethod
@since("3.0.0")
def predict(self, value):
"""
Predict label for the given features.
"""
raise NotImplementedError()
Loading