Skip to content

Commit 0b2c29c

Browse files
author
DB Tsai
committed
first commit
1 parent 8e253eb commit 0b2c29c

File tree

1 file changed

+5
-5
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/optimization

1 file changed

+5
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,16 @@ class LogisticGradient extends Gradient {
9494
* :: DeveloperApi ::
9595
* Compute gradient and loss for a Least-squared loss function, as used in linear regression.
9696
* This is correct for the averaged least squares loss function (mean squared error)
97-
* L = 1/n ||A weights-y||^2
97+
* L = 1/2n ||A weights-y||^2
9898
* See also the documentation for the precise formulation.
9999
*/
100100
@DeveloperApi
101101
class LeastSquaresGradient extends Gradient {
102102
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
103103
val diff = dot(data, weights) - label
104-
val loss = diff * diff
104+
val loss = diff * diff / 2.0
105105
val gradient = data.copy
106-
scal(2.0 * diff, gradient)
106+
scal(diff, gradient)
107107
(gradient, loss)
108108
}
109109

@@ -113,8 +113,8 @@ class LeastSquaresGradient extends Gradient {
113113
weights: Vector,
114114
cumGradient: Vector): Double = {
115115
val diff = dot(data, weights) - label
116-
axpy(2.0 * diff, data, cumGradient)
117-
diff * diff
116+
axpy(diff, data, cumGradient)
117+
diff * diff / 2.0
118118
}
119119
}
120120

0 commit comments

Comments
 (0)