Skip to content

Commit 118b158

Browse files
committed
Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does.
CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel
1 parent 71a452b commit 118b158

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable {
366366
/**
367367
* Sets default values for a list of params.
368368
*
369-
* Note: Java developers should use the single-parameter [[setDefault()]].
370-
* Annotating this with varargs causes compilation failures.
371-
*
372369
* @param paramPairs a list of param pairs that specify params and their default values to set
373370
* respectively. Make sure that the params are initialized before this method
374371
* gets called.
375372
*/
373+
@varargs
376374
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
377375
paramPairs.foreach { p =>
378376
setDefault(p.param.asInstanceOf[Param[Any]], p.value)

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
105105

106106
override def fit(dataset: DataFrame): CrossValidatorModel = {
107107
val schema = dataset.schema
108-
transformSchema(dataset.schema, logging = true)
108+
transformSchema(schema, logging = true)
109109
val sqlCtx = dataset.sqlContext
110110
val est = $(estimator)
111111
val eval = $(evaluator)
@@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] (
159159
}
160160

161161
override def transform(dataset: DataFrame): DataFrame = {
162+
transformSchema(dataset.schema, logging = true)
162163
bestModel.transform(dataset)
163164
}
164165

mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,6 @@ public JavaTestParams() {
5959
ParamValidators.inArray(validStrings));
6060
setDefault(myIntParam, 1);
6161
setDefault(myDoubleParam, 0.5);
62+
setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
6263
}
6364
}

0 commit comments

Comments
 (0)