Skip to content

Commit cdddecd

Browse files
committed
add CrossValidator in Python
1 parent 489700c commit cdddecd

File tree

3 files changed

+175
-5
lines changed

3 files changed

+175
-5
lines changed

python/pyspark/ml/pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pyspark.mllib.common import inherit_doc
2323

2424

25-
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
25+
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator', 'Model']
2626

2727

2828
@inherit_doc
@@ -70,6 +70,15 @@ def transform(self, dataset, params={}):
7070
raise NotImplementedError()
7171

7272

73+
@inherit_doc
74+
class Model(Transformer):
75+
"""
76+
Abstract class for models that fitted by estimators.
77+
"""
78+
79+
__metaclass__ = ABCMeta
80+
81+
7382
@inherit_doc
7483
class Pipeline(Estimator):
7584
"""

python/pyspark/ml/tuning.py

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
#
1717

1818
import itertools
19+
import numpy as np
1920

20-
__all__ = ['ParamGridBuilder']
21+
from pyspark.ml.param import Params, Param
22+
from pyspark.ml import Estimator, Model
23+
from pyspark.sql.functions import rand
24+
25+
__all__ = ['ParamGridBuilder', 'CrossValidator']
2126

2227

2328
class ParamGridBuilder(object):
@@ -78,7 +83,163 @@ def build(self):
7883
grid_values = self._param_grid.values()
7984
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
8085

86+
class CrossValidator(Estimator):
87+
"""
88+
K-fold cross validation.
89+
90+
>>> from pyspark.ml.classification import LogisticRegression
91+
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
92+
>>> from pyspark.mllib.linalg import Vectors
93+
>>> dataset = sqlContext.createDataFrame(
94+
... [(Vectors.dense([0.0, 1.0]), 0.0),
95+
... (Vectors.dense([1.0, 2.0]), 1.0),
96+
... (Vectors.dense([0.55, 3.0]), 0.0),
97+
... (Vectors.dense([0.45, 4.0]), 1.0),
98+
... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
99+
... ["features", "label"])
100+
>>> lr = LogisticRegression()
101+
>>> grid = ParamGridBuilder() \
102+
.addGrid(lr.maxIter, [0, 1, 5]) \
103+
.build()
104+
>>> evaluator = BinaryClassificationEvaluator()
105+
>>> cv = CrossValidator() \
106+
.setEstimator(lr) \
107+
.setEstimatorParamMaps(grid) \
108+
.setEvaluator(evaluator) \
109+
.setNumFolds(3)
110+
>>> cvModel = cv.fit(dataset)
111+
>>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
112+
>>> cvModel.transform(dataset).collect() == expected.collect()
113+
True
114+
"""
115+
116+
# a placeholder to make it appear in the generated doc
117+
estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
118+
119+
# a placeholder to make it appear in the generated doc
120+
estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
121+
122+
# a placeholder to make it appear in the generated doc
123+
evaluator = Param(Params._dummy(), "evaluator", "evaluator for selection")
124+
125+
# a placeholder to make it appear in the generated doc
126+
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
127+
128+
def __init__(self):
129+
super(CrossValidator, self).__init__()
130+
#: param for estimator to be cross-validated
131+
self.estimator = Param(self, "estimator", "estimator to be cross-validated")
132+
#: param for estimator param maps
133+
self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps")
134+
#: param for evaluator for selection
135+
self.evaluator = Param(self, "evaluator", "evaluator for selection")
136+
#: param for number of folds for cross validation
137+
self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
138+
139+
def setEstimator(self, value):
140+
"""
141+
Sets the value of :py:attr:`estimator`.
142+
"""
143+
self.paramMap[self.estimator] = value
144+
return self
145+
146+
def getEstimator(self):
147+
"""
148+
Gets the value of estimator or its default value.
149+
"""
150+
return self.getOrDefault(self.estimator)
151+
152+
def setEstimatorParamMaps(self, value):
153+
"""
154+
Sets the value of :py:attr:`estimatorParamMaps`.
155+
"""
156+
self.paramMap[self.estimatorParamMaps] = value
157+
return self
158+
159+
def getEstimatorParamMaps(self):
160+
"""
161+
Gets the value of estimatorParamMaps or its default value.
162+
"""
163+
return self.getOrDefault(self.estimatorParamMaps)
164+
165+
def setEvaluator(self, value):
166+
"""
167+
Sets the value of :py:attr:`evaluator`.
168+
"""
169+
self.paramMap[self.evaluator] = value
170+
return self
171+
172+
def getEvaluator(self):
173+
"""
174+
Gets the value of evaluator or its default value.
175+
"""
176+
return self.getOrDefault(self.evaluator)
177+
178+
def setNumFolds(self, value):
179+
"""
180+
Sets the value of :py:attr:`numFolds`.
181+
"""
182+
self.paramMap[self.numFolds] = value
183+
return self
184+
185+
def getNumFolds(self):
186+
"""
187+
Gets the value of numFolds or its default value.
188+
"""
189+
return self.getOrDefault(self.numFolds)
190+
191+
def fit(self, dataset, params={}):
192+
paramMap = self.extractParamMap(params)
193+
est = paramMap[self.estimator]
194+
epm = paramMap[self.estimatorParamMaps]
195+
numModels = len(epm)
196+
eva = paramMap[self.evaluator]
197+
nFolds = paramMap[self.numFolds]
198+
h = 1.0 / nFolds
199+
randCol = self.uid + "_rand"
200+
df = dataset.select("*", rand(0).alias(randCol))
201+
metrics = np.zeros(numModels)
202+
for i in range(nFolds):
203+
validateLB = i * h
204+
validateUB = (i + 1) * h
205+
condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
206+
validation = df.filter(condition)
207+
train = df.filter(~condition)
208+
for j in range(numModels):
209+
model = est.fit(train, epm[j])
210+
metric = eva.evaluate(model.transform(validation, epm[j]))
211+
metrics[j] += metric
212+
bestIndex = np.argmax(metrics)
213+
bestModel = est.fit(dataset, epm[bestIndex])
214+
return CrossValidatorModel(bestModel)
215+
216+
217+
class CrossValidatorModel(Model):
218+
"""
219+
Model from k-fold corss validation.
220+
"""
221+
222+
def __init__(self, bestModel):
223+
#: best model from cross validation
224+
self.bestModel = bestModel
225+
226+
def transform(self, dataset, params={}):
227+
return self.bestModel.transform(dataset, params)
228+
81229

82230
if __name__ == "__main__":
83231
import doctest
84-
doctest.testmod()
232+
from pyspark.context import SparkContext
233+
from pyspark.sql import SQLContext
234+
globs = globals().copy()
235+
# The small batch size here ensures that we see multiple batches,
236+
# even in these small test examples:
237+
sc = SparkContext("local[2]", "ml.tuning tests")
238+
sqlContext = SQLContext(sc)
239+
globs['sc'] = sc
240+
globs['sqlContext'] = sqlContext
241+
(failure_count, test_count) = doctest.testmod(
242+
globs=globs, optionflags=doctest.ELLIPSIS)
243+
sc.stop()
244+
if failure_count:
245+
exit(-1)

python/pyspark/ml/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pyspark import SparkContext
2121
from pyspark.sql import DataFrame
2222
from pyspark.ml.param import Params
23-
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
23+
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
2424
from pyspark.mllib.common import inherit_doc
2525

2626

@@ -133,7 +133,7 @@ def transform(self, dataset, params={}):
133133

134134

135135
@inherit_doc
136-
class JavaModel(JavaTransformer):
136+
class JavaModel(Model, JavaTransformer):
137137
"""
138138
Base class for :py:class:`Model`s that wrap Java/Scala
139139
implementations.

0 commit comments

Comments
 (0)