Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pyspark.mllib.common import inherit_doc


__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator', 'Model']


@inherit_doc
Expand Down Expand Up @@ -70,6 +70,15 @@ def transform(self, dataset, params={}):
raise NotImplementedError()


@inherit_doc
class Model(Transformer):
"""
Abstract class for models that fitted by estimators.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"that fitted" --> "that are fitted"

"""

__metaclass__ = ABCMeta


@inherit_doc
class Pipeline(Estimator):
"""
Expand Down
177 changes: 175 additions & 2 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@
#

import itertools
import numpy as np

__all__ = ['ParamGridBuilder']
from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model
from pyspark.ml.util import keyword_only
from pyspark.sql.functions import rand

__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']


class ParamGridBuilder(object):
Expand Down Expand Up @@ -79,6 +85,173 @@ def build(self):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]


class CrossValidator(Estimator):
"""
K-fold cross validation.

>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.mllib.linalg import Vectors
>>> dataset = sqlContext.createDataFrame(
... [(Vectors.dense([0.0, 1.0]), 0.0),
... (Vectors.dense([1.0, 2.0]), 1.0),
... (Vectors.dense([0.55, 3.0]), 0.0),
... (Vectors.dense([0.45, 4.0]), 1.0),
... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
... ["features", "label"])
>>> lr = LogisticRegression()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
>>> cvModel = cv.fit(dataset)
>>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
>>> cvModel.transform(dataset).collect() == expected.collect()
True
"""

# a placeholder to make it appear in the generated doc
estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")

# a placeholder to make it appear in the generated doc
estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")

# a placeholder to make it appear in the generated doc
evaluator = Param(Params._dummy(), "evaluator", "evaluator for selection")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in Scala, we should say CrossValidator tries to maximize the evaluation metric.


# a placeholder to make it appear in the generated doc
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")

@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
"""
super(CrossValidator, self).__init__()
#: param for estimator to be cross-validated
self.estimator = Param(self, "estimator", "estimator to be cross-validated")
#: param for estimator param maps
self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps")
#: param for evaluator for selection
self.evaluator = Param(self, "evaluator", "evaluator for selection")
#: param for number of folds for cross validation
self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
self._setDefault(numFolds=3)
kwargs = self.__init__._input_kwargs
self._set(**kwargs)

@keyword_only
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
Sets params for cross validator.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

def setEstimator(self, value):
"""
Sets the value of :py:attr:`estimator`.
"""
self.paramMap[self.estimator] = value
return self

def getEstimator(self):
"""
Gets the value of estimator or its default value.
"""
return self.getOrDefault(self.estimator)

def setEstimatorParamMaps(self, value):
"""
Sets the value of :py:attr:`estimatorParamMaps`.
"""
self.paramMap[self.estimatorParamMaps] = value
return self

def getEstimatorParamMaps(self):
"""
Gets the value of estimatorParamMaps or its default value.
"""
return self.getOrDefault(self.estimatorParamMaps)

def setEvaluator(self, value):
"""
Sets the value of :py:attr:`evaluator`.
"""
self.paramMap[self.evaluator] = value
return self

def getEvaluator(self):
"""
Gets the value of evaluator or its default value.
"""
return self.getOrDefault(self.evaluator)

def setNumFolds(self, value):
"""
Sets the value of :py:attr:`numFolds`.
"""
self.paramMap[self.numFolds] = value
return self

def getNumFolds(self):
"""
Gets the value of numFolds or its default value.
"""
return self.getOrDefault(self.numFolds)

def fit(self, dataset, params={}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in Scala, we should check to see if users pass in parameters for the evaluator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be done by implementing the copy method (in a separate PR).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point; I guess I'm doing that right now for [https://issues.apache.org/jira/browse/SPARK-7380]

paramMap = self.extractParamMap(params)
est = paramMap[self.estimator]
epm = paramMap[self.estimatorParamMaps]
numModels = len(epm)
eva = paramMap[self.evaluator]
nFolds = paramMap[self.numFolds]
h = 1.0 / nFolds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Until we have per-parameter validation, we should check for nFolds < 2 here

randCol = self.uid + "_rand"
df = dataset.select("*", rand(0).alias(randCol))
metrics = np.zeros(numModels)
for i in range(nFolds):
validateLB = i * h
validateUB = (i + 1) * h
condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
validation = df.filter(condition)
train = df.filter(~condition)
for j in range(numModels):
model = est.fit(train, epm[j])
metric = eva.evaluate(model.transform(validation, epm[j]))
metrics[j] += metric
bestIndex = np.argmax(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
return CrossValidatorModel(bestModel)


class CrossValidatorModel(Model):
"""
Model from k-fold corss validation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: "corss"

"""

def __init__(self, bestModel):
#: best model from cross validation
self.bestModel = bestModel

def transform(self, dataset, params={}):
return self.bestModel.transform(dataset, params)


if __name__ == "__main__":
import doctest
doctest.testmod()
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.tuning tests")
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
exit(-1)
4 changes: 2 additions & 2 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
from pyspark.mllib.common import inherit_doc


Expand Down Expand Up @@ -133,7 +133,7 @@ def transform(self, dataset, params={}):


@inherit_doc
class JavaModel(JavaTransformer):
class JavaModel(Model, JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations.
Expand Down