diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0bf988fd72f1..584efa7d35fe 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -121,7 +121,7 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=0): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) """ @@ -136,6 +136,8 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self, "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") #: param for number of folds for cross validation + self._setDefault(seed=0) + self.seed = Param(self, "seed", "seed value used for k-fold") self.numFolds = Param(self, "numFolds", "number of folds for cross validation") self._setDefault(numFolds=3) kwargs = self.__init__._input_kwargs @@ -210,7 +212,7 @@ def _fit(self, dataset): nFolds = self.getOrDefault(self.numFolds) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(self.getOrDefault(self.seed)).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h