From e5a27110779bf24b9c4423062f76fddce98b592b Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 11 May 2017 17:25:21 +0800 Subject: [PATCH 1/3] recreate pr --- .../classification/LogisticRegression.scala | 4 +-- .../org/apache/spark/ml/clustering/LDA.scala | 30 +++++++++---------- .../LogisticRegressionSuite.scala | 4 +-- .../apache/spark/ml/clustering/LDASuite.scala | 4 +-- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 42dc7fbebe4c3..053487242edd8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -94,7 +94,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the label distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - ParamValidators.inArray[String](supportedFamilyNames)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.1.0") @@ -526,7 +526,7 @@ class LogisticRegression @Since("1.2.0") ( case None => histogram.length } - val isMultinomial = $(family) match { + val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match { case "binomial" => require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " + s"outcome classes but found $numClasses.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index e3026c8efa823..3da29b1c816b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -174,8 +174,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), - (o: String) => - ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT))) + (value: String) => supportedOptimizers.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("1.6.0") @@ -325,7 +324,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" + s" length either 1 (scalar) or k (num topics).") } - getOptimizer match { + getOptimizer.toLowerCase(Locale.ROOT) match { case "online" => require(getDocConcentration.forall(_ >= 0), "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " + @@ -337,7 +336,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM } } if (isSet(topicConcentration)) { - getOptimizer match { + getOptimizer.toLowerCase(Locale.ROOT) match { case "online" => require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" + s" must be >= 0. Found value: $getTopicConcentration") @@ -350,17 +349,18 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } - private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { - case "online" => - new OldOnlineLDAOptimizer() - .setTau0($(learningOffset)) - .setKappa($(learningDecay)) - .setMiniBatchFraction($(subsamplingRate)) - .setOptimizeDocConcentration($(optimizeDocConcentration)) - case "em" => - new OldEMLDAOptimizer() - .setKeepLastCheckpoint($(keepLastCheckpoint)) - } + private[clustering] def getOldOptimizer: OldLDAOptimizer = + getOptimizer.toLowerCase(Locale.ROOT) match { + case "online" => + new OldOnlineLDAOptimizer() + .setTau0($(learningOffset)) + .setKappa($(learningDecay)) + .setMiniBatchFraction($(subsamplingRate)) + .setOptimizeDocConcentration($(optimizeDocConcentration)) + case "em" => + new OldEMLDAOptimizer() + .setKeepLastCheckpoint($(keepLastCheckpoint)) + } } private object LDAParams { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bf6bfe30bfe20..7b69ddbe8bd8b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -255,7 +255,7 @@ class LogisticRegressionSuite } test("thresholds prediction") { - val blr = new LogisticRegression().setFamily("binomial") + val blr = new LogisticRegression().setFamily("BiNomial") val binaryModel = blr.fit(smallBinaryDataset) binaryModel.setThreshold(1.0) @@ -269,7 +269,7 @@ class LogisticRegressionSuite assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - val mlr = new LogisticRegression().setFamily("multinomial") + val mlr = new LogisticRegression().setFamily("MulTinoMial") val model = mlr.fit(smallMultinomialDataset) val basePredictions = model.transform(smallMultinomialDataset).select("prediction").collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b4fe63a89f871..1a6c4c060cc8d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -173,7 +173,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } test("fit & transform with Online LDA") { - val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) + val lda = new LDA().setK(k).setSeed(1).setOptimizer("oNlIne").setMaxIter(2) val model = lda.fit(dataset) MLTestingUtils.checkCopyAndUids(lda, model) @@ -218,7 +218,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } test("fit & transform with EM LDA") { - val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) + val lda = new LDA().setK(k).setSeed(1).setOptimizer("eM").setMaxIter(2) val model_ = lda.fit(dataset) MLTestingUtils.checkCopyAndUids(lda, model_) From 43108fa138e89d60f53eb4737e44421ad6e0d9bc Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 15 May 2017 16:35:43 +0800 Subject: [PATCH 2/3] add tests --- .../LogisticRegressionSuite.scala | 19 +++++++++++++++++-- .../apache/spark/ml/clustering/LDASuite.scala | 15 +++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 7b69ddbe8bd8b..236a96b05f55f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -255,7 +255,7 @@ class LogisticRegressionSuite } test("thresholds prediction") { - val blr = new LogisticRegression().setFamily("BiNomial") + val blr = new LogisticRegression().setFamily("binomial") val binaryModel = blr.fit(smallBinaryDataset) binaryModel.setThreshold(1.0) @@ -269,7 +269,7 @@ class LogisticRegressionSuite assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - val mlr = new LogisticRegression().setFamily("MulTinoMial") + val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(smallMultinomialDataset) val basePredictions = model.transform(smallMultinomialDataset).select("prediction").collect() @@ -2582,6 +2582,21 @@ class LogisticRegressionSuite assert(expected.coefficients.toArray === actual.coefficients.toArray) } } + + test("string params should be case-insensitive") { + val lr = new LogisticRegression() + lr.setFamily("AuTo") + assert(lr.getFamily === "AuTo") + lr.fit(smallBinaryDataset) + + lr.setFamily("biNoMial") + assert(lr.getFamily === "biNoMial") + lr.fit(smallBinaryDataset) + + lr.setFamily("mulTinomIAl") + assert(lr.getFamily === "mulTinomIAl") + lr.fit(smallMultinomialDataset) + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 1a6c4c060cc8d..ea4c48afa6f17 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -173,7 +173,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } test("fit & transform with Online LDA") { - val lda = new LDA().setK(k).setSeed(1).setOptimizer("oNlIne").setMaxIter(2) + val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) val model = lda.fit(dataset) MLTestingUtils.checkCopyAndUids(lda, model) @@ -218,7 +218,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } test("fit & transform with EM LDA") { - val lda = new LDA().setK(k).setSeed(1).setOptimizer("eM").setMaxIter(2) + val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) val model_ = lda.fit(dataset) MLTestingUtils.checkCopyAndUids(lda, model_) @@ -313,4 +313,15 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getCheckpointFiles.isEmpty) } + + test("string params should be case-insensitive") { + val lda = new LDA() + lda.setOptimizer("eM") + assert(lda.getOptimizer === "eM") + lda.fit(dataset) + + lda.setOptimizer("oNlinE") + assert(lda.getOptimizer === "oNlinE") + lda.fit(dataset) + } } From 7abbe3616aa4746644fd3b89c343ab1d38017254 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 15 May 2017 19:25:44 +0800 Subject: [PATCH 3/3] add model getter check --- .../LogisticRegressionSuite.scala | 18 +++++++----------- .../apache/spark/ml/clustering/LDASuite.scala | 13 ++++++------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 236a96b05f55f..1ffd8dcd53d61 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2585,17 +2585,13 @@ class LogisticRegressionSuite test("string params should be case-insensitive") { val lr = new LogisticRegression() - lr.setFamily("AuTo") - assert(lr.getFamily === "AuTo") - lr.fit(smallBinaryDataset) - - lr.setFamily("biNoMial") - assert(lr.getFamily === "biNoMial") - lr.fit(smallBinaryDataset) - - lr.setFamily("mulTinomIAl") - assert(lr.getFamily === "mulTinomIAl") - lr.fit(smallMultinomialDataset) + Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset), + ("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data) => + lr.setFamily(family) + assert(lr.getFamily === family) + val model = lr.fit(data) + assert(model.getFamily === family) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index ea4c48afa6f17..e73bbc18d76bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -316,12 +316,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead test("string params should be case-insensitive") { val lda = new LDA() - lda.setOptimizer("eM") - assert(lda.getOptimizer === "eM") - lda.fit(dataset) - - lda.setOptimizer("oNlinE") - assert(lda.getOptimizer === "oNlinE") - lda.fit(dataset) + Seq("eM", "oNLinE").foreach { optimizer => + lda.setOptimizer(optimizer) + assert(lda.getOptimizer === optimizer) + val model = lda.fit(dataset) + assert(model.getOptimizer === optimizer) + } } }