Skip to content
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom

Expand Down Expand Up @@ -356,6 +356,19 @@ class ALSModel private[ml] (

/**
* Makes recommendations for all users (or items).
*
* Note: the previous approach used for computing top-k recommendations
* used a cross-join followed by predicting a score for each row of the joined dataset.
* However, this results in exploding the size of intermediate data. While Spark SQL makes it
* relatively efficient, the approach implemented here is significantly more efficient.
*
* This approach groups factors into blocks and computes the top-k elements per block,
* using Level 1 BLAS (dot) and an efficient [[BoundedPriorityQueue]]. It then computes the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below we say that blas is not used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about "... using dot product instead of gemm and an efficient ..."

* global top-k by aggregating the per block top-k elements with a [[TopByKeyAggregator]].
* This significantly reduces the size of intermediate and shuffle data.
* This is the DataFrame equivalent to the approach used in
* [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel]].
*
* @param srcFactors src factors for which to generate recommendations
* @param dstFactors dst factors used to make recommendations
* @param srcOutputColumn name of the column for the source ID in the output DataFrame
Expand All @@ -372,11 +385,45 @@ class ALSModel private[ml] (
num: Int): DataFrame = {
import srcFactors.sparkSession.implicits._

val ratings = srcFactors.crossJoin(dstFactors)
.select(
srcFactors("id"),
dstFactors("id"),
predict(srcFactors("features"), dstFactors("features")))
val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])])
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])])
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
.as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])]
.flatMap { case (srcIter, dstIter) =>
val m = srcIter.size
val n = math.min(dstIter.size, num)
val output = new Array[(Int, Int, Float)](m * n)
var j = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You could combine j and i; you really just need 1 counter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

j iterates through src ids while i iterates through dst ids in the queue for each src id. So I don't think they can be combined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway the iter.next() code is a bit ugly and since it's at most k elements it's not really performance critical, so could just use foreach I think

val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
srcIter.foreach { case (srcId, srcFactor) =>
dstIter.foreach { case (dstId, dstFactor) =>
/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use doc notation. Maybe we can reduce it to:

            /*
             * The below code is equivalent to
             *   `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)`
             * The handwritten version is as or more efficient as BLAS calls in this case. 
             */

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

* The below code is equivalent to
* val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)
* Compared with BLAS.dot, the hand-written version used below is more efficient than
* a call to the native BLAS backend and the same performance as the fallback
* F2jBLAS backend.
*/
var score = 0.0f
var k = 0
while (k < rank) {
score += srcFactor(k) * dstFactor(k)
k += 1
}
pq += { (dstId, score) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pq += dstId -> score?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

}
val pqIter = pq.iterator
var i = 0
while (i < n) {
val (dstId, score) = pqIter.next()
output(j + i) = (srcId, dstId, score)
i += 1
}
j += n
pq.clear()
}
output.toSeq
}
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
Expand All @@ -389,6 +436,17 @@ class ALSModel private[ml] (
)
recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point - may as well fix it while here

}

/**
* Blockifies factors to improve the efficiency of cross join
*/
private def blockify(
factors: Dataset[(Int, Array[Float])],
/* TODO make blockSize a param? */blockSize: Int = 4096): Dataset[Seq[(Int, Array[Float])]] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just put the comment in the doc and reference a JIRA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

import factors.sparkSession.implicits._
factors.mapPartitions(_.grouped(blockSize))
}

}

@Since("1.6.0")
Expand Down