Skip to content

Commit 060f7c3

Browse files
committed
update doctest
1 parent acac727 commit 060f7c3

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

python/pyspark/ml/tuning.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020

2121
from pyspark.ml.param import Params, Param
2222
from pyspark.ml import Estimator, Model
23-
from pyspark.sql.functions import rand
2423
from pyspark.ml.util import keyword_only
24+
from pyspark.sql.functions import rand
2525

26-
__all__ = ['ParamGridBuilder', 'CrossValidator']
26+
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']
2727

2828

2929
class ParamGridBuilder(object):
@@ -84,6 +84,7 @@ def build(self):
8484
grid_values = self._param_grid.values()
8585
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
8686

87+
8788
class CrossValidator(Estimator):
8889
"""
8990
K-fold cross validation.
@@ -99,9 +100,7 @@ class CrossValidator(Estimator):
99100
... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
100101
... ["features", "label"])
101102
>>> lr = LogisticRegression()
102-
>>> grid = ParamGridBuilder() \
103-
.addGrid(lr.maxIter, [0, 1, 5]) \
104-
.build()
103+
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
105104
>>> evaluator = BinaryClassificationEvaluator()
106105
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
107106
>>> cvModel = cv.fit(dataset)

0 commit comments

Comments
 (0)