Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
0ae1c0a
basic gradient boosting code from earlier branches
manishamde Sep 28, 2014
5538521
disable checkpointing for now
manishamde Sep 28, 2014
6251fd5
modified method name
manishamde Oct 1, 2014
cdceeef
added documentation
manishamde Oct 1, 2014
f1c9ef7
Merge branch 'master' into gbt
manishamde Oct 6, 2014
1a8031c
sampling with replacement
manishamde Oct 6, 2014
aa8fae7
minor refactoring
manishamde Oct 6, 2014
3973dd1
minor indicating subsample is double during comparison
manishamde Oct 6, 2014
78ed452
added newline and fixed if statement
manishamde Oct 6, 2014
4784091
formatting
manishamde Oct 7, 2014
62cc000
basic checkpointing
manishamde Oct 11, 2014
9af0231
classification attempt
manishamde Oct 11, 2014
6dd4dd8
added support for log loss
manishamde Oct 12, 2014
2fbc9c7
fixing binomial classification prediction
manishamde Oct 12, 2014
bdca43a
added timing parameters
manishamde Oct 12, 2014
f62bc48
added unpersist
manishamde Oct 12, 2014
8e10c63
modified unpersist strategy
manishamde Oct 12, 2014
3b8ffc0
added documentation
manishamde Oct 12, 2014
2cb1258
public API support
manishamde Oct 13, 2014
631baea
Merge branch 'master' into gbt
manishamde Oct 13, 2014
9155a9d
consolidated boosting configuration and added public API
manishamde Oct 13, 2014
5ab3796
minor reformatting
manishamde Oct 13, 2014
5b67102
shortened parameter list
manishamde Oct 14, 2014
1f47941
changing access modifier
manishamde Oct 14, 2014
823691b
fixing RF test
manishamde Oct 14, 2014
6a11c02
fixing formatting
manishamde Oct 20, 2014
9b2e35e
Merge branch 'master' into gbt
manishamde Oct 20, 2014
3b43896
added learning rate for prediction
manishamde Oct 20, 2014
9366b8f
minor: using numTrees instead of trees.size
manishamde Oct 20, 2014
2ae97b7
added documentation for the loss classes
manishamde Oct 20, 2014
d2c8323
Merge branch 'master' into gbt
manishamde Oct 26, 2014
9bc6e74
adding random seed as parameter
manishamde Oct 26, 2014
1b01943
add weights for base learners
manishamde Oct 26, 2014
fee06d3
added weighted ensemble model
manishamde Oct 26, 2014
d971f73
moved RF code to use WeightedEnsembleModel class
manishamde Oct 26, 2014
0e81906
improving caching unpersisting logic
manishamde Oct 26, 2014
3a18cc1
cleaned up api for conversion to bagged point and moved tests to it's…
manishamde Oct 26, 2014
781542a
added support for fractional subsampling with replacement
manishamde Oct 26, 2014
a32a5ab
added test for subsampling without replacement
manishamde Oct 26, 2014
3fd0528
moved helper methods to new class
manishamde Oct 27, 2014
eff21fe
Added gradient boosting tests
manishamde Oct 27, 2014
49ba107
merged from master
manishamde Oct 27, 2014
9f7359d
simplified gbt logic and added more tests
manishamde Oct 30, 2014
035a2ed
jkbradley formatting suggestions
manishamde Oct 30, 2014
eadbf09
parameter renaming
manishamde Oct 30, 2014
e33ab61
minor comment
manishamde Oct 30, 2014
1c40c33
add newline, removed spaces
manishamde Oct 30, 2014
0183cb9
fixed naming and formatting issues
manishamde Oct 30, 2014
8476b6b
fixing line length
manishamde Oct 30, 2014
b4c1318
removing spaces
manishamde Oct 30, 2014
ff2a796
addressing comments
manishamde Oct 31, 2014
991c7b5
public api
manishamde Oct 31, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think that Boosted Model can be a separate one from RandomForestModel. E.g., it's not inconceivable to have boosted models to use RandomForestModel as its base learners.

And if this were a truly generic weighted ensemble model, then it could probably live outside of tree.model namespace, since boosting at least in theory doesn't care whether base learners are trees or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These generalizations will rely on the new ML API (for which there will be a PR any day now); it makes sense to keep it in the tree namespace since there is not generic Estimator concept currently. But once we can, I agree it will be important to generalize meta-algorithms.

