Skip to content

Commit b4e3b1e

Browse files
mengxrjeanlyn
authored andcommitted
[SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API
Since CrossValidator is a meta algorithm, we copy the implementation in Python. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#5926 from mengxr/SPARK-6940 and squashes the following commits: 6af181f [Xiangrui Meng] add TODOs 8285134 [Xiangrui Meng] update doc 060f7c3 [Xiangrui Meng] update doctest acac727 [Xiangrui Meng] add keyword args cdddecd [Xiangrui Meng] add CrossValidator in Python
1 parent 0a33ccd commit b4e3b1e

File tree

4 files changed

+199
-8
lines changed

4 files changed

+199
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params {
5252
def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
5353

5454
/**
55-
* param for the evaluator for selection
55+
* param for the evaluator used to select hyper-parameters that maximize the cross-validated
56+
* metric
5657
* @group param
5758
*/
58-
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
59+
val evaluator: Param[Evaluator] = new Param(this, "evaluator",
60+
"evaluator used to select hyper-parameters that maximize the cross-validated metric")
5961

6062
/** @group getParam */
6163
def getEvaluator: Evaluator = $(evaluator)
@@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
120122
trainingDataset.unpersist()
121123
var i = 0
122124
while (i < numModels) {
125+
// TODO: duplicate evaluator to take extra params from input
123126
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
124127
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
125128
metrics(i) += metric

python/pyspark/ml/pipeline.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pyspark.mllib.common import inherit_doc
2323

2424

25-
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
25+
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator', 'Model']
2626

2727

2828
@inherit_doc
@@ -70,6 +70,15 @@ def transform(self, dataset, params={}):
7070
raise NotImplementedError()
7171

7272

73+
@inherit_doc
74+
class Model(Transformer):
75+
"""
76+
Abstract class for models that are fitted by estimators.
77+
"""
78+
79+
__metaclass__ = ABCMeta
80+
81+
7382
@inherit_doc
7483
class Pipeline(Estimator):
7584
"""
@@ -154,7 +163,7 @@ def fit(self, dataset, params={}):
154163

155164

156165
@inherit_doc
157-
class PipelineModel(Transformer):
166+
class PipelineModel(Model):
158167
"""
159168
Represents a compiled pipeline with transformers and fitted models.
160169
"""

python/pyspark/ml/tuning.py

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@
1616
#
1717

1818
import itertools
19+
import numpy as np
1920

20-
__all__ = ['ParamGridBuilder']
21+
from pyspark.ml.param import Params, Param
22+
from pyspark.ml import Estimator, Model
23+
from pyspark.ml.util import keyword_only
24+
from pyspark.sql.functions import rand
25+
26+
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']
2127

2228

2329
class ParamGridBuilder(object):
@@ -79,6 +85,179 @@ def build(self):
7985
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
8086

8187

88+
class CrossValidator(Estimator):
89+
"""
90+
K-fold cross validation.
91+
92+
>>> from pyspark.ml.classification import LogisticRegression
93+
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
94+
>>> from pyspark.mllib.linalg import Vectors
95+
>>> dataset = sqlContext.createDataFrame(
96+
... [(Vectors.dense([0.0, 1.0]), 0.0),
97+
... (Vectors.dense([1.0, 2.0]), 1.0),
98+
... (Vectors.dense([0.55, 3.0]), 0.0),
99+
... (Vectors.dense([0.45, 4.0]), 1.0),
100+
... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
101+
... ["features", "label"])
102+
>>> lr = LogisticRegression()
103+
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
104+
>>> evaluator = BinaryClassificationEvaluator()
105+
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
106+
>>> cvModel = cv.fit(dataset)
107+
>>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
108+
>>> cvModel.transform(dataset).collect() == expected.collect()
109+
True
110+
"""
111+
112+
# a placeholder to make it appear in the generated doc
113+
estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
114+
115+
# a placeholder to make it appear in the generated doc
116+
estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
117+
118+
# a placeholder to make it appear in the generated doc
119+
evaluator = Param(
120+
Params._dummy(), "evaluator",
121+
"evaluator used to select hyper-parameters that maximize the cross-validated metric")
122+
123+
# a placeholder to make it appear in the generated doc
124+
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
125+
126+
@keyword_only
127+
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
128+
"""
129+
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
130+
"""
131+
super(CrossValidator, self).__init__()
132+
#: param for estimator to be cross-validated
133+
self.estimator = Param(self, "estimator", "estimator to be cross-validated")
134+
#: param for estimator param maps
135+
self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps")
136+
#: param for the evaluator used to select hyper-parameters that
137+
#: maximize the cross-validated metric
138+
self.evaluator = Param(
139+
self, "evaluator",
140+
"evaluator used to select hyper-parameters that maximize the cross-validated metric")
141+
#: param for number of folds for cross validation
142+
self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
143+
self._setDefault(numFolds=3)
144+
kwargs = self.__init__._input_kwargs
145+
self._set(**kwargs)
146+
147+
@keyword_only
148+
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
149+
"""
150+
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
151+
Sets params for cross validator.
152+
"""
153+
kwargs = self.setParams._input_kwargs
154+
return self._set(**kwargs)
155+
156+
def setEstimator(self, value):
157+
"""
158+
Sets the value of :py:attr:`estimator`.
159+
"""
160+
self.paramMap[self.estimator] = value
161+
return self
162+
163+
def getEstimator(self):
164+
"""
165+
Gets the value of estimator or its default value.
166+
"""
167+
return self.getOrDefault(self.estimator)
168+
169+
def setEstimatorParamMaps(self, value):
170+
"""
171+
Sets the value of :py:attr:`estimatorParamMaps`.
172+
"""
173+
self.paramMap[self.estimatorParamMaps] = value
174+
return self
175+
176+
def getEstimatorParamMaps(self):
177+
"""
178+
Gets the value of estimatorParamMaps or its default value.
179+
"""
180+
return self.getOrDefault(self.estimatorParamMaps)
181+
182+
def setEvaluator(self, value):
183+
"""
184+
Sets the value of :py:attr:`evaluator`.
185+
"""
186+
self.paramMap[self.evaluator] = value
187+
return self
188+
189+
def getEvaluator(self):
190+
"""
191+
Gets the value of evaluator or its default value.
192+
"""
193+
return self.getOrDefault(self.evaluator)
194+
195+
def setNumFolds(self, value):
196+
"""
197+
Sets the value of :py:attr:`numFolds`.
198+
"""
199+
self.paramMap[self.numFolds] = value
200+
return self
201+
202+
def getNumFolds(self):
203+
"""
204+
Gets the value of numFolds or its default value.
205+
"""
206+
return self.getOrDefault(self.numFolds)
207+
208+
def fit(self, dataset, params={}):
209+
paramMap = self.extractParamMap(params)
210+
est = paramMap[self.estimator]
211+
epm = paramMap[self.estimatorParamMaps]
212+
numModels = len(epm)
213+
eva = paramMap[self.evaluator]
214+
nFolds = paramMap[self.numFolds]
215+
h = 1.0 / nFolds
216+
randCol = self.uid + "_rand"
217+
df = dataset.select("*", rand(0).alias(randCol))
218+
metrics = np.zeros(numModels)
219+
for i in range(nFolds):
220+
validateLB = i * h
221+
validateUB = (i + 1) * h
222+
condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
223+
validation = df.filter(condition)
224+
train = df.filter(~condition)
225+
for j in range(numModels):
226+
model = est.fit(train, epm[j])
227+
# TODO: duplicate evaluator to take extra params from input
228+
metric = eva.evaluate(model.transform(validation, epm[j]))
229+
metrics[j] += metric
230+
bestIndex = np.argmax(metrics)
231+
bestModel = est.fit(dataset, epm[bestIndex])
232+
return CrossValidatorModel(bestModel)
233+
234+
235+
class CrossValidatorModel(Model):
236+
"""
237+
Model from k-fold cross validation.
238+
"""
239+
240+
def __init__(self, bestModel):
241+
#: best model from cross validation
242+
self.bestModel = bestModel
243+
244+
def transform(self, dataset, params={}):
245+
return self.bestModel.transform(dataset, params)
246+
247+
82248
if __name__ == "__main__":
83249
import doctest
84-
doctest.testmod()
250+
from pyspark.context import SparkContext
251+
from pyspark.sql import SQLContext
252+
globs = globals().copy()
253+
# The small batch size here ensures that we see multiple batches,
254+
# even in these small test examples:
255+
sc = SparkContext("local[2]", "ml.tuning tests")
256+
sqlContext = SQLContext(sc)
257+
globs['sc'] = sc
258+
globs['sqlContext'] = sqlContext
259+
(failure_count, test_count) = doctest.testmod(
260+
globs=globs, optionflags=doctest.ELLIPSIS)
261+
sc.stop()
262+
if failure_count:
263+
exit(-1)

python/pyspark/ml/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pyspark import SparkContext
2121
from pyspark.sql import DataFrame
2222
from pyspark.ml.param import Params
23-
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
23+
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
2424
from pyspark.mllib.common import inherit_doc
2525

2626

@@ -133,7 +133,7 @@ def transform(self, dataset, params={}):
133133

134134

135135
@inherit_doc
136-
class JavaModel(JavaTransformer):
136+
class JavaModel(Model, JavaTransformer):
137137
"""
138138
Base class for :py:class:`Model`s that wrap Java/Scala
139139
implementations.

0 commit comments

Comments
 (0)