Skip to content

Commit 63862d6

Browse files
committed
persist factors in implicit ALS
1 parent 6983732 commit 63862d6

File tree

1 file changed

+70
-39
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/recommendation

1 file changed

+70
-39
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,10 @@ class ALS private (
148148
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
149149
*/
150150
def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
151+
val sc = ratings.context
152+
151153
val numBlocks = if (this.numBlocks == -1) {
152-
math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2)
154+
math.max(sc.defaultParallelism, ratings.partitions.size / 2)
153155
} else {
154156
this.numBlocks
155157
}
@@ -187,53 +189,79 @@ class ALS private (
187189
}
188190
}
189191

190-
for (iter <- 1 to iterations) {
191-
// perform ALS update
192-
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
193-
// YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model
194-
val YtY = computeYtY(users)
195-
val YtYb = ratings.context.broadcast(YtY)
196-
products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
197-
alpha, YtYb)
198-
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
199-
val XtX = computeYtY(products)
200-
val XtXb = ratings.context.broadcast(XtX)
201-
users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
202-
alpha, XtXb)
192+
if (implicitPrefs) {
193+
for (iter <- 1 to iterations) {
194+
// perform ALS update
195+
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
196+
// Persist users because it will be called twice.
197+
users.persist()
198+
val YtY = Some(sc.broadcast(computeYtY(users)))
199+
val previousProducts = products
200+
products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
201+
alpha, YtY)
202+
previousProducts.unpersist()
203+
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
204+
products.persist()
205+
val XtX = Some(sc.broadcast(computeYtY(products)))
206+
val previousUsers = users
207+
users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
208+
alpha, XtX)
209+
previousUsers.unpersist()
210+
}
211+
} else {
212+
for (iter <- 1 to iterations) {
213+
// perform ALS update
214+
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
215+
products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
216+
alpha, YtY = None)
217+
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
218+
users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
219+
alpha, YtY = None)
220+
}
203221
}
204222

223+
products.persist()
224+
205225
// Flatten and cache the two final RDDs to un-block them
206226
val usersOut = unblockFactors(users, userOutLinks)
207227
val productsOut = unblockFactors(products, productOutLinks)
208228

209229
usersOut.persist()
210230
productsOut.persist()
211231

232+
// Materialize usersOut and productsOut.
233+
usersOut.count()
234+
productsOut.count()
235+
236+
products.unpersist()
237+
238+
// Clean up.
239+
userInLinks.unpersist()
240+
userOutLinks.unpersist()
241+
productInLinks.unpersist()
242+
productOutLinks.unpersist()
243+
212244
new MatrixFactorizationModel(rank, usersOut, productsOut)
213245
}
214246

215247
/**
216248
* Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
217-
* for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as
218-
* the driver program requires `YtY` to broadcast it to the slaves
249+
* for each user (or product), in a distributed fashion.
250+
*
219251
* @param factors the (block-distributed) user or product factor vectors
220-
* @return Option[YtY] - whose value is only used in the implicit preference model
252+
* @return YtY - whose value is only used in the implicit preference model
221253
*/
222-
def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
223-
if (implicitPrefs) {
224-
val n = rank * (rank + 1) / 2
225-
val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
226-
Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
227-
L
228-
}, combOp = (L1, L2) => {
229-
L1.addi(L2)
230-
})
231-
val YtY = new DoubleMatrix(rank, rank)
232-
fillFullMatrix(LYtY, YtY)
233-
Option(YtY)
234-
} else {
235-
None
236-
}
254+
private def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
255+
val n = rank * (rank + 1) / 2
256+
val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
257+
Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
258+
L
259+
}, combOp = (L1, L2) => {
260+
L1.addi(L2)
261+
})
262+
val YtY = new DoubleMatrix(rank, rank)
263+
fillFullMatrix(LYtY, YtY)
264+
YtY
237265
}
238266

239267
/**
@@ -264,7 +292,7 @@ class ALS private (
264292
/**
265293
* Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
266294
*/
267-
def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
295+
private def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
268296
outLinks: RDD[(Int, OutLinkBlock)]) = {
269297
blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
270298
for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
@@ -332,8 +360,11 @@ class ALS private (
332360
val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
333361
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
334362
}, true)
335-
links.persist(StorageLevel.MEMORY_AND_DISK)
336-
(links.mapValues(_._1), links.mapValues(_._2))
363+
val inLinks = links.mapValues(_._1)
364+
val outLinks = links.mapValues(_._2)
365+
inLinks.persist(StorageLevel.MEMORY_AND_DISK)
366+
outLinks.persist(StorageLevel.MEMORY_AND_DISK)
367+
(inLinks, outLinks)
337368
}
338369

339370
/**
@@ -365,7 +396,7 @@ class ALS private (
365396
rank: Int,
366397
lambda: Double,
367398
alpha: Double,
368-
YtY: Broadcast[Option[DoubleMatrix]])
399+
YtY: Option[Broadcast[DoubleMatrix]])
369400
: RDD[(Int, Array[Array[Double]])] =
370401
{
371402
val numBlocks = products.partitions.size
@@ -388,8 +419,8 @@ class ALS private (
388419
* Compute the new feature vectors for a block of the users matrix given the list of factors
389420
* it received from each product and its InLinkBlock.
390421
*/
391-
def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
392-
rank: Int, lambda: Double, alpha: Double, YtY: Broadcast[Option[DoubleMatrix]])
422+
private def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
423+
rank: Int, lambda: Double, alpha: Double, YtY: Option[Broadcast[DoubleMatrix]])
393424
: Array[Array[Double]] =
394425
{
395426
// Sort the incoming block factor messages by block ID and make them an array
@@ -445,7 +476,7 @@ class ALS private (
445476
// Solve the resulting matrix, which is symmetric and positive-definite
446477
implicitPrefs match {
447478
case false => Solve.solvePositive(fullXtX, userXy(index)).data
448-
case true => Solve.solvePositive(fullXtX.addi(YtY.value.get), userXy(index)).data
479+
case true => Solve.solvePositive(fullXtX.addi(YtY.get.value), userXy(index)).data
449480
}
450481
}
451482
}

0 commit comments

Comments
 (0)