-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18710][ML] Add offset in GLM #16699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
3bf2718
0e240eb
9c41453
7823f8a
a1f5695
d071b95
d2afcb0
9eca1a6
d44974c
9c320ee
e183c08
58f93af
da4174a
52bc32b
59e10f7
1d41bdd
fb372ad
2bc3ae7
afb4643
fc64d32
90d68a6
e95c25b
4b336be
1e47a11
db0ac93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,3 +27,25 @@ import org.apache.spark.ml.linalg.Vector | |
| * @param features The vector of features for this data point. | ||
| */ | ||
| private[ml] case class Instance(label: Double, weight: Double, features: Vector) | ||
|
|
||
| /** | ||
| * Case class that represents an instance of data point with | ||
| * label, weight, offset and features. | ||
| * | ||
| * @param label Label for this data point. | ||
| * @param weight The weight of this instance. | ||
| * @param offset The offset used for this data point. | ||
| * @param features The vector of features for this data point. | ||
| */ | ||
| private[ml] case class OffsetInstance(label: Double, weight: Double, offset: Double, | ||
|
||
| features: Vector) { | ||
|
|
||
| /** Constructs from an [[Instance]] object and offset */ | ||
| def this(instance: Instance, offset: Double = 0.0) = { | ||
|
||
| this(instance.label, instance.weight, offset, instance.features) | ||
| } | ||
|
|
||
| /** Converts to an [[Instance]] object by leaving out the offset. */ | ||
| private[ml] def toInstance: Instance = Instance(label, weight, features) | ||
|
||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,7 +24,7 @@ import org.apache.spark.SparkException | |||||||||||||||||||||
| import org.apache.spark.annotation.{Experimental, Since} | ||||||||||||||||||||||
| import org.apache.spark.internal.Logging | ||||||||||||||||||||||
| import org.apache.spark.ml.PredictorParams | ||||||||||||||||||||||
| import org.apache.spark.ml.feature.Instance | ||||||||||||||||||||||
| import org.apache.spark.ml.feature.{Instance, OffsetInstance} | ||||||||||||||||||||||
| import org.apache.spark.ml.linalg.{BLAS, Vector} | ||||||||||||||||||||||
| import org.apache.spark.ml.optim._ | ||||||||||||||||||||||
| import org.apache.spark.ml.param._ | ||||||||||||||||||||||
|
|
@@ -134,6 +134,17 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |||||||||||||||||||||
| @Since("2.0.0") | ||||||||||||||||||||||
| def getLinkPredictionCol: String = $(linkPredictionCol) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /** | ||||||||||||||||||||||
| * Param for offset column name. If this is not set or empty, we treat all | ||||||||||||||||||||||
| * instance offsets as 0.0. | ||||||||||||||||||||||
| * @group param | ||||||||||||||||||||||
| */ | ||||||||||||||||||||||
| final val offsetCol: Param[String] = new Param[String](this, "offsetCol", "The offset " + | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||
| "column name. If this is not set or empty, we treat all instance offsets as 0.0") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /** @group getParam */ | ||||||||||||||||||||||
| def getOffsetCol: String = $(offsetCol) | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it looks like you will need to update the validateAndTransformSchema method below to validate these parameters - eg check if the column exists? (similar to what the base class does for features/label columns)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for fixing!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /** Checks whether we should output link prediction. */ | ||||||||||||||||||||||
| private[regression] def hasLinkPredictionCol: Boolean = { | ||||||||||||||||||||||
| isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty | ||||||||||||||||||||||
|
|
@@ -168,6 +179,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) | ||||||||||||||||||||||
| if (isSet(offsetCol) && $(offsetCol).nonEmpty) { | ||||||||||||||||||||||
| SchemaUtils.checkNumericType(schema, $(offsetCol)) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need check numeric type for both fit & transform. |
||||||||||||||||||||||
| if (hasLinkPredictionCol) { | ||||||||||||||||||||||
| SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
|
|
@@ -302,6 +316,17 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |||||||||||||||||||||
| @Since("2.0.0") | ||||||||||||||||||||||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /** | ||||||||||||||||||||||
| * Sets the value of param [[offsetCol]]. | ||||||||||||||||||||||
| * The feature specified as offset has a constant coefficient of 1.0. | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this line to param doc. Usually we keep the most integrated doc in param annotation, and for set method, we can just say |
||||||||||||||||||||||
| * If this is not set or empty, we treat all instance offsets as 0.0. | ||||||||||||||||||||||
| * Default is not set, so all instances have offset 0.0. | ||||||||||||||||||||||
| * | ||||||||||||||||||||||
| * @group setParam | ||||||||||||||||||||||
| */ | ||||||||||||||||||||||
| @Since("2.2.0") | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2.2.0 -> 2.3.0 |
||||||||||||||||||||||
| def setOffsetCol(value: String): this.type = set(offsetCol, value) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /** | ||||||||||||||||||||||
| * Sets the solver algorithm used for optimization. | ||||||||||||||||||||||
| * Currently only supports "irls" which is also the default solver. | ||||||||||||||||||||||
|
|
@@ -325,7 +350,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size | ||||||||||||||||||||||
| val instr = Instrumentation.create(this, dataset) | ||||||||||||||||||||||
| instr.logParams(labelCol, featuresCol, weightCol, predictionCol, linkPredictionCol, | ||||||||||||||||||||||
| instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, | ||||||||||||||||||||||
| family, solver, fitIntercept, link, maxIter, regParam, tol) | ||||||||||||||||||||||
| instr.logNumFeatures(numFeatures) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -336,14 +361,19 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||||||||||||||||||||||
| val instances: RDD[Instance] = | ||||||||||||||||||||||
| dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { | ||||||||||||||||||||||
| case Row(label: Double, weight: Double, features: Vector) => | ||||||||||||||||||||||
| Instance(label, weight, features) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| val offset = if (!isDefined(offsetCol) || $(offsetCol).isEmpty) { | ||||||||||||||||||||||
| lit(0.0) | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| col($(offsetCol)).cast(DoubleType) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { | ||||||||||||||||||||||
| // TODO: Make standardizeFeatures and standardizeLabel configurable. | ||||||||||||||||||||||
| val instances: RDD[Instance] = | ||||||||||||||||||||||
| dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { | ||||||||||||||||||||||
| case Row(label: Double, weight: Double, offset: Double, features: Vector) => | ||||||||||||||||||||||
| Instance(label - offset, weight, features) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, | ||||||||||||||||||||||
| standardizeFeatures = true, standardizeLabel = true) | ||||||||||||||||||||||
| val wlsModel = optimizer.fit(instances) | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think of adding a new interface
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest we leave We discussed something relevant above here. I originally defined |
||||||||||||||||||||||
|
|
@@ -354,6 +384,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |||||||||||||||||||||
| wlsModel.diagInvAtWA.toArray, 1, getSolver) | ||||||||||||||||||||||
| model.setSummary(Some(trainingSummary)) | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| val instances: RDD[OffsetInstance] = | ||||||||||||||||||||||
| dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { | ||||||||||||||||||||||
| case Row(label: Double, weight: Double, offset: Double, features: Vector) => | ||||||||||||||||||||||
| OffsetInstance(label, weight, offset, features) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). | ||||||||||||||||||||||
| val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) | ||||||||||||||||||||||
| val optimizer = new IterativelyReweightedLeastSquares(initialModel, | ||||||||||||||||||||||
|
|
@@ -417,12 +452,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |||||||||||||||||||||
| * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. | ||||||||||||||||||||||
| */ | ||||||||||||||||||||||
| def initialize( | ||||||||||||||||||||||
| instances: RDD[Instance], | ||||||||||||||||||||||
| instances: RDD[OffsetInstance], | ||||||||||||||||||||||
| fitIntercept: Boolean, | ||||||||||||||||||||||
| regParam: Double): WeightedLeastSquaresModel = { | ||||||||||||||||||||||
| val newInstances = instances.map { instance => | ||||||||||||||||||||||
| val mu = family.initialize(instance.label, instance.weight) | ||||||||||||||||||||||
| val eta = predict(mu) | ||||||||||||||||||||||
| val eta = predict(mu) - instance.offset | ||||||||||||||||||||||
| Instance(eta, instance.weight, instance.features) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| // TODO: Make standardizeFeatures and standardizeLabel configurable. | ||||||||||||||||||||||
|
|
@@ -436,13 +471,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |||||||||||||||||||||
| * The reweight function used to update offsets and weights | ||||||||||||||||||||||
|
||||||||||||||||||||||
| * at each iteration of [[IterativelyReweightedLeastSquares]]. | ||||||||||||||||||||||
| */ | ||||||||||||||||||||||
| val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { | ||||||||||||||||||||||
| (instance: Instance, model: WeightedLeastSquaresModel) => { | ||||||||||||||||||||||
| val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = { | ||||||||||||||||||||||
| (instance: OffsetInstance, model: WeightedLeastSquaresModel) => { | ||||||||||||||||||||||
| val eta = model.predict(instance.features) | ||||||||||||||||||||||
| val mu = fitted(eta) | ||||||||||||||||||||||
| val offset = eta + (instance.label - mu) * link.deriv(mu) | ||||||||||||||||||||||
| val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) | ||||||||||||||||||||||
| (offset, weight) | ||||||||||||||||||||||
| val mu = fitted(eta + instance.offset) | ||||||||||||||||||||||
| val newLabel = eta + (instance.label - mu) * link.deriv(mu) | ||||||||||||||||||||||
| val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) | ||||||||||||||||||||||
| (newLabel, newWeight) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
@@ -940,15 +975,27 @@ class GeneralizedLinearRegressionModel private[ml] ( | |||||||||||||||||||||
| private lazy val familyAndLink = FamilyAndLink(this) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| override protected def predict(features: Vector): Double = { | ||||||||||||||||||||||
| val eta = predictLink(features) | ||||||||||||||||||||||
| if (!isSet(offsetCol) || $(offsetCol).isEmpty) { | ||||||||||||||||||||||
| val eta = BLAS.dot(features, coefficients) + intercept | ||||||||||||||||||||||
| familyAndLink.fitted(eta) | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| throw new SparkException("Must supply offset to predict when offset column is set.") | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /** | ||||||||||||||||||||||
| * Calculates the predicted value when offset is set. | ||||||||||||||||||||||
| */ | ||||||||||||||||||||||
| protected def predict(features: Vector, offset: Double): Double = { | ||||||||||||||||||||||
|
||||||||||||||||||||||
| val eta = predictLink(features, offset) | ||||||||||||||||||||||
| familyAndLink.fitted(eta) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /** | ||||||||||||||||||||||
| * Calculate the link prediction (linear predictor) of the given instance. | ||||||||||||||||||||||
| * Calculates the link prediction (linear predictor) of the given instance. | ||||||||||||||||||||||
| */ | ||||||||||||||||||||||
| private def predictLink(features: Vector): Double = { | ||||||||||||||||||||||
| BLAS.dot(features, coefficients) + intercept | ||||||||||||||||||||||
| private def predictLink(features: Vector, offset: Double): Double = { | ||||||||||||||||||||||
| BLAS.dot(features, coefficients) + intercept + offset | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||||||||||||||||||||||
|
|
@@ -957,14 +1004,19 @@ class GeneralizedLinearRegressionModel private[ml] ( | |||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| override protected def transformImpl(dataset: Dataset[_]): DataFrame = { | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I summarized all four cases for making prediction as following:
For case 1 and 4, there is not that controversial.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for summarizing the different cases. I think this is worth a deeper discussion as follow-up work. Let me work on this in another PR. |
||||||||||||||||||||||
| val predictUDF = udf { (features: Vector) => predict(features) } | ||||||||||||||||||||||
| val predictLinkUDF = udf { (features: Vector) => predictLink(features) } | ||||||||||||||||||||||
| val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } | ||||||||||||||||||||||
| val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) } | ||||||||||||||||||||||
| val offset = if (!isSet(offsetCol) || $(offsetCol).isEmpty) { | ||||||||||||||||||||||
| lit(0.0) | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| col($(offsetCol)).cast(DoubleType) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
||||||||||||||||||||||
| var output = dataset | ||||||||||||||||||||||
| if ($(predictionCol).nonEmpty) { | ||||||||||||||||||||||
| output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||||||||||||||||||||||
| output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset)) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| if (hasLinkPredictionCol) { | ||||||||||||||||||||||
| output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) | ||||||||||||||||||||||
| output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset)) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| output.toDF() | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,16 +18,16 @@ | |
| package org.apache.spark.ml.optim | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.feature.{Instance, OffsetInstance} | ||
| import org.apache.spark.ml.linalg.Vectors | ||
| import org.apache.spark.ml.util.TestingUtils._ | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { | ||
|
|
||
| private var instances1: RDD[Instance] = _ | ||
| private var instances2: RDD[Instance] = _ | ||
| private var instances1: RDD[OffsetInstance] = _ | ||
| private var instances2: RDD[OffsetInstance] = _ | ||
|
|
||
| override def beforeAll(): Unit = { | ||
| super.beforeAll() | ||
|
|
@@ -43,7 +43,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes | |
| Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), | ||
| Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), | ||
| Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) | ||
| ), 2) | ||
| ), 2).map(new OffsetInstance(_)) | ||
|
||
| /* | ||
| R code: | ||
|
|
||
|
|
@@ -56,7 +56,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes | |
| Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), | ||
| Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), | ||
| Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) | ||
| ), 2) | ||
| ), 2).map(new OffsetInstance(_)) | ||
| } | ||
|
|
||
| test("IRLS against GLM with Binomial errors") { | ||
|
|
@@ -156,7 +156,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes | |
| var idx = 0 | ||
| for (fitIntercept <- Seq(false, true)) { | ||
| val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, | ||
| standardizeFeatures = false, standardizeLabel = false).fit(instances2) | ||
| standardizeFeatures = false, standardizeLabel = false).fit(instances2.map(_.toInstance)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my above comment about adding interface |
||
| val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, | ||
| fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) | ||
| val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) | ||
|
|
@@ -169,29 +169,29 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes | |
| object IterativelyReweightedLeastSquaresSuite { | ||
|
|
||
| def BinomialReweightFunc( | ||
| instance: Instance, | ||
| instance: OffsetInstance, | ||
| model: WeightedLeastSquaresModel): (Double, Double) = { | ||
| val eta = model.predict(instance.features) | ||
| val eta = model.predict(instance.features) + instance.offset | ||
| val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) | ||
| val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) | ||
| val z = eta - instance.offset + (instance.label - mu) / (mu * (1.0 - mu)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed this is the correct implementation: in the IRWLS, we only include offset when computing |
||
| val w = mu * (1 - mu) * instance.weight | ||
| (z, w) | ||
| } | ||
|
|
||
| def PoissonReweightFunc( | ||
| instance: Instance, | ||
| instance: OffsetInstance, | ||
| model: WeightedLeastSquaresModel): (Double, Double) = { | ||
| val eta = model.predict(instance.features) | ||
| val eta = model.predict(instance.features) + instance.offset | ||
| val mu = math.exp(eta) | ||
| val z = eta + (instance.label - mu) / mu | ||
| val z = eta - instance.offset + (instance.label - mu) / mu | ||
| val w = mu * instance.weight | ||
| (z, w) | ||
| } | ||
|
|
||
| def L1RegressionReweightFunc( | ||
| instance: Instance, | ||
| instance: OffsetInstance, | ||
| model: WeightedLeastSquaresModel): (Double, Double) = { | ||
| val eta = model.predict(instance.features) | ||
| val eta = model.predict(instance.features) + instance.offset | ||
| val e = math.max(math.abs(eta - instance.label), 1e-7) | ||
| val w = 1 / e | ||
| val y = instance.label | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add doc
This is mainly used in GeneralizedLinearRegression currently.