Skip to content

Commit 0b5d659

Browse files
committed
minor
1 parent 32d409d commit 0b5d659

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,18 @@ class GradientBoostedTreesModel(
130130

131131
val numIterations = trees.length
132132
val evaluationArray = Array.fill(numIterations)(0.0)
133+
val localTreeWeights = treeWeights
133134

134135
var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
135-
remappedData, treeWeights(0), trees(0), loss)
136+
remappedData, localTreeWeights(0), trees(0), loss)
136137

137138
evaluationArray(0) = predictionAndError.values.mean()
138139

139140
val broadcastTrees = sc.broadcast(trees)
140141
(1 until numIterations).map { nTree =>
141142
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
142143
val currentTree = broadcastTrees.value(nTree)
143-
val currentTreeWeight = treeWeights(nTree)
144+
val currentTreeWeight = localTreeWeights(nTree)
144145
iter.map { case (point, (pred, error)) =>
145146
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
146147
val newError = loss.computeError(newPred, point.label)
@@ -198,8 +199,6 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
198199
tree: DecisionTreeModel,
199200
loss: Loss): RDD[(Double, Double)] = {
200201

201-
val sc = data.sparkContext
202-
203202
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
204203
iter.map {
205204
case (lp, (pred, error)) => {

0 commit comments

Comments
 (0)