From 45b74930eea787411855fc35a7ad7198b35d577e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 12:02:13 +0800 Subject: [PATCH 01/22] TST: add test case --- .../apache/spark/ml/tree/impl/RandomForestSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index e1ab7c2d6520b..fa77bfe114614 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -87,6 +87,17 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits(0).length === 0) } + test("SPARK-16957: Use weighted midpoints for split values.") { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(2), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits === Array(0.5)) + } + test("find splits for a continuous feature") { // find splits for normal case { From c49d3ae7db0e66855b0c896375b11bf51d9ac482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 12:05:36 +0800 Subject: [PATCH 02/22] ENH: use weighted midpoints --- .../spark/ml/tree/impl/RandomForest.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 008dd19c2498d..c57be687530ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = if (featureSamples.isEmpty) { + val splits: Array[Double] = if (featureSamples.isEmpty) { Array.empty[Double] } else { val numSplits = metadata.numSplits(featureIndex) @@ -1009,10 +1009,20 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + def weightedMean(pre: (Double, Int), cru: (Double, Int)): Double = { + val (preValue, preCount) = pre + val (curValue, curCount) = cru + (preValue * preCount + curValue * curCount) / (preCount + curCount) + } + // if possible splits is not enough or just enough, just return all possible splits val possibleSplits = valueCounts.length - 1 if (possibleSplits <= numSplits) { - valueCounts.map(_._1).init + valueCounts + .sliding(2) + .map{x => weightedMean(x(0), x(1))} + .toArray + } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1037,7 +1047,10 @@ private[spark] object RandomForest extends Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 + val pre = valueCounts(index - 1) + val cru = valueCounts(index) + + splitsBuilder += weightedMean(pre, cru) targetCount += stride } index += 1 From 387eb498054289149706ecd2f88593d008fd074f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 12:13:44 +0800 Subject: [PATCH 03/22] BUG: constant feature, outOfIndex --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index c57be687530ac..3a2ce9dbca4a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1017,7 +1017,11 @@ private[spark] object RandomForest extends Logging { // if possible splits is not enough or just enough, just return all possible splits val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { + if (possibleSplits == 0) { + // constant feature + Array.empty[Double] + + } else if (possibleSplits <= numSplits) { valueCounts .sliding(2) .map{x => weightedMean(x(0), x(1))} From 2e68f1efca59772d1e905474c2392ad0d8b413c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 12:15:09 +0800 Subject: [PATCH 04/22] TST: modify split's test case --- .../spark/ml/tree/impl/RandomForestSuite.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index fa77bfe114614..b48f386aa1a27 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -123,9 +123,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0, 2.0)) + assert(splits === Array(1.8, 2.2)) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -137,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) + .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(2.0, 3.0)) + assert(splits === Array(2.0625, 3.5)) } // find splits when most samples close to the maximum @@ -149,9 +150,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0)) + assert(splits === Array(1.9375)) } // find splits for constant feature From 6a5806f35185596ffda2c88c4879ecaf0be3bda1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 12:24:02 +0800 Subject: [PATCH 05/22] CLN: move test case --- .../ml/tree/impl/RandomForestSuite.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index b48f386aa1a27..da2397817097a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -87,17 +87,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits(0).length === 0) } - test("SPARK-16957: Use weighted midpoints for split values.") { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(2), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) - val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(0.5)) - } - test("find splits for a continuous feature") { // find splits for normal case { @@ -115,6 +104,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits.distinct.length === splits.length) } + // SPARK-16957: Use weighted midpoints for split values. + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(2), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits === Array(0.5)) + } + // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { From 7ad590df5f41cbc4ff621aa63932ac47501c1060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 21:06:31 +0800 Subject: [PATCH 06/22] CLN: fix a typo --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 3a2ce9dbca4a8..f2a5976d84b62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1009,9 +1009,9 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - def weightedMean(pre: (Double, Int), cru: (Double, Int)): Double = { + def weightedMean(pre: (Double, Int), cur: (Double, Int)): Double = { val (preValue, preCount) = pre - val (curValue, curCount) = cru + val (curValue, curCount) = cur (preValue * preCount + curValue * curCount) / (preCount + curCount) } @@ -1052,9 +1052,9 @@ private[spark] object RandomForest extends Logging { // previous value is a split threshold. if (previousGap < currentGap) { val pre = valueCounts(index - 1) - val cru = valueCounts(index) + val cur = valueCounts(index) - splitsBuilder += weightedMean(pre, cru) + splitsBuilder += weightedMean(pre, cur) targetCount += stride } index += 1 From 0aaed66e563378ebb1df9f56ad26913921a6d500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 21:07:38 +0800 Subject: [PATCH 07/22] BUG: int, overflow --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index f2a5976d84b62..b06df40b6b829 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1012,7 +1012,7 @@ private[spark] object RandomForest extends Logging { def weightedMean(pre: (Double, Int), cur: (Double, Int)): Double = { val (preValue, preCount) = pre val (curValue, curCount) = cur - (preValue * preCount + curValue * curCount) / (preCount + curCount) + (preValue * preCount + curValue * curCount) / (preCount.toDouble + curCount) } // if possible splits is not enough or just enough, just return all possible splits From c07ffac370d633607a129198ea37ca946f5812bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Fri, 7 Apr 2017 21:08:37 +0800 Subject: [PATCH 08/22] CLN: style mistake, { -> ( --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b06df40b6b829..f9da55b8f1ae7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1024,7 +1024,7 @@ private[spark] object RandomForest extends Logging { } else if (possibleSplits <= numSplits) { valueCounts .sliding(2) - .map{x => weightedMean(x(0), x(1))} + .map(x => weightedMean(x(0), x(1))) .toArray } else { From 9ca57505c8211954478a2d54ced48c2561cfb9f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Sat, 8 Apr 2017 08:00:40 +0800 Subject: [PATCH 09/22] CLN: mv comment --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index f9da55b8f1ae7..5cafdd0533ec0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1015,13 +1015,13 @@ private[spark] object RandomForest extends Logging { (preValue * preCount + curValue * curCount) / (preCount.toDouble + curCount) } - // if possible splits is not enough or just enough, just return all possible splits val possibleSplits = valueCounts.length - 1 if (possibleSplits == 0) { // constant feature Array.empty[Double] } else if (possibleSplits <= numSplits) { + // if possible splits is not enough or just enough, just return all possible splits valueCounts .sliding(2) .map(x => weightedMean(x(0), x(1))) From b74702afa958fa3552e494cbe77590d9940bf1fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 11 Apr 2017 11:27:57 +0800 Subject: [PATCH 10/22] TST: revise unit test in python --- python/pyspark/mllib/tree.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index a6089fc8b9d32..619fa16d463f5 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -199,9 +199,9 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.0) + If (feature 0 <= 0.5) Predict: 0.0 - Else (feature 0 > 0.0) + Else (feature 0 > 0.5) Predict: 1.0 >>> model.predict(array([1.0])) @@ -383,14 +383,14 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Tree 0: Predict: 1.0 Tree 1: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 Tree 2: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 >>> model.predict([2.0]) From 76f4ae8f2ec15cb30990e116ebe0997e768852ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sat, 15 Apr 2017 07:59:02 +0800 Subject: [PATCH 11/22] TST: explicitly show calculation --- .../spark/ml/tree/impl/RandomForestSuite.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index da2397817097a..3f42b016dad1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -113,7 +113,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(0.5)) + val expSplits = Array((0 * 4 + 0 * 4) / (4 + 4)) // = 0.5 + assert(splits === expSplits) } // find splits should not return identical splits @@ -126,7 +127,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.8, 2.2)) + val expSplits = Array((1.0 * 2 + 2.0 * 8) / (2 + 8), + (2.0 * 8 + 3.0 * 2) / (8 + 2)) // = (1.8, 2.2) + assert(splits === expSplits) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -141,7 +144,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(2.0625, 3.5)) + val expSplits = Array((2.0 * 15 + 3.0 * 1) / (15 + 1), + (3.0 * 1 + 4.0 * 1) / (1 + 1)) // = (2.0625, 3.5) + assert(splits === expSplits) } // find splits when most samples close to the maximum @@ -153,7 +158,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.9375)) + val expSplits = Array((1.0 * 1 + 2.0 * 15) / (1 + 15)) // = (1.9375) + assert(splits === expSplits) } // find splits for constant feature From 031c61a60d0638dc75133c60c045be2c9204b64b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sat, 15 Apr 2017 08:47:30 +0800 Subject: [PATCH 12/22] TST: add possibleSplits > numSplits --- .../ml/tree/impl/RandomForestSuite.scala | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 3f42b016dad1e..24360481fba57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -108,13 +108,26 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), - Array(2), Gini, QuantileStrategy.Sort, + Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) - val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expSplits = Array((0 * 4 + 0 * 4) / (4 + 4)) // = 0.5 - assert(splits === expSplits) + + // possibleSplits <= numSplits + { + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expSplits = Array((0.0 * 4 + 1.0 * 4) / (4 + 4)) // = 0.5 + assert(splits === expSplits) + } + + // possibleSplits > numSplits + { + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expSplits = Array((0.0 * 2 + 1.0 * 2) / (2 + 2), + (2.0 * 2 + 3.0 * 2) / (2 + 2)) // = (0.5, 2.5) + assert(splits === expSplits) + } } // find splits should not return identical splits From 1459b142b1b70e1da649209206df85548feb0719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sat, 29 Apr 2017 10:03:34 +0800 Subject: [PATCH 13/22] TST: expSplits -> expectedSplits --- .../ml/tree/impl/RandomForestSuite.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 24360481fba57..c444737a7c795 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -116,17 +116,17 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { { val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expSplits = Array((0.0 * 4 + 1.0 * 4) / (4 + 4)) // = 0.5 - assert(splits === expSplits) + val expectedSplits = Array((0.0 * 4 + 1.0 * 4) / (4 + 4)) // = 0.5 + assert(splits === expectedSplits) } // possibleSplits > numSplits { val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expSplits = Array((0.0 * 2 + 1.0 * 2) / (2 + 2), - (2.0 * 2 + 3.0 * 2) / (2 + 2)) // = (0.5, 2.5) - assert(splits === expSplits) + val expectedSplits = Array((0.0 * 2 + 1.0 * 2) / (2 + 2), + (2.0 * 2 + 3.0 * 2) / (2 + 2)) // = (0.5, 2.5) + assert(splits === expectedSplits) } } @@ -140,9 +140,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expSplits = Array((1.0 * 2 + 2.0 * 8) / (2 + 8), - (2.0 * 8 + 3.0 * 2) / (8 + 2)) // = (1.8, 2.2) - assert(splits === expSplits) + val expectedSplits = Array((1.0 * 2 + 2.0 * 8) / (2 + 8), + (2.0 * 8 + 3.0 * 2) / (8 + 2)) // = (1.8, 2.2) + assert(splits === expectedSplits) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -157,9 +157,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expSplits = Array((2.0 * 15 + 3.0 * 1) / (15 + 1), - (3.0 * 1 + 4.0 * 1) / (1 + 1)) // = (2.0625, 3.5) - assert(splits === expSplits) + val expectedSplits = Array((2.0 * 15 + 3.0 * 1) / (15 + 1), + (3.0 * 1 + 4.0 * 1) / (1 + 1)) // = (2.0625, 3.5) + assert(splits === expectedSplits) } // find splits when most samples close to the maximum @@ -171,8 +171,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expSplits = Array((1.0 * 1 + 2.0 * 15) / (1 + 15)) // = (1.9375) - assert(splits === expSplits) + val expectedSplits = Array((1.0 * 1 + 2.0 * 15) / (1 + 15)) // = (1.9375) + assert(splits === expectedSplits) } // find splits for constant feature From a09402968e903524e23a5623f7497388af3dd7bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sat, 29 Apr 2017 10:04:56 +0800 Subject: [PATCH 14/22] CLN: remove blank --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 5cafdd0533ec0..f9e3ab9576c86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1019,7 +1019,6 @@ private[spark] object RandomForest extends Logging { if (possibleSplits == 0) { // constant feature Array.empty[Double] - } else if (possibleSplits <= numSplits) { // if possible splits is not enough or just enough, just return all possible splits valueCounts @@ -1053,7 +1052,6 @@ private[spark] object RandomForest extends Logging { if (previousGap < currentGap) { val pre = valueCounts(index - 1) val cur = valueCounts(index) - splitsBuilder += weightedMean(pre, cur) targetCount += stride } From 7c50d4aa04ff85bb1720e3b0ece8708aa278f306 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sun, 30 Apr 2017 19:17:16 +0800 Subject: [PATCH 15/22] ENH: weighted mean -> mean --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index f9e3ab9576c86..b19a28811af66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1009,10 +1009,11 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - def weightedMean(pre: (Double, Int), cur: (Double, Int)): Double = { + // perhaps weighted mean is better in the future, see SPARK-16957 and Github PR 17556. + def mean(pre: (Double, Int), cur: (Double, Int)): Double = { val (preValue, preCount) = pre val (curValue, curCount) = cur - (preValue * preCount + curValue * curCount) / (preCount.toDouble + curCount) + (preValue + curValue) / 2 } val possibleSplits = valueCounts.length - 1 @@ -1023,7 +1024,7 @@ private[spark] object RandomForest extends Logging { // if possible splits is not enough or just enough, just return all possible splits valueCounts .sliding(2) - .map(x => weightedMean(x(0), x(1))) + .map(x => mean(x(0), x(1))) .toArray } else { @@ -1052,7 +1053,7 @@ private[spark] object RandomForest extends Logging { if (previousGap < currentGap) { val pre = valueCounts(index - 1) val cur = valueCounts(index) - splitsBuilder += weightedMean(pre, cur) + splitsBuilder += mean(pre, cur) targetCount += stride } index += 1 From 7bb11dd788702ac14a60c9c45e2f663673764e97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sun, 30 Apr 2017 19:21:51 +0800 Subject: [PATCH 16/22] TST: revise unit test in scala --- .../spark/ml/tree/impl/RandomForestSuite.scala | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index c444737a7c795..df155b464c64b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -104,7 +104,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits.distinct.length === splits.length) } - // SPARK-16957: Use weighted midpoints for split values. + // SPARK-16957: Use midpoints for split values. { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), @@ -116,7 +116,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { { val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expectedSplits = Array((0.0 * 4 + 1.0 * 4) / (4 + 4)) // = 0.5 + val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) } @@ -124,8 +124,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { { val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expectedSplits = Array((0.0 * 2 + 1.0 * 2) / (2 + 2), - (2.0 * 2 + 3.0 * 2) / (2 + 2)) // = (0.5, 2.5) + val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) } } @@ -140,8 +139,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expectedSplits = Array((1.0 * 2 + 2.0 * 8) / (2 + 8), - (2.0 * 8 + 3.0 * 2) / (8 + 2)) // = (1.8, 2.2) + val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) // check returned splits are distinct assert(splits.distinct.length === splits.length) @@ -157,8 +155,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expectedSplits = Array((2.0 * 15 + 3.0 * 1) / (15 + 1), - (3.0 * 1 + 4.0 * 1) / (1 + 1)) // = (2.0625, 3.5) + val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) assert(splits === expectedSplits) } @@ -171,7 +168,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - val expectedSplits = Array((1.0 * 1 + 2.0 * 15) / (1 + 15)) // = (1.9375) + val expectedSplits = Array((1.0 + 2.0) / 2) assert(splits === expectedSplits) } From ae0e48e7deb611bf2be6f3e5ddb4d008f89b62f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 2 May 2017 10:10:03 +0800 Subject: [PATCH 17/22] CLN: mean method is removed --- .../main/scala/org/apache/spark/ml/tree/Node.scala | 2 +- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 12 +++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10e..2605f19e49132 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -31,7 +31,7 @@ sealed abstract class Node extends Serializable { // code into the new API and deprecate the old API. SPARK-3727 /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ - def prediction: Double + def prediction: Double /** Impurity measure at this node (for training data) */ def impurity: Double diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b19a28811af66..d69744528ef9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1009,13 +1009,6 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - // perhaps weighted mean is better in the future, see SPARK-16957 and Github PR 17556. - def mean(pre: (Double, Int), cur: (Double, Int)): Double = { - val (preValue, preCount) = pre - val (curValue, curCount) = cur - (preValue + curValue) / 2 - } - val possibleSplits = valueCounts.length - 1 if (possibleSplits == 0) { // constant feature @@ -1024,7 +1017,7 @@ private[spark] object RandomForest extends Logging { // if possible splits is not enough or just enough, just return all possible splits valueCounts .sliding(2) - .map(x => mean(x(0), x(1))) + .map(x => (x(0)._1 + x(1)._1) / 2) .toArray } else { @@ -1053,7 +1046,8 @@ private[spark] object RandomForest extends Logging { if (previousGap < currentGap) { val pre = valueCounts(index - 1) val cur = valueCounts(index) - splitsBuilder += mean(pre, cur) + // perhaps weighted mean will be used later, see SPARK-16957 and Github PR 17556. + splitsBuilder += (pre._1 + cur._1) / 2 targetCount += stride } index += 1 From 59866fac6104400167335f0cdee8c387919f1981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 2 May 2017 10:14:33 +0800 Subject: [PATCH 18/22] CLN: trim whitespace at end of line --- mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 2605f19e49132..07e98a142b10e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -31,7 +31,7 @@ sealed abstract class Node extends Serializable { // code into the new API and deprecate the old API. SPARK-3727 /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ - def prediction: Double + def prediction: Double /** Impurity measure at this node (for training data) */ def impurity: Double From 10037eaa7b844f7e3dc711b3e23ac511e292b2c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 2 May 2017 10:26:31 +0800 Subject: [PATCH 19/22] CLN: refine, possibleSplits <= numSplits --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index d69744528ef9c..d09b3cbb5510a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1015,11 +1015,11 @@ private[spark] object RandomForest extends Logging { Array.empty[Double] } else if (possibleSplits <= numSplits) { // if possible splits is not enough or just enough, just return all possible splits - valueCounts - .sliding(2) - .map(x => (x(0)._1 + x(1)._1) / 2) - .toArray + val splits = for { + i <- 0 until valueCounts.length - 1 + } yield (valueCounts(i)._1 + valueCounts(i + 1)._1) / 2 + splits.toArray } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) From 1cae998398bb276024dde22b87ae30f0c35c53d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 2 May 2017 10:28:59 +0800 Subject: [PATCH 20/22] CLN: use possibleSplits --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index d09b3cbb5510a..43efd5fe5a097 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1016,7 +1016,7 @@ private[spark] object RandomForest extends Logging { } else if (possibleSplits <= numSplits) { // if possible splits is not enough or just enough, just return all possible splits val splits = for { - i <- 0 until valueCounts.length - 1 + i <- 0 until possibleSplits } yield (valueCounts(i)._1 + valueCounts(i + 1)._1) / 2 splits.toArray From 92df1c8c58081fc5dbe867249e22d364a8f66e7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 2 May 2017 18:20:38 +0800 Subject: [PATCH 21/22] CLN: use map to replace for...yield --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 43efd5fe5a097..03bf63f18c57f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1015,11 +1015,9 @@ private[spark] object RandomForest extends Logging { Array.empty[Double] } else if (possibleSplits <= numSplits) { // if possible splits is not enough or just enough, just return all possible splits - val splits = for { - i <- 0 until possibleSplits - } yield (valueCounts(i)._1 + valueCounts(i + 1)._1) / 2 - - splits.toArray + (1 to possibleSplits) + .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0) + .toArray } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1044,10 +1042,8 @@ private[spark] object RandomForest extends Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - val pre = valueCounts(index - 1) - val cur = valueCounts(index) // perhaps weighted mean will be used later, see SPARK-16957 and Github PR 17556. - splitsBuilder += (pre._1 + cur._1) / 2 + splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 targetCount += stride } index += 1 From 591d7900fa25a2c0abdbc0de7dadc105e35dd52d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 3 May 2017 16:27:38 +0800 Subject: [PATCH 22/22] CLN: remove comment --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 03bf63f18c57f..82e1ed85a0a14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1042,7 +1042,6 @@ private[spark] object RandomForest extends Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - // perhaps weighted mean will be used later, see SPARK-16957 and Github PR 17556. splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 targetCount += stride }