Skip to content

Commit b84f41c

Browse files
committed
address comments
1 parent a76da7b commit b84f41c

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,15 @@ object MovieLensALS {
142142
val predictions = model.transform(test).cache()
143143

144144
// Evaluate the model.
145+
// TODO: Create an evaluator to compute RMSE.
145146
val mse = predictions.select('rating, 'prediction)
146147
.flatMap { case Row(rating: Float, prediction: Float) =>
147148
val err = rating.toDouble - prediction
148149
val err2 = err * err
149150
if (err2.isNaN) {
150-
Iterator.empty
151+
None
151152
} else {
152-
Iterator.single(err2)
153+
Some(err2)
153154
}
154155
}.mean()
155156
val rmse = math.sqrt(mse)

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.recommendation
1919

20-
import java.{util => javaUtil}
20+
import java.{util => ju}
2121

2222
import scala.collection.mutable
2323

@@ -179,7 +179,7 @@ private object ALSModel {
179179
*
180180
* Essentially instead of finding the low-rank approximations to the rating matrix `R`,
181181
* this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
182-
* r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of
182+
* r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
183183
* indicated user
184184
* preferences rather than explicit ratings given to items.
185185
*/
@@ -314,9 +314,12 @@ private[recommendation] object ALS extends Logging {
314314
*/
315315
def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
316316
require(a.size == k)
317+
// Extension to the original paper to handle b < 0. confidence is a function of |b| instead
318+
// so that it is never negative.
317319
val confidence = 1.0 + alpha * math.abs(b)
318320
copyToDouble(a)
319321
blas.dspr(upper, k, confidence - 1.0, da, 1, ata)
322+
// For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0.
320323
if (b > 0) {
321324
blas.daxpy(k, confidence, da, 1, atb, 1)
322325
}
@@ -334,8 +337,8 @@ private[recommendation] object ALS extends Logging {
334337

335338
/** Resets everything to zero, which should be called after each solve. */
336339
def reset(): Unit = {
337-
javaUtil.Arrays.fill(ata, 0.0)
338-
javaUtil.Arrays.fill(atb, 0.0)
340+
ju.Arrays.fill(ata, 0.0)
341+
ju.Arrays.fill(atb, 0.0)
339342
n = 0
340343
}
341344
}
@@ -461,6 +464,7 @@ private[recommendation] object ALS extends Logging {
461464
ratings: Array[Float]) {
462465
/** Size of the block. */
463466
val size: Int = ratings.size
467+
464468
require(dstEncodedIndices.size == size)
465469
require(dstPtrs.size == srcIds.size + 1)
466470
}
@@ -473,6 +477,11 @@ private[recommendation] object ALS extends Logging {
473477
* @return initialized factor blocks
474478
*/
475479
private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = {
480+
// Choose a unit vector uniformly at random from the unit sphere, but from the
481+
// "first quadrant" where all elements are nonnegative. This can be done by choosing
482+
// elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
483+
// This appears to create factorizations that have a slightly better reconstruction
484+
// (<1%) compared picking elements uniformly at random in [0,1].
476485
inBlocks.map { case (srcBlockId, inBlock) =>
477486
val random = new XORShiftRandom(srcBlockId)
478487
val factors = Array.fill(inBlock.srcIds.size) {
@@ -799,7 +808,7 @@ private[recommendation] object ALS extends Logging {
799808
i += 1
800809
}
801810
assert(i == dstIdSet.size)
802-
javaUtil.Arrays.sort(sortedDstIds)
811+
ju.Arrays.sort(sortedDstIds)
803812
val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size)
804813
i = 0
805814
while (i < sortedDstIds.size) {
@@ -826,7 +835,7 @@ private[recommendation] object ALS extends Logging {
826835
val seen = new Array[Boolean](dstPart.numPartitions)
827836
while (i < srcIds.size) {
828837
var j = dstPtrs(i)
829-
javaUtil.Arrays.fill(seen, false)
838+
ju.Arrays.fill(seen, false)
830839
while (j < dstPtrs(i + 1)) {
831840
val dstBlockId = encoder.blockId(dstEncodedIndices(j))
832841
if (!seen(dstBlockId)) {

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double)
9090
*
9191
* Essentially instead of finding the low-rank approximations to the rating matrix `R`,
9292
* this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
93-
* r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of
93+
* r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
9494
* indicated user
9595
* preferences rather than explicit ratings given to items.
9696
*/

0 commit comments

Comments
 (0)