diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 42dc7fbebe4c..053487242edd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -94,7 +94,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the label distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - ParamValidators.inArray[String](supportedFamilyNames)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.1.0") @@ -526,7 +526,7 @@ class LogisticRegression @Since("1.2.0") ( case None => histogram.length } - val isMultinomial = $(family) match { + val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match { case "binomial" => require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " + s"outcome classes but found $numClasses.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index e3026c8efa82..3da29b1c816b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -174,8 +174,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), - (o: String) => - ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT))) + (value: String) => supportedOptimizers.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("1.6.0") @@ -325,7 +324,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" + s" length either 1 (scalar) or k (num topics).") } - getOptimizer match { + getOptimizer.toLowerCase(Locale.ROOT) match { case "online" => require(getDocConcentration.forall(_ >= 0), "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " + @@ -337,7 +336,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM } } if (isSet(topicConcentration)) { - getOptimizer match { + getOptimizer.toLowerCase(Locale.ROOT) match { case "online" => require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" + s" must be >= 0. Found value: $getTopicConcentration") @@ -350,17 +349,18 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } - private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { - case "online" => - new OldOnlineLDAOptimizer() - .setTau0($(learningOffset)) - .setKappa($(learningDecay)) - .setMiniBatchFraction($(subsamplingRate)) - .setOptimizeDocConcentration($(optimizeDocConcentration)) - case "em" => - new OldEMLDAOptimizer() - .setKeepLastCheckpoint($(keepLastCheckpoint)) - } + private[clustering] def getOldOptimizer: OldLDAOptimizer = + getOptimizer.toLowerCase(Locale.ROOT) match { + case "online" => + new OldOnlineLDAOptimizer() + .setTau0($(learningOffset)) + .setKappa($(learningDecay)) + .setMiniBatchFraction($(subsamplingRate)) + .setOptimizeDocConcentration($(optimizeDocConcentration)) + case "em" => + new OldEMLDAOptimizer() + .setKeepLastCheckpoint($(keepLastCheckpoint)) + } } private object LDAParams { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bf6bfe30bfe2..1ffd8dcd53d6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2582,6 +2582,17 @@ class LogisticRegressionSuite assert(expected.coefficients.toArray === actual.coefficients.toArray) } } + + test("string params should be case-insensitive") { + val lr = new LogisticRegression() + Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset), + ("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data) => + lr.setFamily(family) + assert(lr.getFamily === family) + val model = lr.fit(data) + assert(model.getFamily === family) + } + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b4fe63a89f87..e73bbc18d76b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -313,4 +313,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getCheckpointFiles.isEmpty) } + + test("string params should be case-insensitive") { + val lda = new LDA() + Seq("eM", "oNLinE").foreach { optimizer => + lda.setOptimizer(optimizer) + assert(lda.getOptimizer === optimizer) + val model = lda.fit(dataset) + assert(model.getOptimizer === optimizer) + } + } }