diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index e49363c2c64d..8d9559856ddc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -241,16 +241,27 @@ object LBFGS extends Logging { val bcW = data.context.broadcast(w) val localGradient = gradient - val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = localGradient.compute( - features, label, bcW.value, grad) + /** Given (current accumulated gradient, current loss) and (label, features) + * tuples, updates the current gradient and current loss + */ + val seqOp = (c: (Vector, Double), v: (Double, Vector)) => + (c, v) match { + case ((grad, loss), (label, features)) => + val l = localGradient.compute(features, label, bcW.value, grad) (grad, loss + l) - }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + } + + // Adds two (gradient, loss) tuples + val combOp = (c1: (Vector, Double), c2: (Vector, Double)) => + (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => axpy(1.0, grad2, grad1) (grad1, loss1 + loss2) - }) + } + + val (gradientSum, lossSum) = data.mapPartitions { it => { + val inPartitionAggregated = it.aggregate((Vectors.zeros(n), 0.0))(seqOp, combOp) + Iterator(inPartitionAggregated) + }}.treeReduce(combOp) /** * regVal is sum of weight squares if it's L2 updater;