diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 02e2384afe53..4e0702a7c196 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -200,10 +200,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w /** @group expertGetParam */ def getFinalStorageLevel: String = $(finalStorageLevel) + /** + * Param for threshold in computation of dst factors to decide + * if stacking factors to speed up the computation.(>= 1). + * Default: 1024 + * @group expertParam + */ + val threshold = new IntParam(this, "threshold", "threshold in computation of dst factors " + + "to decide if stacking factors to speed up the computation.", + ParamValidators.gtEq(1)) + + /** @group expertGetParam */ + def getThreshold: Int = $(threshold) + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, - intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK") + intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK", + threshold -> 1024) /** * Validates and transforms the input schema. @@ -432,6 +446,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value) + /** @group expertSetParam */ + @Since("2.1.0") + def setThreshold(value: Int): this.type = set(threshold, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @@ -460,14 +478,15 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val instrLog = Instrumentation.create(this, ratings) instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, itemCol, ratingCol, predictionCol, maxIter, - regParam, nonnegative, checkpointInterval, seed) + regParam, nonnegative, threshold, checkpointInterval, seed) val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)), finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)), - checkpointInterval = $(checkpointInterval), seed = $(seed)) + checkpointInterval = $(checkpointInterval), + seed = $(seed), threshold = $(threshold)) val userDF = userFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features") val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) @@ -621,6 +640,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { val atb = new Array[Double](k) private val da = new Array[Double](k) + private val ata2 = new Array[Double](k * k) private val upper = "U" private def copyToDouble(a: Array[Float]): Unit = { @@ -631,6 +651,22 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } } + private def copyToTri(): Unit = { + var i = 0 + var j = 0 + var ii = 0 + while (i < k) { + val temp = i * k + j = 0 + while (j <= i) { + ata(ii) += ata2(temp + j) + j += 1 + ii += 1 + } + i += 1 + } + } + /** Adds an observation. */ def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { require(c >= 0.0) @@ -643,6 +679,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { this } + /** Adds a stack of observations. */ + def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = { + require(a.length == n * k) + blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k) + copyToTri() + blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1) + this + } + /** Merges another normal equation object. */ def merge(other: NormalEquation): this.type = { require(other.k == k) @@ -654,6 +699,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { /** Resets everything to zero, which should be called after each solve. */ def reset(): Unit = { ju.Arrays.fill(ata, 0.0) + ju.Arrays.fill(ata2, 0.0) ju.Arrays.fill(atb, 0.0) } } @@ -676,7 +722,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, checkpointInterval: Int = 10, - seed: Long = 0L)( + seed: Long = 0L, + threshold: Int = 1024)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") @@ -721,7 +768,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) val previousItemFactors = itemFactors itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, - userLocalIndexEncoder, implicitPrefs, alpha, solver) + userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold) previousItemFactors.unpersist() itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) // TODO: Generalize PeriodicGraphCheckpointer and use it here. @@ -731,7 +778,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, - itemLocalIndexEncoder, implicitPrefs, alpha, solver) + itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold) if (shouldCheckpoint(iter)) { ALS.cleanShuffleDependencies(sc, deps) deletePreviousCheckpointFile() @@ -742,7 +789,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } else { for (iter <- 0 until maxIter) { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, - userLocalIndexEncoder, solver = solver) + userLocalIndexEncoder, solver = solver, threshold = threshold) if (shouldCheckpoint(iter)) { val deps = itemFactors.dependencies itemFactors.checkpoint() @@ -752,7 +799,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { previousCheckpointFile = itemFactors.getCheckpointFile } userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, - itemLocalIndexEncoder, solver = solver) + itemLocalIndexEncoder, solver = solver, threshold = threshold) } } val userIdAndFactors = userInBlocks @@ -1266,7 +1313,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { srcEncoder: LocalIndexEncoder, implicitPrefs: Boolean = false, alpha: Double = 1.0, - solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = { + solver: LeastSquaresNESolver, + threshold: Int): RDD[(Int, FactorBlock)] = { val numSrcBlocks = srcFactorBlocks.partitions.length val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { @@ -1292,6 +1340,11 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } var i = srcPtrs(j) var numExplicits = 0 + // Stacking factors(vectors) in matrices to speed up the computation, + // when the number of factors and the rank is large enough. + val doStack = srcPtrs(j + 1) - srcPtrs(j) > threshold && rank > threshold + val srcFactorBuffer = mutable.ArrayBuilder.make[Double] + val bBuffer = mutable.ArrayBuilder.make[Double] while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1309,11 +1362,23 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { ls.add(srcFactor, (c1 + 1.0) / c1, c1) } } else { - ls.add(srcFactor, rating) numExplicits += 1 + if (doStack) { + bBuffer += rating + var ii = 0 + while(ii < srcFactor.length) { + srcFactorBuffer += srcFactor(ii) + ii += 1 + } + } else { + ls.add(srcFactor, rating) + } } i += 1 } + if (!implicitPrefs && doStack) { + ls.addStack(srcFactorBuffer.result(), bBuffer.result(), numExplicits) + } // Weight lambda by the number of explicit ratings based on the ALS-WR paper. dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index d0aa2cdfe0fd..e4d6e700791f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -301,7 +301,8 @@ class ALSSuite implicitPrefs: Boolean = false, numUserBlocks: Int = 2, numItemBlocks: Int = 3, - targetRMSE: Double = 0.05): Unit = { + targetRMSE: Double = 0.05, + threshold: Int = 1024): Unit = { val spark = this.spark import spark.implicits._ val als = new ALS() @@ -311,6 +312,7 @@ class ALSSuite .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) .setSeed(0) + .setThreshold(threshold) val alpha = als.getAlpha val model = als.fit(training.toDF()) val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { @@ -382,6 +384,12 @@ class ALSSuite numItemBlocks = 5, numUserBlocks = 5) } + test("do stacking factors in matrices") { + val (training, test) = genExplicitTestData(numUsers = 200, numItems = 20, rank = 1) + testALS(training, test, maxIter = 1, rank = 129, regParam = 0.01, targetRMSE = 0.02, + threshold = 128) + } + test("implicit feedback") { val (training, test) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)