Skip to content

Commit f9d8a83

Browse files
mengxrmateiz
authored andcommitted
[SPARK-1266] persist factors in implicit ALS
In implicit ALS computation, the user or product factor is used twice in each iteration. Caching can certainly help accelerate the computation. I saw the running time decreased by ~70% for implicit ALS on the movielens data. I also made the following changes: 1. Change `YtYb` type from `Broadcast[Option[DoubleMatrix]]` to `Option[Broadcast[DoubleMatrix]]`, so we don't need to broadcast None in explicit computation. 2. Mark methods `computeYtY`, `unblockFactors`, `updateBlock`, and `updateFeatures private`. Users do not need those methods. 3. Materialize the final matrix factors before returning the model. It allows us to clean up other cached RDDs before returning the model. I do not have a better solution here, so I use `RDD.count()`. JIRA: https://spark-project.atlassian.net/browse/SPARK-1266 Author: Xiangrui Meng <[email protected]> Closes #165 from mengxr/als and squashes the following commits: c9676a6 [Xiangrui Meng] add a comment about the last products.persist d3a88aa [Xiangrui Meng] change implicitPrefs match to if ... else ... 63862d6 [Xiangrui Meng] persist factors in implicit ALS
1 parent e108b9a commit f9d8a83

File tree

1 file changed

+89
-56
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/recommendation

1 file changed

+89
-56
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 89 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,10 @@ class ALS private (
148148
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
149149
*/
150150
def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
151+
val sc = ratings.context
152+
151153
val numBlocks = if (this.numBlocks == -1) {
152-
math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2)
154+
math.max(sc.defaultParallelism, ratings.partitions.size / 2)
153155
} else {
154156
this.numBlocks
155157
}
@@ -187,53 +189,81 @@ class ALS private (
187189
}
188190
}
189191

190-
for (iter <- 1 to iterations) {
191-
// perform ALS update
192-
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
193-
// YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model
194-
val YtY = computeYtY(users)
195-
val YtYb = ratings.context.broadcast(YtY)
196-
products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
197-
alpha, YtYb)
198-
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
199-
val XtX = computeYtY(products)
200-
val XtXb = ratings.context.broadcast(XtX)
201-
users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
202-
alpha, XtXb)
192+
if (implicitPrefs) {
193+
for (iter <- 1 to iterations) {
194+
// perform ALS update
195+
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
196+
// Persist users because it will be called twice.
197+
users.persist()
198+
val YtY = Some(sc.broadcast(computeYtY(users)))
199+
val previousProducts = products
200+
products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
201+
alpha, YtY)
202+
previousProducts.unpersist()
203+
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
204+
products.persist()
205+
val XtX = Some(sc.broadcast(computeYtY(products)))
206+
val previousUsers = users
207+
users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
208+
alpha, XtX)
209+
previousUsers.unpersist()
210+
}
211+
} else {
212+
for (iter <- 1 to iterations) {
213+
// perform ALS update
214+
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
215+
products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
216+
alpha, YtY = None)
217+
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
218+
users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
219+
alpha, YtY = None)
220+
}
203221
}
204222

223+
// The last `products` will be used twice. One to generate the last `users` and the other to
224+
// generate `productsOut`. So we cache it for better performance.
225+
products.persist()
226+
205227
// Flatten and cache the two final RDDs to un-block them
206228
val usersOut = unblockFactors(users, userOutLinks)
207229
val productsOut = unblockFactors(products, productOutLinks)
208230

209231
usersOut.persist()
210232
productsOut.persist()
211233

234+
// Materialize usersOut and productsOut.
235+
usersOut.count()
236+
productsOut.count()
237+
238+
products.unpersist()
239+
240+
// Clean up.
241+
userInLinks.unpersist()
242+
userOutLinks.unpersist()
243+
productInLinks.unpersist()
244+
productOutLinks.unpersist()
245+
212246
new MatrixFactorizationModel(rank, usersOut, productsOut)
213247
}
214248

215249
/**
216250
* Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
217-
* for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as
218-
* the driver program requires `YtY` to broadcast it to the slaves
251+
* for each user (or product), in a distributed fashion.
252+
*
219253
* @param factors the (block-distributed) user or product factor vectors
220-
* @return Option[YtY] - whose value is only used in the implicit preference model
254+
* @return YtY - whose value is only used in the implicit preference model
221255
*/
222-
def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
223-
if (implicitPrefs) {
224-
val n = rank * (rank + 1) / 2
225-
val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
226-
Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
227-
L
228-
}, combOp = (L1, L2) => {
229-
L1.addi(L2)
230-
})
231-
val YtY = new DoubleMatrix(rank, rank)
232-
fillFullMatrix(LYtY, YtY)
233-
Option(YtY)
234-
} else {
235-
None
236-
}
256+
private def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
257+
val n = rank * (rank + 1) / 2
258+
val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
259+
Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
260+
L
261+
}, combOp = (L1, L2) => {
262+
L1.addi(L2)
263+
})
264+
val YtY = new DoubleMatrix(rank, rank)
265+
fillFullMatrix(LYtY, YtY)
266+
YtY
237267
}
238268

