2020
2121from pyspark .ml .param import Params , Param
2222from pyspark .ml import Estimator , Model
23- from pyspark .sql .functions import rand
2423from pyspark .ml .util import keyword_only
24+ from pyspark .sql .functions import rand
2525
26- __all__ = ['ParamGridBuilder' , 'CrossValidator' ]
26+ __all__ = ['ParamGridBuilder' , 'CrossValidator' , 'CrossValidatorModel' ]
2727
2828
2929class ParamGridBuilder (object ):
@@ -84,6 +84,7 @@ def build(self):
8484 grid_values = self ._param_grid .values ()
8585 return [dict (zip (keys , prod )) for prod in itertools .product (* grid_values )]
8686
87+
8788class CrossValidator (Estimator ):
8889 """
8990 K-fold cross validation.
@@ -99,9 +100,7 @@ class CrossValidator(Estimator):
99100 ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
100101 ... ["features", "label"])
101102 >>> lr = LogisticRegression()
102- >>> grid = ParamGridBuilder() \
103- .addGrid(lr.maxIter, [0, 1, 5]) \
104- .build()
103+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
105104 >>> evaluator = BinaryClassificationEvaluator()
106105 >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
107106 >>> cvModel = cv.fit(dataset)
0 commit comments