@@ -231,19 +231,43 @@ object DecisionTree extends Serializable with Logging {
231231 * @param maxDepth maxDepth maximum depth of the tree
232232 * @return a DecisionTreeModel that can be used for prediction
233233 */
234+ def train (
235+ input : RDD [LabeledPoint ],
236+ algo : Algo ,
237+ impurity : Impurity ,
238+ maxDepth : Int ): DecisionTreeModel = {
239+ val strategy = new Strategy (algo,impurity,maxDepth)
240+ // Converting from standard instance format to weighted input format for tree training
241+ val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
242+ new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
243+ }
244+
245+ /**
246+ * Method to train a decision tree model where the instances are represented as an RDD of
247+ * (label, features) pairs. The method supports binary classification and regression. For the
248+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
249+ * classes.
250+ *
251+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
252+ * training data
253+ * @param algo algorithm, classification or regression
254+ * @param impurity impurity criterion used for information gain calculation
255+ * @param maxDepth maxDepth maximum depth of the tree
256+ * @param numClasses number of classes for classification
257+ * @return a DecisionTreeModel that can be used for prediction
258+ */
234259 def train (
235260 input : RDD [LabeledPoint ],
236261 algo : Algo ,
237262 impurity : Impurity ,
238- maxDepth : Int ): DecisionTreeModel = {
239- val strategy = new Strategy (algo,impurity,maxDepth)
263+ maxDepth : Int ,
264+ numClasses : Int ): DecisionTreeModel = {
265+ val strategy = new Strategy (algo,impurity,maxDepth,numClasses)
240266 // Converting from standard instance format to weighted input format for tree training
241267 val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
242268 new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
243269 }
244270
245- // TODO: Add multiclass classification support
246-
247271 // TODO: Add sample weight support
248272
249273 /**
@@ -258,6 +282,7 @@ object DecisionTree extends Serializable with Logging {
258282 * @param algo classification or regression
259283 * @param impurity criterion used for information gain calculation
260284 * @param maxDepth maximum depth of the tree
285+ * @param numClasses number of classes for classification
261286 * @param maxBins maximum number of bins used for splitting features
262287 * @param quantileCalculationStrategy algorithm for calculating quantiles
263288 * @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -272,11 +297,12 @@ object DecisionTree extends Serializable with Logging {
272297 algo : Algo ,
273298 impurity : Impurity ,
274299 maxDepth : Int ,
300+ numClasses : Int ,
275301 maxBins : Int ,
276302 quantileCalculationStrategy : QuantileStrategy ,
277303 categoricalFeaturesInfo : Map [Int ,Int ]): DecisionTreeModel = {
278- val strategy = new Strategy (algo, impurity, maxDepth, maxBins, quantileCalculationStrategy ,
279- categoricalFeaturesInfo)
304+ val strategy = new Strategy (algo, impurity, maxDepth, numClasses, maxBins ,
305+ quantileCalculationStrategy, categoricalFeaturesInfo)
280306 // Converting from standard instance format to weighted input format for tree training
281307 val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
282308 new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
@@ -737,10 +763,26 @@ object DecisionTree extends Serializable with Logging {
737763 }
738764 }
739765
740- // TODO: Make multiclass modification here
741- val predict = (leftCounts(1 ) + rightCounts(1 )) / (leftTotalCount + rightTotalCount)
766+ val totalCount = leftTotalCount + rightTotalCount
742767
743- new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict)
768+ // Sum of count for each label
769+ val leftRightCounts : Array [Double ]
770+ = leftCounts.zip(rightCounts)
771+ .map{case (leftCount, rightCount) => leftCount + rightCount}
772+
773+ def indexOfLargest (array : Seq [Double ]): Int = {
774+ val result = array.foldLeft(- 1 ,Double .MinValue ,0 ) {
775+ case ((maxIndex, maxValue, currentIndex), currentValue) =>
776+ if (currentValue > maxValue) (currentIndex,currentValue,currentIndex+ 1 )
777+ else (maxIndex,maxValue,currentIndex+ 1 )
778+ }
779+ if (result._1 < 0 ) result._1 else 0
780+ }
781+
782+ val predict = indexOfLargest(leftRightCounts)
783+ val prob = leftRightCounts(predict) / totalCount
784+
785+ new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict, prob)
744786 case Regression =>
745787 val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0 )
746788 val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1 )
@@ -793,8 +835,9 @@ object DecisionTree extends Serializable with Logging {
793835 /**
794836 * Extracts left and right split aggregates.
795837 * @param binData Array[Double] of size 2*numFeatures*numSplits
796- * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
797- * Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
838+ * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\],
839+ * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature,
840+ * (numBins - 1), numClasses)
798841 */
799842 def extractLeftRightNodeAggregates (
800843 binData : Array [Double ]): (Array [Array [Array [Double ]]], Array [Array [Array [Double ]]]) = {
0 commit comments