@@ -739,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
739739 val rightCount = rightImpurityCalculator.count
740740
741741 // If left child or right child doesn't satisfy minimum instances per node,
742- // then this split is invalid, return invalid information gain stats
742+ // then this split is invalid, return invalid information gain stats.
743743 if ((leftCount < metadata.minInstancesPerNode) ||
744744 (rightCount < metadata.minInstancesPerNode)) {
745745 return InformationGainStats .invalidInformationGainStats
@@ -764,13 +764,23 @@ object DecisionTree extends Serializable with Logging {
764764 val rightWeight = rightCount / totalCount.toDouble
765765
766766 val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
767+
768+ // if information gain doesn't satisfy minimum information gain,
769+ // then this split is invalid, return invalid information gain stats.
767770 if (gain < metadata.minInfoGain) {
768771 return InformationGainStats .invalidInformationGainStats
769772 }
770773
771774 new InformationGainStats (gain, impurity, leftImpurity, rightImpurity)
772775 }
773776
777+ /**
778+ * Calculate predict value for current node, given stats of any split.
779+ * Note that this function is called only once for each node.
780+ * @param leftImpurityCalculator left node aggregates for a split
781+ * @param rightImpurityCalculator right node aggregates for a node
782+ * @return predict value for current node
783+ */
774784 private def calculatePredict (
775785 leftImpurityCalculator : ImpurityCalculator ,
776786 rightImpurityCalculator : ImpurityCalculator ): Predict = {
@@ -799,6 +809,7 @@ object DecisionTree extends Serializable with Logging {
799809
800810 logDebug(" node impurity = " + nodeImpurity)
801811
812+ // calculate predict only once
802813 var predict : Option [Predict ] = None
803814
804815 // For each (feature, split), calculate the gain, and select the best (feature, split).
0 commit comments