Skip to content

Commit 352001f

Browse files
committed
Minor
1 parent 6e8aa10 commit 352001f

File tree

5 files changed

+14
-38
lines changed

5 files changed

+14
-38
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,8 @@ object AbsoluteError extends Loss {
4747
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
4848
}
4949

50-
/**
51-
* Method to calculate loss when the predictions are already known.
52-
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
53-
* predicted values from previously fit trees.
54-
* @param prediction Predicted label.
55-
* @param datum LabeledPoint.
56-
* @return Absolute error of model on the given datapoint.
57-
*/
58-
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
59-
val err = datum.label - prediction
50+
override def computeError(prediction: Double, label: Double): Double = {
51+
val err = label - prediction
6052
math.abs(err)
6153
}
6254

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,8 @@ object LogLoss extends Loss {
5050
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
5151
}
5252

53-
/**
54-
* Method to calculate loss when the predictions are already known.
55-
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
56-
* predicted values from previously fit trees.
57-
* @param prediction Predicted label.
58-
* @param datum LabeledPoint
59-
* @return log loss of model on the datapoint.
60-
*/
61-
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
62-
val margin = 2.0 * datum.label * prediction
53+
override def computeError(prediction: Double, label: Double): Double = {
54+
val margin = 2.0 * label * prediction
6355
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
6456
2.0 * MLUtils.log1pExp(-margin)
6557
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ trait Loss extends Serializable {
4848
* @return Measure of model error on data
4949
*/
5050
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
51-
data.map(point => computeError(model.predict(point.features), point)).mean()
51+
data.map(point => computeError(model.predict(point.features), point.label)).mean()
5252
}
5353

5454
/**
5555
* Method to calculate loss when the predictions are already known.
5656
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
5757
* predicted values from previously fit trees.
5858
* @param prediction Predicted label.
59-
* @param datum LabeledPoint
59+
* @param label True label.
6060
* @return Measure of model error on datapoint.
6161
*/
62-
def computeError(prediction: Double, datum: LabeledPoint): Double
62+
def computeError(prediction: Double, label: Double): Double
6363

6464
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,8 @@ object SquaredError extends Loss {
4747
2.0 * (model.predict(point.features) - point.label)
4848
}
4949

50-
/**
51-
* Method to calculate loss when the predictions are already known.
52-
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
53-
* predicted values from previously fit trees.
54-
* @param prediction Predicted label.
55-
* @param datum LabeledPoint
56-
* @return Mean squared error of model on datapoint.
57-
*/
58-
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
59-
val err = prediction - datum.label
50+
override def computeError(prediction: Double, label: Double): Double = {
51+
val err = prediction - label
6052
err * err
6153
}
6254

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class GradientBoostedTreesModel(
133133

134134
var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
135135
val pred = treeWeights(0) * trees(0).predict(i.features)
136-
val error = loss.computeError(pred, i)
136+
val error = loss.computeError(pred, i.label)
137137
(pred, error)
138138
}
139139
evaluationArray(0) = predictionAndError.values.mean()
@@ -143,13 +143,13 @@ class GradientBoostedTreesModel(
143143
val broadcastWeights = sc.broadcast(treeWeights)
144144

145145
(1 until numIterations).map { nTree =>
146-
val currentTree = broadcastTrees.value(nTree)
147-
val currentTreeWeight = broadcastWeights.value(nTree)
148146
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
149-
iter map {
147+
val currentTree = broadcastTrees.value(nTree)
148+
val currentTreeWeight = broadcastWeights.value(nTree)
149+
iter.map {
150150
case (point, (pred, error)) => {
151151
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
152-
val newError = loss.computeError(newPred, point)
152+
val newError = loss.computeError(newPred, point.label)
153153
(newPred, newError)
154154
}
155155
}

0 commit comments

Comments
 (0)