239269
/**
@@ -264,7 +294,7 @@ class ALS private (
264294
/**
265295
* Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
266296
*/
267-
def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
297+
private def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
268298
outLinks: RDD[(Int, OutLinkBlock)]) = {
269299
blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
270300
for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
@@ -332,8 +362,11 @@ class ALS private (
332362
val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
333363
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
334364
}, true)
335-
links.persist(StorageLevel.MEMORY_AND_DISK)
336-
(links.mapValues(_._1), links.mapValues(_._2))
365+
val inLinks = links.mapValues(_._1)
366+
val outLinks = links.mapValues(_._2)
367+
inLinks.persist(StorageLevel.MEMORY_AND_DISK)
368+
outLinks.persist(StorageLevel.MEMORY_AND_DISK)
369+
(inLinks, outLinks)
337370
}
338371

339372
/**
@@ -365,7 +398,7 @@ class ALS private (
365398
rank: Int,
366399
lambda: Double,
367400
alpha: Double,
368-
YtY: Broadcast[Option[DoubleMatrix]])
401+
YtY: Option[Broadcast[DoubleMatrix]])
369402
: RDD[(Int, Array[Array[Double]])] =
370403
{
371404
val numBlocks = products.partitions.size
@@ -388,8 +421,8 @@ class ALS private (
388421
* Compute the new feature vectors for a block of the users matrix given the list of factors
389422
* it received from each product and its InLinkBlock.
390423
*/
391-
def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
392-
rank: Int, lambda: Double, alpha: Double, YtY: Broadcast[Option[DoubleMatrix]])
424+
private def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
425+
rank: Int, lambda: Double, alpha: Double, YtY: Option[Broadcast[DoubleMatrix]])
393426
: Array[Array[Double]] =
394427
{
395428
// Sort the incoming block factor messages by block ID and make them an array
@@ -416,21 +449,20 @@ class ALS private (
416449
dspr(1.0, x, tempXtX)
417450
val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
418451
for (i <- 0 until us.length) {
419-
implicitPrefs match {
420-
case false =>
421-
userXtX(us(i)).addi(tempXtX)
422-
SimpleBlas.axpy(rs(i), x, userXy(us(i)))
423-
case true =>
424-
// Extension to the original paper to handle rs(i) < 0. confidence is a function
425-
// of |rs(i)| instead so that it is never negative:
426-
val confidence = 1 + alpha * abs(rs(i))
427-
SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i)))
428-
// For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
429-
// means we try to reconstruct 0. We add terms only where P = 1, so, term below
430-
// is now only added for rs(i) > 0:
431-
if (rs(i) > 0) {
432-
SimpleBlas.axpy(confidence, x, userXy(us(i)))
433-
}
452+
if (implicitPrefs) {
453+
// Extension to the original paper to handle rs(i) < 0. confidence is a function
454+
// of |rs(i)| instead so that it is never negative:
455+
val confidence = 1 + alpha * abs(rs(i))
456+
SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i)))
457+
// For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
458+
// means we try to reconstruct 0. We add terms only where P = 1, so, term below
459+
// is now only added for rs(i) > 0:
460+
if (rs(i) > 0) {
461+
SimpleBlas.axpy(confidence, x, userXy(us(i)))
462+
}
463+
} else {
464+
userXtX(us(i)).addi(tempXtX)
465+
SimpleBlas.axpy(rs(i), x, userXy(us(i)))
434466
}
435467
}
436468
}
@@ -443,9 +475,10 @@ class ALS private (
443475
// Add regularization
444476
(0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
445477
// Solve the resulting matrix, which is symmetric and positive-definite
446-
implicitPrefs match {
447-
case false => Solve.solvePositive(fullXtX, userXy(index)).data
448-
case true => Solve.solvePositive(fullXtX.addi(YtY.value.get), userXy(index)).data
478+
if (implicitPrefs) {
479+
Solve.solvePositive(fullXtX.addi(YtY.get.value), userXy(index)).data
480+
} else {
481+
Solve.solvePositive(fullXtX, userXy(index)).data
449482
}
450483
}
451484
}

0 commit comments

Comments
 (0)