File tree Expand file tree Collapse file tree 5 files changed +14
-38
lines changed
mllib/src/main/scala/org/apache/spark/mllib/tree Expand file tree Collapse file tree 5 files changed +14
-38
lines changed Original file line number Diff line number Diff line change @@ -47,16 +47,8 @@ object AbsoluteError extends Loss {
4747 if ((point.label - model.predict(point.features)) < 0 ) 1.0 else - 1.0
4848 }
4949
50- /**
51- * Method to calculate loss when the predictions are already known.
52- * Note: This method is used in the method evaluateEachIteration to avoid recomputing the
53- * predicted values from previously fit trees.
54- * @param prediction Predicted label.
55- * @param datum LabeledPoint.
56- * @return Absolute error of model on the given datapoint.
57- */
58- override def computeError (prediction : Double , datum : LabeledPoint ): Double = {
59- val err = datum.label - prediction
50+ override def computeError (prediction : Double , label : Double ): Double = {
51+ val err = label - prediction
6052 math.abs(err)
6153 }
6254
Original file line number Diff line number Diff line change @@ -50,16 +50,8 @@ object LogLoss extends Loss {
5050 - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
5151 }
5252
53- /**
54- * Method to calculate loss when the predictions are already known.
55- * Note: This method is used in the method evaluateEachIteration to avoid recomputing the
56- * predicted values from previously fit trees.
57- * @param prediction Predicted label.
58- * @param datum LabeledPoint
59- * @return log loss of model on the datapoint.
60- */
61- override def computeError (prediction : Double , datum : LabeledPoint ): Double = {
62- val margin = 2.0 * datum.label * prediction
53+ override def computeError (prediction : Double , label : Double ): Double = {
54+ val margin = 2.0 * label * prediction
6355 // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
6456 2.0 * MLUtils .log1pExp(- margin)
6557 }
Original file line number Diff line number Diff line change @@ -48,17 +48,17 @@ trait Loss extends Serializable {
4848 * @return Measure of model error on data
4949 */
5050 def computeError (model : TreeEnsembleModel , data : RDD [LabeledPoint ]): Double = {
51- data.map(point => computeError(model.predict(point.features), point)).mean()
51+ data.map(point => computeError(model.predict(point.features), point.label )).mean()
5252 }
5353
5454 /**
5555 * Method to calculate loss when the predictions are already known.
5656 * Note: This method is used in the method evaluateEachIteration to avoid recomputing the
5757 * predicted values from previously fit trees.
5858 * @param prediction Predicted label.
59- * @param datum LabeledPoint
59+ * @param label True label.
6060 * @return Measure of model error on datapoint.
6161 */
62- def computeError (prediction : Double , datum : LabeledPoint ): Double
62+ def computeError (prediction : Double , label : Double ): Double
6363
6464}
Original file line number Diff line number Diff line change @@ -47,16 +47,8 @@ object SquaredError extends Loss {
4747 2.0 * (model.predict(point.features) - point.label)
4848 }
4949
50- /**
51- * Method to calculate loss when the predictions are already known.
52- * Note: This method is used in the method evaluateEachIteration to avoid recomputing the
53- * predicted values from previously fit trees.
54- * @param prediction Predicted label.
55- * @param datum LabeledPoint
56- * @return Mean squared error of model on datapoint.
57- */
58- override def computeError (prediction : Double , datum : LabeledPoint ): Double = {
59- val err = prediction - datum.label
50+ override def computeError (prediction : Double , label : Double ): Double = {
51+ val err = prediction - label
6052 err * err
6153 }
6254
Original file line number Diff line number Diff line change @@ -133,7 +133,7 @@ class GradientBoostedTreesModel(
133133
134134 var predictionAndError : RDD [(Double , Double )] = remappedData.map { i =>
135135 val pred = treeWeights(0 ) * trees(0 ).predict(i.features)
136- val error = loss.computeError(pred, i)
136+ val error = loss.computeError(pred, i.label )
137137 (pred, error)
138138 }
139139 evaluationArray(0 ) = predictionAndError.values.mean()
@@ -143,13 +143,13 @@ class GradientBoostedTreesModel(
143143 val broadcastWeights = sc.broadcast(treeWeights)
144144
145145 (1 until numIterations).map { nTree =>
146- val currentTree = broadcastTrees.value(nTree)
147- val currentTreeWeight = broadcastWeights.value(nTree)
148146 predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
149- iter map {
147+ val currentTree = broadcastTrees.value(nTree)
148+ val currentTreeWeight = broadcastWeights.value(nTree)
149+ iter.map {
150150 case (point, (pred, error)) => {
151151 val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
152- val newError = loss.computeError(newPred, point)
152+ val newError = loss.computeError(newPred, point.label )
153153 (newPred, newError)
154154 }
155155 }
You can’t perform that action at this time.
0 commit comments