With respect to the models, I don't see how the model concepts are different. The learning algorithms are different, but that will not prevent a meta-algorithm to use another meta-algorithm as a weak learner (once the new API is available). (I think it's good to separate the concepts of Estimator (learning algorithm) and Transformer (learned model) here.) What do you think?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I guess from the design perspective, it's tempting to unify these under the same umbrella.

IMO, RandomForest is mostly a specific instance of a generic ensemble model, so this makes sense.

However, I think that boosted models have some specific things about them due to their sequential nature (as opposed to parallel nature of RandomForest). E.g., if you have 1000 models, you can potentially predict based on the first 100 models whereas with RandomForest you can pick any 100. You also have to do overfitting/underfitting analyses on boosted models sequentially, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codedeft I started with a separate model for boosting but @jkbradley (quite correctly IMO) convinced me otherwise. :-)

I agree methods like boosting require support such as early stopping, sequential selection of models, etc. but may be we can handle it as a part of the model configuration. AdaBoost and RF in some ways are more similar than AdaBoost and GBT in their combining operation. It might be better to capture all these nuances in one place. Of course, we can always split them later if we end up writing a lot of custom logic for each algorithm. Thoughts?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manishamde Sounds good.

Just a side note. Because RF models tend to be much bigger than boosted ensembles, we've encountered situations where the model was too big to fit in a single machine memory. RandomForest model is in a way a good model for embarassingly parallel predictions so a model could potentially reside in a distributed fashion.

But we haven't yet decided whether we really want to do this (i.e. are humongous models really useful in practice and do we really expect crazy scenarios of gigantic models surpassing dozens of GBs?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codedeft Agree about the distributed storage though I never bothered to check the size of deep trees in memory! :-) In fact, such a storage might be a good option for Partial Forest implementation.

import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -317,7 +317,7 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
rfModel.trees(0)
rfModel.weakHypotheses(0)
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
import org.apache.spark.Logging
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.loss.Losses
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum

/**
* :: Experimental ::
* A class that implements gradient boosting for regression and binary classification problems.
* @param boostingStrategy Parameters for the gradient boosting algorithm
*/
@Experimental
class GradientBoosting (
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {

/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return WeightedEnsembleModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
val algo = boostingStrategy.algo
algo match {
case Regression => GradientBoosting.boost(input, boostingStrategy)
case Classification =>
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoosting.boost(remappedInput, boostingStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}

}


object GradientBoosting extends Logging {

/**
* Method to train a gradient boosting model.
*
* Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
* is recommended to clearly specify regression.
* Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
* is recommended to clearly specify regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return WeightedEnsembleModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Method to train a gradient boosting classification model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
val algo = boostingStrategy.algo
require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Method to train a gradient boosting regression model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
val algo = boostingStrategy.algo
require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
new GradientBoosting(boostingStrategy).train(input)
}

/**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline needed

* Method to train a gradient boosting binary classification model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param numEstimators Number of estimators used in boosting stages. In other words,
* number of boosting iterations performed.
* @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
* @param subsamplingRate Fraction of the training data used for learning the decision tree.
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
* the number of discrete values they take. For example,
* an entry (n -> k) implies the feature n is categorical with k
* categories 0, 1, 2, ... , k-1. It's important to note that
* features are zero-indexed.
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
* supported.)
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo: Map[Int, Int],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
val lossType = Losses.fromString(loss)
val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
weakLearnerParams)
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Method to train a gradient boosting regression model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param numEstimators Number of estimators used in boosting stages. In other words,
* number of boosting iterations performed.
* @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
* @param subsamplingRate Fraction of the training data used for learning the decision tree.
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
* the number of discrete values they take. For example,
* an entry (n -> k) implies the feature n is categorical with k
* categories 0, 1, 2, ... , k-1. It's important to note that
* features are zero-indexed.
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
* supported.)
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo: Map[Int, Int],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
val lossType = Losses.fromString(loss)
val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
weakLearnerParams)
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
*/
def trainClassifier(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
weakLearnerParams)
}

/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
*/
def trainRegressor(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
weakLearnerParams)
}


/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param boostingStrategy boosting parameters
* @return
*/
private def boost(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {

val timer = new TimeTracker()
timer.start("total")
timer.start("init")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra newline

// Initialize gradient boosting parameters
val numEstimators = boostingStrategy.numEstimators
val baseLearners = new Array[DecisionTreeModel](numEstimators)
val baseLearnerWeights = new Array[Double](numEstimators)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
val strategy = boostingStrategy.weakLearnerParams

// Cache input
input.persist(StorageLevel.MEMORY_AND_DISK)

timer.stop("init")

logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
var data = input

// 1. Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(strategy).train(data)
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be learningRate too, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the learning rate is applied after the first model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Friedman paper, the first "model" is just the average label (for squared error). I think it's fine to keep it as is; that way, running for just 1 iteration will behave reasonably.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(learningRate)

Sum)
logDebug("error of gbt = " + loss.computeError(startingModel, input))
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")

// psuedo-residual for second iteration
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
point.features))

