Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
DataFrame-based machine learning APIs to let users quickly assemble and configure practical
machine learning pipelines.
"""
from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
from pyspark.ml.base import Estimator, Model, Predictor, PredictionModel, \
Transformer, UnaryTransformer
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml import classification, clustering, evaluation, feature, fpm, \
image, pipeline, recommendation, regression, stat, tuning, util, linalg, param

__all__ = [
"Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel",
"Transformer", "UnaryTransformer", "Estimator", "Model",
"Predictor", "PredictionModel",
"Pipeline", "PipelineModel",
Comment thread
zero323 marked this conversation as resolved.
Outdated
"classification", "clustering", "evaluation", "feature", "fpm", "image",
"recommendation", "regression", "stat", "tuning", "util", "linalg", "param",
]
81 changes: 80 additions & 1 deletion python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod, abstractproperty

import copy
import threading
Expand Down Expand Up @@ -246,3 +246,82 @@ def _transform(self, dataset):
transformedDataset = dataset.withColumn(self.getOutputCol(),
transformUDF(dataset[self.getInputCol()]))
return transformedDataset


@inherit_doc
class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
"""
Params for :py:class:`Predictor` and :py:class:`PredictorModel`.

.. versionadded:: 3.0.0
"""
pass


@inherit_doc
class Predictor(Estimator, _PredictorParams):
"""
Estimator for prediction tasks (regression and classification).
"""

__metaclass__ = ABCMeta

@since("3.0.0")
def setLabelCol(self, value):
"""
Sets the value of :py:attr:`labelCol`.
"""
return self._set(labelCol=value)

@since("3.0.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)


@inherit_doc
class PredictionModel(Model, _PredictorParams):
"""
Model for prediction tasks (regression and classification).
"""

__metaclass__ = ABCMeta

@since("3.0.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)

@abstractproperty
@since("2.1.0")
def numFeatures(self):
"""
Returns the number of features the model was trained on. If unknown, returns -1
"""
raise NotImplementedError()

@abstractmethod
@since("3.0.0")
def predict(self, value):
"""
Predict label for the given features.
"""
raise NotImplementedError()
Loading