diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 3d8883b486e4c..a770bad32ecd2 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -950,6 +950,13 @@ def test_fit_maximize_metric(self): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + def test_save_load_trained_model(self): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0c8029f293cfe..1f4abf5157335 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -115,7 +115,11 @@ def build(self): """ keys = self._param_grid.keys() grid_values = self._param_grid.values() - return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] + + def to_key_value_pairs(keys, values): + return [(key, key.typeConverter(value)) for key, value in zip(keys, values)] + + return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] class ValidatorParams(HasSeed):