-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16008][ML] Remove unnecessary serialization in logistic regression #13729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -937,50 +937,47 @@ class BinaryLogisticRegressionSummary private[classification] ( | |
| * Two LogisticAggregator can be merged together to have a summary of loss and gradient of | ||
| * the corresponding joint dataset. | ||
| * | ||
| * @param coefficients The coefficients corresponding to the features. | ||
| * @param numClasses the number of possible outcomes for k classes classification problem in | ||
| * Multinomial Logistic Regression. | ||
| * @param fitIntercept Whether to fit an intercept term. | ||
| * @param featuresStd The standard deviation values of the features. | ||
| * @param featuresMean The mean values of the features. | ||
| */ | ||
| private class LogisticAggregator( | ||
| coefficients: Vector, | ||
| numFeatures: Int, | ||
| numClasses: Int, | ||
| fitIntercept: Boolean, | ||
| featuresStd: Array[Double], | ||
| featuresMean: Array[Double]) extends Serializable { | ||
| fitIntercept: Boolean) extends Serializable { | ||
|
|
||
| private var weightSum = 0.0 | ||
| private var lossSum = 0.0 | ||
|
|
||
| private val coefficientsArray = coefficients match { | ||
| case dv: DenseVector => dv.values | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
| s"coefficients only supports dense vector but got type ${coefficients.getClass}.") | ||
| } | ||
|
|
||
| private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length | ||
|
|
||
| private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) | ||
| private val dim = numFeatures | ||
| private val gradientSumArray = | ||
| Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures) | ||
|
|
||
| /** | ||
| * Add a new training instance to this LogisticAggregator, and update the loss and gradient | ||
| * of the objective function. | ||
| * | ||
| * @param instance The instance of data point to be added. | ||
| * @param coefficients The coefficients corresponding to the features. | ||
| * @param featuresStd The standard deviation values of the features. | ||
| * @return This LogisticAggregator object. | ||
| */ | ||
| def add(instance: Instance): this.type = { | ||
| def add(instance: Instance, | ||
| coefficients: Vector, | ||
| featuresStd: Array[Double]): this.type = { | ||
| instance match { case Instance(label, weight, features) => | ||
| require(dim == features.size, s"Dimensions mismatch when adding new instance." + | ||
| s" Expecting $dim but got ${features.size}.") | ||
| require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") | ||
|
|
||
| if (weight == 0.0) return this | ||
|
|
||
| val localCoefficientsArray = coefficientsArray | ||
| val coefficientsArray = coefficients match { | ||
| case dv: DenseVector => dv.values | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
| s"coefficients only supports dense vector but got type ${coefficients.getClass}.") | ||
| } | ||
| val localGradientSumArray = gradientSumArray | ||
|
|
||
| numClasses match { | ||
|
|
@@ -990,11 +987,11 @@ private class LogisticAggregator( | |
| var sum = 0.0 | ||
| features.foreachActive { (index, value) => | ||
| if (featuresStd(index) != 0.0 && value != 0.0) { | ||
| sum += localCoefficientsArray(index) * (value / featuresStd(index)) | ||
| sum += coefficientsArray(index) * (value / featuresStd(index)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could be missing something, but why are the values being normalized by the stdev but not centered by the mean? before, the means were passed in, so I wonder if something was being overlooked before.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it is by design and not being overlooked (perhaps the mean was also passed in case future versions decided it was necessary to mean center the data). This comment is from MLlib: |
||
| } | ||
| } | ||
| sum + { | ||
| if (fitIntercept) localCoefficientsArray(dim) else 0.0 | ||
| if (fitIntercept) coefficientsArray(dim) else 0.0 | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1086,13 +1083,17 @@ private class LogisticCostFun( | |
| override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { | ||
| val numFeatures = featuresStd.length | ||
| val coeffs = Vectors.fromBreeze(coefficients) | ||
| val n = coeffs.size | ||
| val localFeaturesStd = featuresStd | ||
|
|
||
|
|
||
| val logisticAggregator = { | ||
| val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) | ||
| val seqOp = (c: LogisticAggregator, instance: Instance) => | ||
| c.add(instance, coeffs, localFeaturesStd) | ||
| val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) | ||
|
|
||
| instances.treeAggregate( | ||
| new LogisticAggregator(coeffs, numClasses, fitIntercept, featuresStd, featuresMean) | ||
| new LogisticAggregator(numFeatures, numClasses, fitIntercept) | ||
| )(seqOp, combOp) | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need dim here, or can just reference numFeatures later in the class?
I had to look twice at the line below to make sure the logic wasn't reversed from before but I see why it works out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left it because this logic will likely change when multiclass is added.
dimis used to check that the overall coefficients array is the correct length, which won't benumFeaturesfor multiclass. Still, I can remove it here if that seems better.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, I like your suggestion. I updated it accordingly.