-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-6940][MLLIB] Add CrossValidator to Python ML pipeline API #5926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,8 +16,14 @@ | |
| # | ||
|
|
||
| import itertools | ||
| import numpy as np | ||
|
|
||
| __all__ = ['ParamGridBuilder'] | ||
| from pyspark.ml.param import Params, Param | ||
| from pyspark.ml import Estimator, Model | ||
| from pyspark.ml.util import keyword_only | ||
| from pyspark.sql.functions import rand | ||
|
|
||
| __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel'] | ||
|
|
||
|
|
||
| class ParamGridBuilder(object): | ||
|
|
@@ -79,6 +85,173 @@ def build(self): | |
| return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] | ||
|
|
||
|
|
||
| class CrossValidator(Estimator): | ||
| """ | ||
| K-fold cross validation. | ||
|
|
||
| >>> from pyspark.ml.classification import LogisticRegression | ||
| >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator | ||
| >>> from pyspark.mllib.linalg import Vectors | ||
| >>> dataset = sqlContext.createDataFrame( | ||
| ... [(Vectors.dense([0.0, 1.0]), 0.0), | ||
| ... (Vectors.dense([1.0, 2.0]), 1.0), | ||
| ... (Vectors.dense([0.55, 3.0]), 0.0), | ||
| ... (Vectors.dense([0.45, 4.0]), 1.0), | ||
| ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, | ||
| ... ["features", "label"]) | ||
| >>> lr = LogisticRegression() | ||
| >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() | ||
| >>> evaluator = BinaryClassificationEvaluator() | ||
| >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) | ||
| >>> cvModel = cv.fit(dataset) | ||
| >>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset) | ||
| >>> cvModel.transform(dataset).collect() == expected.collect() | ||
| True | ||
| """ | ||
|
|
||
| # a placeholder to make it appear in the generated doc | ||
| estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") | ||
|
|
||
| # a placeholder to make it appear in the generated doc | ||
| estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") | ||
|
|
||
| # a placeholder to make it appear in the generated doc | ||
| evaluator = Param(Params._dummy(), "evaluator", "evaluator for selection") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and in Scala, we should say CrossValidator tries to maximize the evaluation metric. |
||
|
|
||
| # a placeholder to make it appear in the generated doc | ||
| numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") | ||
|
|
||
| @keyword_only | ||
| def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): | ||
| """ | ||
| __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) | ||
| """ | ||
| super(CrossValidator, self).__init__() | ||
| #: param for estimator to be cross-validated | ||
| self.estimator = Param(self, "estimator", "estimator to be cross-validated") | ||
| #: param for estimator param maps | ||
| self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps") | ||
| #: param for evaluator for selection | ||
| self.evaluator = Param(self, "evaluator", "evaluator for selection") | ||
| #: param for number of folds for cross validation | ||
| self.numFolds = Param(self, "numFolds", "number of folds for cross validation") | ||
| self._setDefault(numFolds=3) | ||
| kwargs = self.__init__._input_kwargs | ||
| self._set(**kwargs) | ||
|
|
||
| @keyword_only | ||
| def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): | ||
| """ | ||
| setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): | ||
| Sets params for cross validator. | ||
| """ | ||
| kwargs = self.setParams._input_kwargs | ||
| return self._set(**kwargs) | ||
|
|
||
| def setEstimator(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`estimator`. | ||
| """ | ||
| self.paramMap[self.estimator] = value | ||
| return self | ||
|
|
||
| def getEstimator(self): | ||
| """ | ||
| Gets the value of estimator or its default value. | ||
| """ | ||
| return self.getOrDefault(self.estimator) | ||
|
|
||
| def setEstimatorParamMaps(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`estimatorParamMaps`. | ||
| """ | ||
| self.paramMap[self.estimatorParamMaps] = value | ||
| return self | ||
|
|
||
| def getEstimatorParamMaps(self): | ||
| """ | ||
| Gets the value of estimatorParamMaps or its default value. | ||
| """ | ||
| return self.getOrDefault(self.estimatorParamMaps) | ||
|
|
||
| def setEvaluator(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`evaluator`. | ||
| """ | ||
| self.paramMap[self.evaluator] = value | ||
| return self | ||
|
|
||
| def getEvaluator(self): | ||
| """ | ||
| Gets the value of evaluator or its default value. | ||
| """ | ||
| return self.getOrDefault(self.evaluator) | ||
|
|
||
| def setNumFolds(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`numFolds`. | ||
| """ | ||
| self.paramMap[self.numFolds] = value | ||
| return self | ||
|
|
||
| def getNumFolds(self): | ||
| """ | ||
| Gets the value of numFolds or its default value. | ||
| """ | ||
| return self.getOrDefault(self.numFolds) | ||
|
|
||
| def fit(self, dataset, params={}): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and in Scala, we should check to see if users pass in parameters for the evaluator.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be done by implementing the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point; I guess I'm doing that right now for [https://issues.apache.org/jira/browse/SPARK-7380] |
||
| paramMap = self.extractParamMap(params) | ||
| est = paramMap[self.estimator] | ||
| epm = paramMap[self.estimatorParamMaps] | ||
| numModels = len(epm) | ||
| eva = paramMap[self.evaluator] | ||
| nFolds = paramMap[self.numFolds] | ||
| h = 1.0 / nFolds | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Until we have per-parameter validation, we should check for nFolds < 2 here |
||
| randCol = self.uid + "_rand" | ||
| df = dataset.select("*", rand(0).alias(randCol)) | ||
| metrics = np.zeros(numModels) | ||
| for i in range(nFolds): | ||
| validateLB = i * h | ||
| validateUB = (i + 1) * h | ||
| condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) | ||
| validation = df.filter(condition) | ||
| train = df.filter(~condition) | ||
| for j in range(numModels): | ||
| model = est.fit(train, epm[j]) | ||
| metric = eva.evaluate(model.transform(validation, epm[j])) | ||
| metrics[j] += metric | ||
| bestIndex = np.argmax(metrics) | ||
| bestModel = est.fit(dataset, epm[bestIndex]) | ||
| return CrossValidatorModel(bestModel) | ||
|
|
||
|
|
||
| class CrossValidatorModel(Model): | ||
| """ | ||
| Model from k-fold corss validation. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: "corss" |
||
| """ | ||
|
|
||
| def __init__(self, bestModel): | ||
| #: best model from cross validation | ||
| self.bestModel = bestModel | ||
|
|
||
| def transform(self, dataset, params={}): | ||
| return self.bestModel.transform(dataset, params) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| doctest.testmod() | ||
| from pyspark.context import SparkContext | ||
| from pyspark.sql import SQLContext | ||
| globs = globals().copy() | ||
| # The small batch size here ensures that we see multiple batches, | ||
| # even in these small test examples: | ||
| sc = SparkContext("local[2]", "ml.tuning tests") | ||
| sqlContext = SQLContext(sc) | ||
| globs['sc'] = sc | ||
| globs['sqlContext'] = sqlContext | ||
| (failure_count, test_count) = doctest.testmod( | ||
| globs=globs, optionflags=doctest.ELLIPSIS) | ||
| sc.stop() | ||
| if failure_count: | ||
| exit(-1) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"that fitted" --> "that are fitted"