@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
2828import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
2929import org .apache .spark .mllib .tree .configuration .FeatureType ._
3030import org .apache .spark .mllib .tree .configuration .Algo ._
31+ import org .apache .spark .mllib .tree .impurity .Impurity
3132
3233/**
3334A class that implements a decision tree algorithm for classification and regression.
@@ -38,7 +39,7 @@ algorithm (classification,
3839regression, etc.), feature type (continuous, categorical), depth of the tree,
3940quantile calculation strategy, etc.
4041 */
41- class DecisionTree (val strategy : Strategy ) extends Serializable with Logging {
42+ class DecisionTree private (val strategy : Strategy ) extends Serializable with Logging {
4243
4344 /**
4445 Method to train a decision tree model over an RDD
@@ -157,6 +158,70 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
157158
158159object DecisionTree extends Serializable with Logging {
159160
161+ /**
162+ Method to train a decision tree model over an RDD
163+
164+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
165+ for DecisionTree
166+ @param strategy The configuration parameters for the tree algorithm which specify the type of algorithm
167+ (classification, regression, etc.), feature type (continuous, categorical),
168+ depth of the tree, quantile calculation strategy, etc.
169+ @return a DecisionTreeModel that can be used for prediction
170+ */
171+ def train (input : RDD [LabeledPoint ], strategy : Strategy ) : DecisionTreeModel = {
172+ new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
173+ }
174+
175+ /**
176+ Method to train a decision tree model over an RDD
177+
178+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
179+ for DecisionTree
180+ @param algo classification or regression
181+ @param impurity criterion used for information gain calculation
182+ @param maxDepth maximum depth of the tree
183+ @return a DecisionTreeModel that can be used for prediction
184+ */
185+ def train (
186+ input : RDD [LabeledPoint ],
187+ algo : Algo ,
188+ impurity : Impurity ,
189+ maxDepth : Int
190+ ) : DecisionTreeModel = {
191+ val strategy = new Strategy (algo,impurity,maxDepth)
192+ new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
193+ }
194+
195+
196+ /**
197+ Method to train a decision tree model over an RDD
198+
199+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
200+ for DecisionTree
201+ @param algo classification or regression
202+ @param impurity criterion used for information gain calculation
203+ @param maxDepth maximum depth of the tree
204+ @param maxBins maximum number of bins used for splitting features
205+ @param quantileCalculationStrategy algorithm for calculating quantiles
206+ @param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete
207+ values they take. For example, an entry (n -> k) implies the feature n is
208+ categorical with k categories 0, 1, 2, ... , k-1. It's important to note that
209+ features are zero-indexed.
210+ @return a DecisionTreeModel that can be used for prediction
211+ */
212+ def train (
213+ input : RDD [LabeledPoint ],
214+ algo : Algo ,
215+ impurity : Impurity ,
216+ maxDepth : Int ,
217+ maxBins : Int ,
218+ quantileCalculationStrategy : QuantileStrategy ,
219+ categoricalFeaturesInfo : Map [Int ,Int ]
220+ ) : DecisionTreeModel = {
221+ val strategy = new Strategy (algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo)
222+ new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
223+ }
224+
160225 /**
161226 Returns an Array[Split] of optimal splits for all nodes at a given level
162227
@@ -717,13 +782,13 @@ object DecisionTree extends Serializable with Logging {
717782 for DecisionTree
718783 @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy ]] instance containing
719784 parameters for construction the DecisionTree
720- @return a tuple of (splits,bins) where Split is an Array[Array[ Split]] of size (numFeatures,
721- numSplits-1) and bins is an
722- Array[Array[Bin]] of size (numFeatures,numSplits1)
785+ @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model. Split] of
786+ size (numFeatures, numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of
787+ size (numFeatures,numSplits1)
723788 */
724789 def findSplitsBins (
725790 input : RDD [LabeledPoint ],
726- strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
791+ strategy : Strategy ): (Array [Array [Split ]], Array [Array [Bin ]]) = {
727792
728793 val count = input.count()
729794
0 commit comments