@@ -148,8 +148,10 @@ class ALS private (
148148 * Returns a MatrixFactorizationModel with feature vectors for each user and product.
149149 */
150150 def run (ratings : RDD [Rating ]): MatrixFactorizationModel = {
151+ val sc = ratings.context
152+
151153 val numBlocks = if (this .numBlocks == - 1 ) {
152- math.max(ratings.context .defaultParallelism, ratings.partitions.size / 2 )
154+ math.max(sc .defaultParallelism, ratings.partitions.size / 2 )
153155 } else {
154156 this .numBlocks
155157 }
@@ -187,53 +189,81 @@ class ALS private (
187189 }
188190 }
189191
190- for (iter <- 1 to iterations) {
191- // perform ALS update
192- logInfo(" Re-computing I given U (Iteration %d/%d)" .format(iter, iterations))
193- // YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model
194- val YtY = computeYtY(users)
195- val YtYb = ratings.context.broadcast(YtY )
196- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
197- alpha, YtYb )
198- logInfo(" Re-computing U given I (Iteration %d/%d)" .format(iter, iterations))
199- val XtX = computeYtY(products)
200- val XtXb = ratings.context.broadcast(XtX )
201- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
202- alpha, XtXb )
192+ if (implicitPrefs) {
193+ for (iter <- 1 to iterations) {
194+ // perform ALS update
195+ logInfo(" Re-computing I given U (Iteration %d/%d)" .format(iter, iterations))
196+ // Persist users because it will be called twice.
197+ users.persist()
198+ val YtY = Some (sc.broadcast(computeYtY(users)))
199+ val previousProducts = products
200+ products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
201+ alpha, YtY )
202+ previousProducts.unpersist()
203+ logInfo(" Re-computing U given I (Iteration %d/%d)" .format(iter, iterations))
204+ products.persist()
205+ val XtX = Some (sc.broadcast(computeYtY(products)))
206+ val previousUsers = users
207+ users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
208+ alpha, XtX )
209+ previousUsers.unpersist()
210+ }
211+ } else {
212+ for (iter <- 1 to iterations) {
213+ // perform ALS update
214+ logInfo(" Re-computing I given U (Iteration %d/%d)" .format(iter, iterations))
215+ products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
216+ alpha, YtY = None )
217+ logInfo(" Re-computing U given I (Iteration %d/%d)" .format(iter, iterations))
218+ users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
219+ alpha, YtY = None )
220+ }
203221 }
204222
223+ // The last `products` will be used twice. One to generate the last `users` and the other to
224+ // generate `productsOut`. So we cache it for better performance.
225+ products.persist()
226+
205227 // Flatten and cache the two final RDDs to un-block them
206228 val usersOut = unblockFactors(users, userOutLinks)
207229 val productsOut = unblockFactors(products, productOutLinks)
208230
209231 usersOut.persist()
210232 productsOut.persist()
211233
234+ // Materialize usersOut and productsOut.
235+ usersOut.count()
236+ productsOut.count()
237+
238+ products.unpersist()
239+
240+ // Clean up.
241+ userInLinks.unpersist()
242+ userOutLinks.unpersist()
243+ productInLinks.unpersist()
244+ productOutLinks.unpersist()
245+
212246 new MatrixFactorizationModel (rank, usersOut, productsOut)
213247 }
214248
215249 /**
216250 * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
217- * for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as
218- * the driver program requires `YtY` to broadcast it to the slaves
251+ * for each user (or product), in a distributed fashion.
252+ *
219253 * @param factors the (block-distributed) user or product factor vectors
220- * @return Option[ YtY] - whose value is only used in the implicit preference model
254+ * @return YtY - whose value is only used in the implicit preference model
221255 */
222- def computeYtY (factors : RDD [(Int , Array [Array [Double ]])]) = {
223- if (implicitPrefs) {
224- val n = rank * (rank + 1 ) / 2
225- val LYtY = factors.values.aggregate(new DoubleMatrix (n))( seqOp = (L , Y ) => {
226- Y .foreach(y => dspr(1.0 , new DoubleMatrix (y), L ))
227- L
228- }, combOp = (L1 , L2 ) => {
229- L1 .addi(L2 )
230- })
231- val YtY = new DoubleMatrix (rank, rank)
232- fillFullMatrix(LYtY , YtY )
233- Option (YtY )
234- } else {
235- None
236- }
256+ private def computeYtY (factors : RDD [(Int , Array [Array [Double ]])]) = {
257+ val n = rank * (rank + 1 ) / 2
258+ val LYtY = factors.values.aggregate(new DoubleMatrix (n))( seqOp = (L , Y ) => {
259+ Y .foreach(y => dspr(1.0 , new DoubleMatrix (y), L ))
260+ L
261+ }, combOp = (L1 , L2 ) => {
262+ L1 .addi(L2 )
263+ })
264+ val YtY = new DoubleMatrix (rank, rank)
265+ fillFullMatrix(LYtY , YtY )
266+ YtY
237267 }
238268
239269 /**
@@ -264,7 +294,7 @@ class ALS private (
264294 /**
265295 * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
266296 */
267- def unblockFactors (blockedFactors : RDD [(Int , Array [Array [Double ]])],
297+ private def unblockFactors (blockedFactors : RDD [(Int , Array [Array [Double ]])],
268298 outLinks : RDD [(Int , OutLinkBlock )]) = {
269299 blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
270300 for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
@@ -332,8 +362,11 @@ class ALS private (
332362 val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
333363 Iterator .single((blockId, (inLinkBlock, outLinkBlock)))
334364 }, true )
335- links.persist(StorageLevel .MEMORY_AND_DISK )
336- (links.mapValues(_._1), links.mapValues(_._2))
365+ val inLinks = links.mapValues(_._1)
366+ val outLinks = links.mapValues(_._2)
367+ inLinks.persist(StorageLevel .MEMORY_AND_DISK )
368+ outLinks.persist(StorageLevel .MEMORY_AND_DISK )
369+ (inLinks, outLinks)
337370 }
338371
339372 /**
@@ -365,7 +398,7 @@ class ALS private (
365398 rank : Int ,
366399 lambda : Double ,
367400 alpha : Double ,
368- YtY : Broadcast [ Option [DoubleMatrix ]])
401+ YtY : Option [ Broadcast [DoubleMatrix ]])
369402 : RDD [(Int , Array [Array [Double ]])] =
370403 {
371404 val numBlocks = products.partitions.size
@@ -388,8 +421,8 @@ class ALS private (
388421 * Compute the new feature vectors for a block of the users matrix given the list of factors
389422 * it received from each product and its InLinkBlock.
390423 */
391- def updateBlock (messages : Seq [(Int , Array [Array [Double ]])], inLinkBlock : InLinkBlock ,
392- rank : Int , lambda : Double , alpha : Double , YtY : Broadcast [ Option [DoubleMatrix ]])
424+ private def updateBlock (messages : Seq [(Int , Array [Array [Double ]])], inLinkBlock : InLinkBlock ,
425+ rank : Int , lambda : Double , alpha : Double , YtY : Option [ Broadcast [DoubleMatrix ]])
393426 : Array [Array [Double ]] =
394427 {
395428 // Sort the incoming block factor messages by block ID and make them an array
@@ -416,21 +449,20 @@ class ALS private (
416449 dspr(1.0 , x, tempXtX)
417450 val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
418451 for (i <- 0 until us.length) {
419- implicitPrefs match {
420- case false =>
421- userXtX(us(i)).addi(tempXtX)
422- SimpleBlas .axpy(rs(i), x, userXy(us(i)))
423- case true =>
424- // Extension to the original paper to handle rs(i) < 0. confidence is a function
425- // of |rs(i)| instead so that it is never negative:
426- val confidence = 1 + alpha * abs(rs(i))
427- SimpleBlas .axpy(confidence - 1.0 , tempXtX, userXtX(us(i)))
428- // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
429- // means we try to reconstruct 0. We add terms only where P = 1, so, term below
430- // is now only added for rs(i) > 0:
431- if (rs(i) > 0 ) {
432- SimpleBlas .axpy(confidence, x, userXy(us(i)))
433- }
452+ if (implicitPrefs) {
453+ // Extension to the original paper to handle rs(i) < 0. confidence is a function
454+ // of |rs(i)| instead so that it is never negative:
455+ val confidence = 1 + alpha * abs(rs(i))
456+ SimpleBlas .axpy(confidence - 1.0 , tempXtX, userXtX(us(i)))
457+ // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
458+ // means we try to reconstruct 0. We add terms only where P = 1, so, term below
459+ // is now only added for rs(i) > 0:
460+ if (rs(i) > 0 ) {
461+ SimpleBlas .axpy(confidence, x, userXy(us(i)))
462+ }
463+ } else {
464+ userXtX(us(i)).addi(tempXtX)
465+ SimpleBlas .axpy(rs(i), x, userXy(us(i)))
434466 }
435467 }
436468 }
@@ -443,9 +475,10 @@ class ALS private (
443475 // Add regularization
444476 (0 until rank).foreach(i => fullXtX.data(i* rank + i) += lambda)
445477 // Solve the resulting matrix, which is symmetric and positive-definite
446- implicitPrefs match {
447- case false => Solve .solvePositive(fullXtX, userXy(index)).data
448- case true => Solve .solvePositive(fullXtX.addi(YtY .value.get), userXy(index)).data
478+ if (implicitPrefs) {
479+ Solve .solvePositive(fullXtX.addi(YtY .get.value), userXy(index)).data
480+ } else {
481+ Solve .solvePositive(fullXtX, userXy(index)).data
449482 }
450483 }
451484 }
0 commit comments