1717
1818package org .apache .spark .mllib .tree
1919
20- import scala .collection .JavaConverters ._
2120import scala .collection .mutable
21+ import scala .collection .JavaConverters ._
2222
2323import org .apache .spark .Logging
2424import org .apache .spark .annotation .Experimental
2525import org .apache .spark .api .java .JavaRDD
2626import org .apache .spark .mllib .regression .LabeledPoint
27+ import org .apache .spark .mllib .tree .configuration .Strategy
2728import org .apache .spark .mllib .tree .configuration .Algo ._
2829import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
29- import org .apache .spark .mllib .tree .configuration .EnsembleCombiningStrategy .Average
30- import org .apache .spark .mllib .tree .configuration .Strategy
31- import org .apache .spark .mllib .tree .impl .{BaggedPoint , TreePoint , DecisionTreeMetadata , TimeTracker , NodeIdCache }
30+ import org .apache .spark .mllib .tree .impl .{BaggedPoint , DecisionTreeMetadata , NodeIdCache ,
31+ TimeTracker , TreePoint }
3232import org .apache .spark .mllib .tree .impurity .Impurities
3333import org .apache .spark .mllib .tree .model ._
3434import org .apache .spark .rdd .RDD
@@ -79,9 +79,9 @@ private class RandomForest (
7979 /**
8080 * Method to train a decision tree model over an RDD
8181 * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]]
82- * @return WeightedEnsembleModel that can be used for prediction
82+ * @return a random forest model that can be used for prediction
8383 */
84- def run (input : RDD [LabeledPoint ]): TreeEnsembleModel = {
84+ def run (input : RDD [LabeledPoint ]): RandomForestModel = {
8585
8686 val timer = new TimeTracker ()
8787
@@ -212,8 +212,7 @@ private class RandomForest (
212212 }
213213
214214 val trees = topNodes.map(topNode => new DecisionTreeModel (topNode, strategy.algo))
215- val treeWeights = Array .fill[Double ](numTrees)(1.0 )
216- new TreeEnsembleModel (trees, treeWeights, strategy.algo, Average )
215+ new RandomForestModel (strategy.algo, trees)
217216 }
218217
219218}
@@ -234,14 +233,14 @@ object RandomForest extends Serializable with Logging {
234233 * if numTrees > 1 (forest) set to "sqrt" for classification and
235234 * to "onethird" for regression.
236235 * @param seed Random seed for bootstrapping and choosing feature subsets.
237- * @return WeightedEnsembleModel that can be used for prediction
236+ * @return a random forest model that can be used for prediction
238237 */
239238 def trainClassifier (
240239 input : RDD [LabeledPoint ],
241240 strategy : Strategy ,
242241 numTrees : Int ,
243242 featureSubsetStrategy : String ,
244- seed : Int ): TreeEnsembleModel = {
243+ seed : Int ): RandomForestModel = {
245244 require(strategy.algo == Classification ,
246245 s " RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}" )
247246 val rf = new RandomForest (strategy, numTrees, featureSubsetStrategy, seed)
@@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging {
272271 * @param maxBins maximum number of bins used for splitting features
273272 * (suggested value: 100)
274273 * @param seed Random seed for bootstrapping and choosing feature subsets.
275- * @return WeightedEnsembleModel that can be used for prediction
274+ * @return a random forest model that can be used for prediction
276275 */
277276 def trainClassifier (
278277 input : RDD [LabeledPoint ],
@@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging {
283282 impurity : String ,
284283 maxDepth : Int ,
285284 maxBins : Int ,
286- seed : Int = Utils .random.nextInt()): TreeEnsembleModel = {
285+ seed : Int = Utils .random.nextInt()): RandomForestModel = {
287286 val impurityType = Impurities .fromString(impurity)
288287 val strategy = new Strategy (Classification , impurityType, maxDepth,
289288 numClassesForClassification, maxBins, Sort , categoricalFeaturesInfo)
@@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging {
302301 impurity : String ,
303302 maxDepth : Int ,
304303 maxBins : Int ,
305- seed : Int ): TreeEnsembleModel = {
304+ seed : Int ): RandomForestModel = {
306305 trainClassifier(input.rdd, numClassesForClassification,
307306 categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap,
308307 numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -322,14 +321,14 @@ object RandomForest extends Serializable with Logging {
322321 * if numTrees > 1 (forest) set to "sqrt" for classification and
323322 * to "onethird" for regression.
324323 * @param seed Random seed for bootstrapping and choosing feature subsets.
325- * @return WeightedEnsembleModel that can be used for prediction
324+ * @return a random forest model that can be used for prediction
326325 */
327326 def trainRegressor (
328327 input : RDD [LabeledPoint ],
329328 strategy : Strategy ,
330329 numTrees : Int ,
331330 featureSubsetStrategy : String ,
332- seed : Int ): TreeEnsembleModel = {
331+ seed : Int ): RandomForestModel = {
333332 require(strategy.algo == Regression ,
334333 s " RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}" )
335334 val rf = new RandomForest (strategy, numTrees, featureSubsetStrategy, seed)
@@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging {
359358 * @param maxBins maximum number of bins used for splitting features
360359 * (suggested value: 100)
361360 * @param seed Random seed for bootstrapping and choosing feature subsets.
362- * @return WeightedEnsembleModel that can be used for prediction
361+ * @return a random forest model that can be used for prediction
363362 */
364363 def trainRegressor (
365364 input : RDD [LabeledPoint ],
@@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging {
369368 impurity : String ,
370369 maxDepth : Int ,
371370 maxBins : Int ,
372- seed : Int = Utils .random.nextInt()): TreeEnsembleModel = {
371+ seed : Int = Utils .random.nextInt()): RandomForestModel = {
373372 val impurityType = Impurities .fromString(impurity)
374373 val strategy = new Strategy (Regression , impurityType, maxDepth,
375374 0 , maxBins, Sort , categoricalFeaturesInfo)
@@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging {
479478 3 * totalBins
480479 }
481480 }
482-
483481}
0 commit comments