Skip to content

Commit a0ed0da

Browse files
committed
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
1 parent 3726d20 commit a0ed0da

File tree

4 files changed

+35
-32
lines changed

4 files changed

+35
-32
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
2727
import org.apache.spark.mllib.tree.configuration.Algo._
2828
import org.apache.spark.mllib.tree.configuration.FeatureType._
2929
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
30-
import org.apache.spark.mllib.tree.impl.{DTMetadata, TimeTracker, TreePoint}
30+
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
3131
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
3232
import org.apache.spark.mllib.tree.model._
3333
import org.apache.spark.rdd.RDD
@@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
6262
timer.start("init")
6363

6464
val retaggedInput = input.retag(classOf[LabeledPoint])
65-
val metadata = DTMetadata.buildMetadata(retaggedInput, strategy)
65+
val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
6666
logDebug("algo = " + strategy.algo)
6767

6868
// Find the splits and the corresponding bins (interval between the splits) using a sample
@@ -443,7 +443,7 @@ object DecisionTree extends Serializable with Logging {
443443
protected[tree] def findBestSplits(
444444
input: RDD[TreePoint],
445445
parentImpurities: Array[Double],
446-
metadata: DTMetadata,
446+
metadata: DecisionTreeMetadata,
447447
level: Int,
448448
nodes: Array[Node],
449449
splits: Array[Array[Split]],
@@ -489,7 +489,7 @@ object DecisionTree extends Serializable with Logging {
489489
private def findBestSplitsPerGroup(
490490
input: RDD[TreePoint],
491491
parentImpurities: Array[Double],
492-
metadata: DTMetadata,
492+
metadata: DecisionTreeMetadata,
493493
level: Int,
494494
nodes: Array[Node],
495495
splits: Array[Array[Split]],
@@ -551,7 +551,9 @@ object DecisionTree extends Serializable with Logging {
551551

552552
/**
553553
* Get the node index corresponding to this data point.
554-
* This is used during training, mimicking prediction.
554+
* This function mimics prediction, passing an example from the root node down to a node
555+
* at the current level being trained; that node's index is returned.
556+
*
555557
* @return Leaf index if the data point reaches a leaf.
556558
* Otherwise, last node reachable in tree matching this example.
557559
*/
@@ -608,7 +610,8 @@ object DecisionTree extends Serializable with Logging {
608610
val levelOffset = (1 << level) - 1
609611

610612
/**
611-
* Find the node (indexed from 0 at the start of this level) for the given example.
613+
* Find the node index for the given example.
614+
* Nodes are indexed from 0 at the start of this (level, group).
612615
* If the example does not reach this level, returns a value < 0.
613616
*/
614617
def treePointToNodeIndex(treePoint: TreePoint): Int = {
@@ -1261,7 +1264,7 @@ object DecisionTree extends Serializable with Logging {
12611264
*
12621265
* @param numBins Number of bins = 1 + number of possible splits.
12631266
*/
1264-
private def getElementsPerNode(metadata: DTMetadata, numBins: Int): Int = {
1267+
private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
12651268
if (metadata.isClassification) {
12661269
if (metadata.isMulticlassWithCategoricalFeatures) {
12671270
2 * metadata.numClasses * numBins * metadata.numFeatures
@@ -1304,7 +1307,7 @@ object DecisionTree extends Serializable with Logging {
13041307
*/
13051308
protected[tree] def findSplitsBins(
13061309
input: RDD[LabeledPoint],
1307-
metadata: DTMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
1310+
metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
13081311

13091312
val count = input.count()
13101313

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTMetadata.scala renamed to mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.rdd.RDD
3535
* @param featureArity Map: categorical feature index --> arity.
3636
* I.e., the feature takes values in {0, ..., arity - 1}.
3737
*/
38-
private[tree] class DTMetadata(
38+
private[tree] class DecisionTreeMetadata(
3939
val numFeatures: Int,
4040
val numExamples: Long,
4141
val numClasses: Int,
@@ -59,9 +59,9 @@ private[tree] class DTMetadata(
5959

6060
}
6161

62-
private[tree] object DTMetadata {
62+
private[tree] object DecisionTreeMetadata {
6363

64-
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DTMetadata = {
64+
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
6565

6666
val numFeatures = input.take(1)(0).features.size
6767
val numExamples = input.count()
@@ -93,7 +93,7 @@ private[tree] object DTMetadata {
9393
}
9494
}
9595

96-
new DTMetadata(numFeatures, numExamples, numClasses, maxBins,
96+
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
9797
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
9898
strategy.impurity, strategy.quantileCalculationStrategy)
9999
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ private[tree] object TreePoint {
5454
def convertToTreeRDD(
5555
input: RDD[LabeledPoint],
5656
bins: Array[Array[Bin]],
57-
metadata: DTMetadata): RDD[TreePoint] = {
57+
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
5858
input.map { x =>
5959
TreePoint.labeledPointToTreePoint(x, bins, metadata)
6060
}
@@ -67,7 +67,7 @@ private[tree] object TreePoint {
6767
private def labeledPointToTreePoint(
6868
labeledPoint: LabeledPoint,
6969
bins: Array[Array[Bin]],
70-
metadata: DTMetadata): TreePoint = {
70+
metadata: DecisionTreeMetadata): TreePoint = {
7171

7272
val numFeatures = labeledPoint.features.size
7373
val numBins = bins(0).size

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.scalatest.FunSuite
2424
import org.apache.spark.mllib.tree.configuration.Algo._
2525
import org.apache.spark.mllib.tree.configuration.FeatureType._
2626
import org.apache.spark.mllib.tree.configuration.Strategy
27-
import org.apache.spark.mllib.tree.impl.{DTMetadata, TreePoint}
27+
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
2828
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
2929
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
3030
import org.apache.spark.mllib.linalg.Vectors
@@ -64,7 +64,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
6464
assert(arr.length === 1000)
6565
val rdd = sc.parallelize(arr)
6666
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
67-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
67+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
6868
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
6969
assert(splits.length === 2)
7070
assert(bins.length === 2)
@@ -83,7 +83,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
8383
numClassesForClassification = 2,
8484
maxBins = 100,
8585
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
86-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
86+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
8787
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
8888
assert(splits.length === 2)
8989
assert(bins.length === 2)
@@ -164,7 +164,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
164164
numClassesForClassification = 2,
165165
maxBins = 100,
166166
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
167-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
167+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
168168
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
169169

170170
// Check splits.
@@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
282282
numClassesForClassification = 100,
283283
maxBins = 100,
284284
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
285-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
285+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
286286
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
287287

288288
// Expecting 2^2 - 1 = 3 bins/splits
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
377377
numClassesForClassification = 100,
378378
maxBins = 100,
379379
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
380-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
380+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
381381
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
382382

383383
// 2^10 - 1 > 100, so categorical variables will be ordered
@@ -433,7 +433,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
433433
maxDepth = 2,
434434
maxBins = 100,
435435
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
436-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
436+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
437437
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
438438
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
439439
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
@@ -462,7 +462,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
462462
maxDepth = 2,
463463
maxBins = 100,
464464
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
465-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
465+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
466466
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
467467
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
468468
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
@@ -502,7 +502,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
502502
assert(arr.length === 1000)
503503
val rdd = sc.parallelize(arr)
504504
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
505-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
505+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
506506
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
507507
assert(splits.length === 2)
508508
assert(splits(0).length === 99)
@@ -526,7 +526,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
526526
assert(arr.length === 1000)
527527
val rdd = sc.parallelize(arr)
528528
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
529-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
529+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
530530
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
531531
assert(splits.length === 2)
532532
assert(splits(0).length === 99)
@@ -551,7 +551,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
551551
assert(arr.length === 1000)
552552
val rdd = sc.parallelize(arr)
553553
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
554-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
554+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
555555
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
556556
assert(splits.length === 2)
557557
assert(splits(0).length === 99)
@@ -576,7 +576,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
576576
assert(arr.length === 1000)
577577
val rdd = sc.parallelize(arr)
578578
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
579-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
579+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
580580
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
581581
assert(splits.length === 2)
582582
assert(splits(0).length === 99)
@@ -601,7 +601,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
601601
assert(arr.length === 1000)
602602
val rdd = sc.parallelize(arr)
603603
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
604-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
604+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
605605
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
606606
assert(splits.length === 2)
607607
assert(splits(0).length === 99)
@@ -653,7 +653,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
653653
val rdd = sc.parallelize(arr)
654654
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
655655
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
656-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
656+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
657657
assert(strategy.isMulticlassClassification)
658658
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
659659
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
@@ -710,7 +710,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
710710
numClassesForClassification = 3, maxBins = maxBins,
711711
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
712712
assert(strategy.isMulticlassClassification)
713-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
713+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
714714

715715
val model = DecisionTree.train(rdd, strategy)
716716
validateClassifier(model, arr, 1.0)
@@ -739,7 +739,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
739739
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
740740
numClassesForClassification = 3)
741741
assert(strategy.isMulticlassClassification)
742-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
742+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
743743

744744
val model = DecisionTree.train(rdd, strategy)
745745
validateClassifier(model, arr, 0.9)
@@ -765,7 +765,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
765765
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
766766
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
767767
assert(strategy.isMulticlassClassification)
768-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
768+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
769769

770770
val model = DecisionTree.train(rdd, strategy)
771771
validateClassifier(model, arr, 0.9)
@@ -790,7 +790,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
790790
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
791791
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
792792
assert(strategy.isMulticlassClassification)
793-
val metadata = DTMetadata.buildMetadata(rdd, strategy)
793+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
794794

795795
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
796796
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)

0 commit comments

Comments
 (0)