Skip to content
Closed
85 changes: 75 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
/** @group expertGetParam */
def getFinalStorageLevel: String = $(finalStorageLevel)

/**
* Param for threshold in computation of dst factors to decide
* if stacking factors to speed up the computation.(>= 1).
* Default: 1024
* @group expertParam
*/
val threshold = new IntParam(this, "threshold", "threshold in computation of dst factors " +
"to decide if stacking factors to speed up the computation.",
ParamValidators.gtEq(1))

/** @group expertGetParam */
def getThreshold: Int = $(threshold)

setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
threshold -> 1024)

/**
* Validates and transforms the input schema.
Expand Down Expand Up @@ -432,6 +446,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)

/** @group expertSetParam */
@Since("2.1.0")
def setThreshold(value: Int): this.type = set(threshold, value)

/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
*
Expand Down Expand Up @@ -460,14 +478,15 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val instrLog = Instrumentation.create(this, ratings)
instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
userCol, itemCol, ratingCol, predictionCol, maxIter,
regParam, nonnegative, checkpointInterval, seed)
regParam, nonnegative, threshold, checkpointInterval, seed)
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)),
finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)),
checkpointInterval = $(checkpointInterval), seed = $(seed))
checkpointInterval = $(checkpointInterval),
seed = $(seed), threshold = $(threshold))
val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
Expand Down Expand Up @@ -621,6 +640,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
val atb = new Array[Double](k)

private val da = new Array[Double](k)
private val ata2 = new Array[Double](k * k)
private val upper = "U"

private def copyToDouble(a: Array[Float]): Unit = {
Expand All @@ -631,6 +651,22 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}

private def copyToTri(): Unit = {
var i = 0
var j = 0
var ii = 0
while (i < k) {
val temp = i * k
j = 0
while (j <= i) {
ata(ii) += ata2(temp + j)
j += 1
ii += 1
}
i += 1
}
}

/** Adds an observation. */
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(c >= 0.0)
Expand All @@ -643,6 +679,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
this
}

/** Adds a stack of observations. */
def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
require(a.length == n * k)
blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
copyToTri()
blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
this
}

/** Merges another normal equation object. */
def merge(other: NormalEquation): this.type = {
require(other.k == k)
Expand All @@ -654,6 +699,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** Resets everything to zero, which should be called after each solve. */
def reset(): Unit = {
ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(ata2, 0.0)
ju.Arrays.fill(atb, 0.0)
}
}
Expand All @@ -676,7 +722,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
checkpointInterval: Int = 10,
seed: Long = 0L)(
seed: Long = 0L,
threshold: Int = 1024)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
"ALS is not designed to run without persisting intermediate RDDs.")
Expand Down Expand Up @@ -721,7 +768,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
// TODO: Generalize PeriodicGraphCheckpointer and use it here.
Expand All @@ -731,7 +778,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
if (shouldCheckpoint(iter)) {
ALS.cleanShuffleDependencies(sc, deps)
deletePreviousCheckpointFile()
Expand All @@ -742,7 +789,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
userLocalIndexEncoder, solver = solver, threshold = threshold)
if (shouldCheckpoint(iter)) {
val deps = itemFactors.dependencies
itemFactors.checkpoint()
Expand All @@ -752,7 +799,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
previousCheckpointFile = itemFactors.getCheckpointFile
}
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
itemLocalIndexEncoder, solver = solver, threshold = threshold)
}
}
val userIdAndFactors = userInBlocks
Expand Down Expand Up @@ -1266,7 +1313,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
alpha: Double = 1.0,
solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
solver: LeastSquaresNESolver,
threshold: Int): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.length
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
Expand All @@ -1292,6 +1340,11 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
var i = srcPtrs(j)
var numExplicits = 0
// Stacking factors(vectors) in matrices to speed up the computation,
// when the number of factors and the rank is large enough.
val doStack = srcPtrs(j + 1) - srcPtrs(j) > threshold && rank > threshold
val srcFactorBuffer = mutable.ArrayBuilder.make[Double]
val bBuffer = mutable.ArrayBuilder.make[Double]
while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
Expand All @@ -1309,11 +1362,23 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
} else {
ls.add(srcFactor, rating)
numExplicits += 1
if (doStack) {
bBuffer += rating
var ii = 0
while(ii < srcFactor.length) {
srcFactorBuffer += srcFactor(ii)
ii += 1
}
} else {
ls.add(srcFactor, rating)
}
}
i += 1
}
if (!implicitPrefs && doStack) {
ls.addStack(srcFactorBuffer.result(), bBuffer.result(), numExplicits)
}
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ class ALSSuite
implicitPrefs: Boolean = false,
numUserBlocks: Int = 2,
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
targetRMSE: Double = 0.05,
threshold: Int = 1024): Unit = {
val spark = this.spark
import spark.implicits._
val als = new ALS()
Expand All @@ -311,6 +312,7 @@ class ALSSuite
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
.setSeed(0)
.setThreshold(threshold)
val alpha = als.getAlpha
val model = als.fit(training.toDF())
val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map {
Expand Down Expand Up @@ -382,6 +384,12 @@ class ALSSuite
numItemBlocks = 5, numUserBlocks = 5)
}

test("do stacking factors in matrices") {
val (training, test) = genExplicitTestData(numUsers = 200, numItems = 20, rank = 1)
testALS(training, test, maxIter = 1, rank = 129, regParam = 0.01, targetRMSE = 0.02,
threshold = 128)
}

test("implicit feedback") {
val (training, test) =
genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
Expand Down