Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
val atb = new Array[Double](k)

private val da = new Array[Double](k)
private val ata2 = new Array[Double](k * k)
private val upper = "U"

private def copyToDouble(a: Array[Float]): Unit = {
Expand All @@ -635,6 +636,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}

private def copyToTri(): Unit = {
var ii = 0
for(i <- 0 until k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might fail the style check for missing space before the paren.
Also, I think the received wisdom is that for loops are slow in Scala? if this is performance-critical, you may convert to while loops. Also you can cache i*k from the outer loop below in the inner loop

for(j <- 0 to i) {
ata(ii) += ata2(i * k + j)
ii += 1
}
}

/** Adds an observation. */
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(c >= 0.0)
Expand All @@ -647,6 +657,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
this
}

/** Adds a stack of observations. */
def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
require(a.length == n * k)
blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
copyToTri()
blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
this
}

/** Merges another normal equation object. */
def merge(other: NormalEquation): this.type = {
require(other.k == k)
Expand All @@ -658,6 +677,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** Resets everything to zero, which should be called after each solve. */
def reset(): Unit = {
ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(ata2, 0.0)
ju.Arrays.fill(atb, 0.0)
}
}
Expand Down Expand Up @@ -1296,6 +1316,9 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
var i = srcPtrs(j)
var numExplicits = 0
val doStack = if (srcPtrs(j + 1) - srcPtrs(j) > 10) true else false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (...) true else false is redundant

val srcFactorBuffer = mutable.ArrayBuilder.make[Double]
val bBuffer = mutable.ArrayBuilder.make[Double]
while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
Expand All @@ -1313,11 +1336,23 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
} else {
ls.add(srcFactor, rating)
numExplicits += 1
if (doStack) {
bBuffer += rating
var ii = 0
while(ii < srcFactor.length) {
srcFactorBuffer += srcFactor(ii)
ii += 1
}
} else {
ls.add(srcFactor, rating)
}
}
i += 1
}
if (!implicitPrefs && doStack) {
ls.addStack(srcFactorBuffer.result(), bBuffer.result(), numExplicits)
}
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1
Expand Down