@@ -247,9 +247,34 @@ class LDA private (
247247 new DistributedLDAModel (state, iterationTimes)
248248 }
249249
250- def runOnlineLDA (documents : RDD [(Long , Vector )]): LDAModel = {
251- val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k)
252- (0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
250+
251+ /**
252+ * Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
253+ * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
254+ *
255+ * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
256+ * The term count vectors are "bags of words" with a fixed-size vocabulary
257+ * (where the vocabulary size is the length of the vector).
258+ * Document IDs must be unique and >= 0.
259+ * @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
260+ * -1 for automatic batchNumber.
261+ * @return Inferred LDA model
262+ */
263+ def runOnlineLDA (documents : RDD [(Long , Vector )], batchNumber : Int = - 1 ): LDAModel = {
264+ val D = documents.count().toInt
265+ val batchSize =
266+ if (batchNumber == - 1 ) { // auto mode
267+ if (D / 100 > 16384 ) 16384
268+ else if (D / 100 < 4 ) 4
269+ else D / 100
270+ }
271+ else {
272+ require(batchNumber > 0 , " batchNumber should be positive or -1" )
273+ D / batchNumber
274+ }
275+
276+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, batchSize)
277+ (0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next())
253278 new LocalLDAModel (Matrices .fromBreeze(onlineLDA.lambda).transpose)
254279 }
255280
@@ -411,28 +436,26 @@ private[clustering] object LDA {
411436 * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
412437 */
413438 private [clustering] class OnlineLDAOptimizer (
414- private val documents : RDD [(Long , Vector )],
415- private val k : Int ) extends Serializable {
439+ private val documents : RDD [(Long , Vector )],
440+ private val k : Int ,
441+ private val batchSize : Int ) extends Serializable {
416442
417443 private val vocabSize = documents.first._2.size
418444 private val D = documents.count().toInt
419- private val batchSize = if (D / 1000 > 4096 ) 4096
420- else if (D / 1000 < 4 ) 4
421- else D / 1000
422- val batchNumber = D / batchSize
445+ val actualBatchNumber = Math .ceil(D .toDouble / batchSize).toInt
423446
424- // Initialize the variational distribution q(beta|lambda)
447+ // Initialize the variational distribution q(beta|lambda)
425448 var lambda = getGammaMatrix(k, vocabSize) // K * V
426449 private var Elogbeta = dirichlet_expectation(lambda) // K * V
427450 private var expElogbeta = exp(Elogbeta ) // K * V
428451
429452 private var batchId = 0
430453 def next (): Unit = {
431- require(batchId < batchNumber )
454+ require(batchId < actualBatchNumber )
432455 // weight of the mini-batch. 1024 down weights early iterations
433456 val weight = math.pow(1024 + batchId, - 0.5 )
434- val batch = documents.filter(doc => doc._1 % batchNumber == batchId )
435-
457+ val batch = documents.sample( true , batchSize.toDouble / D )
458+ batch.cache()
436459 // Given a mini-batch of documents, estimates the parameters gamma controlling the
437460 // variational distribution over the topic weights for each document in the mini-batch.
438461 var stat = BDM .zeros[Double ](k, vocabSize)
0 commit comments