-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-28159][ML] Make the transform natively in ml framework to avoid extra conversion #24963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
3d9c916
92d555c
10ba449
bd813db
f78ed32
5730ab7
38b3872
d54d073
f1314fb
096d204
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -595,7 +595,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { | |
| * Serializable companion object containing helper methods and shared code for | ||
| * [[OnlineLDAOptimizer]] and [[LocalLDAModel]]. | ||
| */ | ||
| private[clustering] object OnlineLDAOptimizer { | ||
| private[spark] object OnlineLDAOptimizer { | ||
| /** | ||
| * Uses variational inference to infer the topic distribution `gammad` given the term counts | ||
| * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will | ||
|
|
@@ -608,27 +608,24 @@ private[clustering] object OnlineLDAOptimizer { | |
| * @return Returns a tuple of `gammad` - estimate of gamma, the topic distribution, `sstatsd` - | ||
| * statistics for updating lambda and `ids` - list of termCounts vector indices. | ||
| */ | ||
| private[clustering] def variationalTopicInference( | ||
| termCounts: Vector, | ||
| private[spark] def variationalTopicInference( | ||
| indices: List[Int], | ||
| values: Array[Double], | ||
| expElogbeta: BDM[Double], | ||
| alpha: breeze.linalg.Vector[Double], | ||
| gammaShape: Double, | ||
| k: Int, | ||
| seed: Long): (BDV[Double], BDM[Double], List[Int]) = { | ||
| val (ids: List[Int], cts: Array[Double]) = termCounts match { | ||
| case v: DenseVector => ((0 until v.size).toList, v.values) | ||
| case v: SparseVector => (v.indices.toList, v.values) | ||
| } | ||
| // Initialize the variational distribution q(theta|gamma) for the mini-batch | ||
| val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) | ||
| val gammad: BDV[Double] = | ||
| new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K | ||
| val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K | ||
| val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K | ||
| val expElogbetad = expElogbeta(indices, ::).toDenseMatrix // ids * K | ||
|
|
||
| val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids | ||
| var meanGammaChange = 1D | ||
| val ctsVector = new BDV[Double](cts) // ids | ||
| val ctsVector = new BDV[Double](values) // ids | ||
|
|
||
| // Iterate between gamma and phi until convergence | ||
| while (meanGammaChange > 1e-3) { | ||
|
|
@@ -642,6 +639,20 @@ private[clustering] object OnlineLDAOptimizer { | |
| } | ||
|
|
||
| val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector /:/ phiNorm).asDenseMatrix | ||
| (gammad, sstatsd, ids) | ||
| (gammad, sstatsd, indices) | ||
| } | ||
|
|
||
| private[clustering] def variationalTopicInference( | ||
| termCounts: Vector, | ||
| expElogbeta: BDM[Double], | ||
| alpha: breeze.linalg.Vector[Double], | ||
| gammaShape: Double, | ||
| k: Int, | ||
| seed: Long): (BDV[Double], BDM[Double], List[Int]) = { | ||
| val (ids: List[Int], cts: Array[Double]) = termCounts match { | ||
| case v: DenseVector => ((0 until v.size).toList, v.values) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and elsewhere, as an optimization, can we avoid You're generally solving this with separate sparse/dense methods which could be fine too if it doesn't result in too much code duplication and improves performance in the dense case.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good then except we might be able to make one more optimization here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just look into the usage of indices
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am afraid that an empty list may not help to simplify the impl. |
||
| case v: SparseVector => (v.indices.toList, v.values) | ||
| } | ||
| variationalTopicInference(ids, cts, expElogbeta, alpha, gammaShape, k, seed) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.