@@ -447,21 +447,20 @@ class ALS private (
447447 dspr(1.0 , x, tempXtX)
448448 val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
449449 for (i <- 0 until us.length) {
450- implicitPrefs match {
451- case false =>
452- userXtX(us(i)).addi(tempXtX)
453- SimpleBlas .axpy(rs(i), x, userXy(us(i)))
454- case true =>
455- // Extension to the original paper to handle rs(i) < 0. confidence is a function
456- // of |rs(i)| instead so that it is never negative:
457- val confidence = 1 + alpha * abs(rs(i))
458- SimpleBlas .axpy(confidence - 1.0 , tempXtX, userXtX(us(i)))
459- // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
460- // means we try to reconstruct 0. We add terms only where P = 1, so, term below
461- // is now only added for rs(i) > 0:
462- if (rs(i) > 0 ) {
463- SimpleBlas .axpy(confidence, x, userXy(us(i)))
464- }
450+ if (implicitPrefs) {
451+ // Extension to the original paper to handle rs(i) < 0. confidence is a function
452+ // of |rs(i)| instead so that it is never negative:
453+ val confidence = 1 + alpha * abs(rs(i))
454+ SimpleBlas .axpy(confidence - 1.0 , tempXtX, userXtX(us(i)))
455+ // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
456+ // means we try to reconstruct 0. We add terms only where P = 1, so, term below
457+ // is now only added for rs(i) > 0:
458+ if (rs(i) > 0 ) {
459+ SimpleBlas .axpy(confidence, x, userXy(us(i)))
460+ }
461+ } else {
462+ userXtX(us(i)).addi(tempXtX)
463+ SimpleBlas .axpy(rs(i), x, userXy(us(i)))
465464 }
466465 }
467466 }
@@ -474,9 +473,10 @@ class ALS private (
474473 // Add regularization
475474 (0 until rank).foreach(i => fullXtX.data(i* rank + i) += lambda)
476475 // Solve the resulting matrix, which is symmetric and positive-definite
477- implicitPrefs match {
478- case false => Solve .solvePositive(fullXtX, userXy(index)).data
479- case true => Solve .solvePositive(fullXtX.addi(YtY .get.value), userXy(index)).data
476+ if (implicitPrefs) {
477+ Solve .solvePositive(fullXtX.addi(YtY .get.value), userXy(index)).data
478+ } else {
479+ Solve .solvePositive(fullXtX, userXy(index)).data
480480 }
481481 }
482482 }
0 commit comments