Skip to content

Commit e72c7e4

Browse files
author
qiping.lqp
committed
add comments
1 parent 845c6fa commit e72c7e4

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
739739
val rightCount = rightImpurityCalculator.count
740740

741741
// If left child or right child doesn't satisfy minimum instances per node,
742-
// then this split is invalid, return invalid information gain stats
742+
// then this split is invalid, return invalid information gain stats.
743743
if ((leftCount < metadata.minInstancesPerNode) ||
744744
(rightCount < metadata.minInstancesPerNode)) {
745745
return InformationGainStats.invalidInformationGainStats
@@ -764,13 +764,23 @@ object DecisionTree extends Serializable with Logging {
764764
val rightWeight = rightCount / totalCount.toDouble
765765

766766
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
767+
768+
// if information gain doesn't satisfy minimum information gain,
769+
// then this split is invalid, return invalid information gain stats.
767770
if (gain < metadata.minInfoGain) {
768771
return InformationGainStats.invalidInformationGainStats
769772
}
770773

771774
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
772775
}
773776

777+
/**
778+
* Calculate predict value for current node, given stats of any split.
779+
* Note that this function is called only once for each node.
780+
* @param leftImpurityCalculator left node aggregates for a split
781+
* @param rightImpurityCalculator right node aggregates for a node
782+
* @return predict value for current node
783+
*/
774784
private def calculatePredict(
775785
leftImpurityCalculator: ImpurityCalculator,
776786
rightImpurityCalculator: ImpurityCalculator): Predict = {
@@ -799,6 +809,7 @@ object DecisionTree extends Serializable with Logging {
799809

800810
logDebug("node impurity = " + nodeImpurity)
801811

812+
// calculate predict only once
802813
var predict: Option[Predict] = None
803814

804815
// For each (feature, split), calculate the gain, and select the best (feature, split).

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,10 @@ class InformationGainStats(
4242

4343

4444
private[tree] object InformationGainStats {
45+
/**
46+
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
47+
* denote that current split doesn't satisfies minimum info gain or
48+
* minimum number of instances per node.
49+
*/
4550
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
4651
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
6868
private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
6969
extends Split(feature, Double.MaxValue, featureType, List())
7070

71-
7271
private[tree] object Split {
72+
/**
73+
* A [[org.apache.spark.mllib.tree.model.Split]] object to denote that
74+
* we can't find a valid split that satisfies minimum info gain
75+
* or minimum number of instances per node.
76+
*/
7377
val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List())
7478
}

0 commit comments

Comments
 (0)