-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-13777] [ML] Remove constant features from training in normal solver (WLS) #11610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,23 +80,16 @@ private[ml] class WeightedLeastSquares( | |
| val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) | ||
| summary.validate() | ||
| logInfo(s"Number of instances: ${summary.count}.") | ||
| val k = if (fitIntercept) summary.k + 1 else summary.k | ||
| val triK = summary.triK | ||
| val wSum = summary.wSum | ||
| val bBar = summary.bBar | ||
| val bStd = summary.bStd | ||
| val aBar = summary.aBar | ||
| val aVar = summary.aVar | ||
| val abBar = summary.abBar | ||
| val aaBar = summary.aaBar | ||
| val aaValues = aaBar.values | ||
|
|
||
| if (bStd == 0) { | ||
| if (fitIntercept) { | ||
| logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + | ||
| s"zeros and the intercept will be the mean of the label; as a result, " + | ||
| s"training is not needed.") | ||
| val coefficients = new DenseVector(Array.ofDim(k-1)) | ||
| val coefficients = new DenseVector(Array.ofDim(summary.k)) | ||
| val intercept = bBar | ||
| val diagInvAtWA = new DenseVector(Array(0D)) | ||
| return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) | ||
|
|
@@ -108,6 +101,57 @@ private[ml] class WeightedLeastSquares( | |
| "Consider setting fitIntercept=true.") | ||
| } | ||
| } | ||
| /* | ||
| * If more than one of the features in the data are constant (i.e. data matrix has constant | ||
| * columns), then A^T.A is no longer positive definite and Cholesky decomposition fails | ||
| * (because the normal equation does not have a solution). | ||
| * In order to find a solution, we need to drop constant columns from the data matrix. Or, | ||
| * we can drop corresponding column and row from A^T.A matrix. | ||
| * Once we drop rows/columns from A^T.A matrix, the Cholesky decomposition will produce | ||
| * correct coefficients. But, for the final result, we need to add zeros to the list of | ||
| * coefficients corresponding to the constant features. | ||
| */ | ||
| val aVarRaw = summary.aVar.values | ||
|
||
| // this will keep track of features to keep in the model, and remove | ||
| // features with zero variance. | ||
| val nzVarIndex = aVarRaw.zipWithIndex.filter(_._1 != 0).map(_._2) | ||
|
||
| val nz = nzVarIndex.length | ||
| // if there are features with zero variance, then ATA is not positive definite, and we need to | ||
| // keep track of that. | ||
| val singular = summary.k > nz | ||
|
||
| val k = if (fitIntercept) nz + 1 else nz | ||
| val triK = nz * (nz + 1) / 2 | ||
|
|
||
| val aVar = if (singular) { | ||
|
||
| for (i <- nzVarIndex) yield aVarRaw(i) | ||
| } else { | ||
| aVarRaw | ||
| } | ||
| val aBar = if (singular) { | ||
| val aBarTemp = summary.aBar.values | ||
|
||
| for (i <- nzVarIndex) yield aBarTemp(i) | ||
| } else { | ||
| summary.aBar.values | ||
| } | ||
| val abBar = if (singular) { | ||
| val abBarTemp = summary.abBar.values | ||
|
||
| for (i <- nzVarIndex) yield {abBarTemp(i)} | ||
| } else { | ||
| summary.abBar.values | ||
| } | ||
| // NOTE: aaBar represents upper triangular part of A^T.A matrix in column major order. | ||
| // We need to drop columns and rows from A^T.A corresponding to the features which have | ||
| // zero variance. The following logic removes elements from aaBar corresponding to zerp | ||
| // variance which effectively removes columns and rows from A^T.A. | ||
| val aaBar = if (singular) { | ||
|
||
| val aaBarTemp = summary.aaBar.values | ||
| (for { col <- 0 until summary.k | ||
| row <- 0 to col | ||
| if aVarRaw(col) != 0 && aVarRaw(row) != 0 } yield | ||
| aaBarTemp(row + col * (col + 1) / 2)).toArray | ||
| } else { | ||
| summary.aaBar.values | ||
| } | ||
|
|
||
| // add regularization to diagonals | ||
| var i = 0 | ||
|
|
@@ -120,34 +164,47 @@ private[ml] class WeightedLeastSquares( | |
| if (standardizeLabel && bStd != 0) { | ||
| lambda /= bStd | ||
| } | ||
| aaValues(i) += lambda | ||
| aaBar(i) += lambda | ||
| i += j | ||
| j += 1 | ||
| } | ||
|
|
||
| val aa = if (fitIntercept) { | ||
| Array.concat(aaBar.values, aBar.values, Array(1.0)) | ||
| Array.concat(aaBar, aBar, Array(1.0)) | ||
| } else { | ||
| aaBar.values | ||
| aaBar | ||
| } | ||
| val ab = if (fitIntercept) { | ||
| Array.concat(abBar.values, Array(bBar)) | ||
| Array.concat(abBar, Array(bBar)) | ||
| } else { | ||
| abBar.values | ||
| abBar | ||
| } | ||
|
|
||
| val x = CholeskyDecomposition.solve(aa, ab) | ||
|
|
||
| val (coefs, intercept) = if (fitIntercept) { | ||
| (x.init, x.last) | ||
| } else { | ||
| (x, 0.0) | ||
| } | ||
| val aaInv = CholeskyDecomposition.inverse(aa, k) | ||
|
|
||
| // aaInv is a packed upper triangular matrix, here we get all elements on diagonal | ||
| val diagInvAtWA = new DenseVector((1 to k).map { i => | ||
| aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray) | ||
| val aaInvDiag = (1 to k).map { i => | ||
| aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray | ||
|
|
||
| val (coefficients, intercept) = if (fitIntercept) { | ||
| (new DenseVector(x.slice(0, x.length - 1)), x.last) | ||
| val (coefficients, diagInvAtWA) = if (singular) { | ||
| // if there are constant features in the data, we need to add zeros for the coefficients | ||
| // for these features. | ||
| val coefTemp = Array.ofDim[Double](summary.k) | ||
| val diagTemp = Array.ofDim[Double](summary.k) | ||
| var i = 0 | ||
| while (i < nz) { | ||
| coefTemp(nzVarIndex(i)) = coefs(i) | ||
| diagTemp(nzVarIndex(i)) = aaInvDiag(i) | ||
| i += 1 | ||
| } | ||
| (new DenseVector(coefTemp), new DenseVector(diagTemp)) | ||
| } else { | ||
| (new DenseVector(x), 0.0) | ||
| (new DenseVector(coefs), new DenseVector(aaInvDiag)) | ||
| } | ||
|
|
||
| new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make this comment adhere to the style like here?