var m = 1
while (m < numEstimators) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val model = new DecisionTree(strategy).train(data)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that this will result in repetitive sampling/ re-discretization and etc. of the entire data set every iteration. Additionally, repersisting the entire dataset seems very expensive, in particular if the dataset (LabeledPoint) is initially coming from the disk.

I think that the optimal thing to do is:

  1. Discretize the features and persist the entire discretized features only once.
  2. Calculate the new labels after each iteration, and create a separate RDD of these new labels, and persist them.
  3. zip the new labels with the discretized features and reuse the DecisionTree's regression logic.

This will require some modifications of internal DecisionTree.train but it seems to me the better thing to do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on @manishamde 's PR description, I think the plan is to do this optimization later. Keeping it a separate PR is helpful for reducing conflict with your node ID caching PR [https://github.com//pull/2868]. I feel like it is easier to break things into smaller PRs. Also, since this type of optimization will likely be useful for other meta-algorithms, it will be good to think about a standard interface for getting a learning algorithm's internal data representation (and the related prediction methods which take that internal representation).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. So I suppose you want to provide the functionality first and then optimize later ;).

I'm not sure though about whether this is going to result in re-reading from the disk the input at every iteration. Maybe I'm wrong. But a simple change could be simply persisting features all the time, and re-persisting newly calculated labels periodically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codedeft Thanks for your comment. Your observation is correct. Conversion to the internal discretized/binned storage format will definitely lead to a faster implementation and lower memory consumption on the cluster. As @jkbradley mentioned, we decided to work on it after the generic MLlib API work has been completed. We can then use methods such as trainUsingInternalFormat and predictUsingInternalFormat if the underlying algo (in this case DecisionTree) supports it.

We won't be re-reading from disk at every iteration but caching the training data at the first iteration and checkpointing/persisting every few iterations to avoid long lineage chains. Will comment on the checkpointing further in the other thread.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing, I think that the decision tree itself does persisting of discretized data. So it seems that this could potentially require doubly persisted datasets (one LabeledPoint and the other one TreePoint)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. That's the big disadvantage of not using the internal format. It won't affect other algos as much since there is no discretization.

We have a few options:

  1. Keep the implementation as is and inform the user about memory requirements.
  2. Persisting RDD[TreePoint] is essential since we perform multiple passes on it during each tree construction and reading RDD[LabeledPoint] from disk every time.
  3. Persisting RDD[LabeledPoint] and not caching RDD[TreePoint] during tree construction leading to repeated LabeledPoint -> TreePoint conversions for each NodeGroup.

Thoughts? cc: @jkbradley

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, I envision:
(1) GradientBoosting gets 1 copy of the data from the weak learner (RDD[TreePoint] for DecisionTree) and persists it.
(2) DecisionTree persists the NodeIdCache (possibly storing 2 copies of the cache). GradientBoosting tells DecisionTree not to serialize anything.
(3) GradientBoosting persists the label (as @codedeft suggested) only, and periodically serializes it.

For now, I saw we either keep it as is and add a warning, or spend a little time refactoring to just persist the labels per the suggestion from @codedeft

timer.stop(s"building tree $m")
// Create partial model
baseLearners(m) = model
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
baseLearnerWeights.slice(0, m + 1), Regression, Sum)
logDebug("error of gbt = " + loss.computeError(partialModel, input))
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
point.features))
m += 1
}

timer.stop("total")

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")


// 3. Output classifier
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)

}

}
Loading