Skip to content

Commit 1e47a11

Browse files
committed
address comments
1 parent 4b336be commit 1e47a11

File tree

5 files changed

+41
-77
lines changed

5 files changed

+41
-77
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ private[ml] case class Instance(label: Double, weight: Double, features: Vector)
3131
/**
3232
* Case class that represents an instance of data point with
3333
* label, weight, offset and features.
34+
* This is mainly used in GeneralizedLinearRegression currently.
3435
*
3536
* @param label Label for this data point.
3637
* @param weight The weight of this instance.

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.optim
1919

2020
import org.apache.spark.internal.Logging
21-
import org.apache.spark.ml.feature.Instance
21+
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
2222
import org.apache.spark.ml.linalg._
2323
import org.apache.spark.rdd.RDD
2424

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
135135
def getLinkPredictionCol: String = $(linkPredictionCol)
136136

137137
/**
138-
* Param for offset column name. If this is not set or empty, we treat all
139-
* instance offsets as 0.0.
138+
* Param for offset column name. If this is not set or empty, we treat all instance offsets
139+
* as 0.0. The feature specified as offset has a constant coefficient of 1.0.
140140
* @group param
141141
*/
142142
final val offsetCol: Param[String] = new Param[String](this, "offsetCol", "The offset " +
@@ -145,6 +145,14 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
145145
/** @group getParam */
146146
def getOffsetCol: String = $(offsetCol)
147147

148+
/** Checks whether weight column is set and nonempty. */
149+
private[regression] def hasWeightCol: Boolean =
150+
isSet(weightCol) && $(weightCol).nonEmpty
151+
152+
/** Checks whether offset column is set and nonempty. */
153+
private[regression] def hasOffsetCol: Boolean =
154+
isSet(offsetCol) && $(offsetCol).nonEmpty
155+
148156
/** Checks whether we should output link prediction. */
149157
private[regression] def hasLinkPredictionCol: Boolean = {
150158
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
@@ -179,9 +187,11 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
179187
}
180188

181189
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
182-
if (fitting) {
183-
if (isSetOffsetCol(this)) SchemaUtils.checkNumericType(schema, $(offsetCol))
190+
191+
if (hasOffsetCol) {
192+
SchemaUtils.checkNumericType(schema, $(offsetCol))
184193
}
194+
185195
if (hasLinkPredictionCol) {
186196
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
187197
} else {
@@ -318,7 +328,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
318328

319329
/**
320330
* Sets the value of param [[offsetCol]].
321-
* The feature specified as offset has a constant coefficient of 1.0.
322331
* If this is not set or empty, we treat all instance offsets as 0.0.
323332
* Default is not set, so all instances have offset 0.0.
324333
*
@@ -364,8 +373,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
364373
"GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " +
365374
"set to false. To fit a model with 0 features, fitIntercept must be set to true." )
366375

367-
val w = if (!isSetWeightCol(this)) lit(1.0) else col($(weightCol))
368-
val offset = if (!isSetOffsetCol(this)) lit(0.0) else col($(offsetCol)).cast(DoubleType)
376+
val w = if (!hasWeightCol) lit(1.0) else col($(weightCol))
377+
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
369378

370379
val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) {
371380
// TODO: Make standardizeFeatures and standardizeLabel configurable.
@@ -437,14 +446,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
437446

438447
private[regression] val epsilon: Double = 1E-16
439448

440-
/** Checks whether weight column is set and nonempty */
441-
private[regression] def isSetWeightCol(params: GeneralizedLinearRegressionBase): Boolean =
442-
params.isSet(params.weightCol) && params.getWeightCol.nonEmpty
443-
444-
/** Checks whether offset column is set and nonempty */
445-
private[regression] def isSetOffsetCol(params: GeneralizedLinearRegressionBase): Boolean =
446-
params.isSet(params.offsetCol) && params.getOffsetCol.nonEmpty
447-
448449
/**
449450
* Wrapper of family and link combination used in the model.
450451
*/
@@ -476,14 +477,14 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
476477
}
477478

478479
/**
479-
* The reweight function used to update offsets and weights
480+
* The reweight function used to update working labels and weights
480481
* at each iteration of [[IterativelyReweightedLeastSquares]].
481482
*/
482483
val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = {
483484
(instance: OffsetInstance, model: WeightedLeastSquaresModel) => {
484-
val eta = model.predict(instance.features)
485-
val mu = fitted(eta + instance.offset)
486-
val newLabel = eta + (instance.label - mu) * link.deriv(mu)
485+
val eta = model.predict(instance.features) + instance.offset
486+
val mu = fitted(eta)
487+
val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu)
487488
val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
488489
(newLabel, newWeight)
489490
}
@@ -989,7 +990,7 @@ class GeneralizedLinearRegressionModel private[ml] (
989990
/**
990991
* Calculates the predicted value when offset is set.
991992
*/
992-
protected def predict(features: Vector, offset: Double): Double = {
993+
def predict(features: Vector, offset: Double): Double = {
993994
val eta = predictLink(features, offset)
994995
familyAndLink.fitted(eta)
995996
}
@@ -1009,22 +1010,8 @@ class GeneralizedLinearRegressionModel private[ml] (
10091010
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
10101011
val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) }
10111012
val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) }
1012-
/*
1013-
Offset is only validated when it's specified in the model and available in prediction data set.
1014-
When offset is specified but missing in the prediction data set, we default it to zero.
1015-
*/
1016-
val offset = {
1017-
if (!isSetOffsetCol(this)) {
1018-
lit(0.0)
1019-
} else {
1020-
if (dataset.schema.fieldNames.contains($(offsetCol))) {
1021-
SchemaUtils.checkNumericType(dataset.schema, $(offsetCol))
1022-
col($(offsetCol)).cast(DoubleType)
1023-
} else {
1024-
lit(0.0)
1025-
}
1026-
}
1027-
}
1013+
1014+
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
10281015
var output = dataset
10291016
if ($(predictionCol).nonEmpty) {
10301017
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset))
@@ -1218,11 +1205,11 @@ class GeneralizedLinearRegressionSummary private[regression] (
12181205
private def prediction: Column = col(predictionCol)
12191206

12201207
private def weight: Column = {
1221-
if (!isSetWeightCol(model)) lit(1.0) else col(model.getWeightCol)
1208+
if (!model.hasWeightCol) lit(1.0) else col(model.getWeightCol)
12221209
}
12231210

12241211
private def offset: Column = {
1225-
if (!isSetOffsetCol(model)) lit(0.0) else col(model.getOffsetCol).cast(DoubleType)
1212+
if (!model.hasOffsetCol) lit(0.0) else col(model.getOffsetCol).cast(DoubleType)
12261213
}
12271214

12281215
private[regression] lazy val devianceResiduals: DataFrame = {
@@ -1285,8 +1272,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
12851272
Estimate intercept analytically when there is no offset, or when there is offset but
12861273
the model is Gaussian family with identity link. Otherwise, fit an intercept only model.
12871274
*/
1288-
if (!isSetOffsetCol(model) ||
1289-
(isSetOffsetCol(model) && family == Gaussian && link == Identity)) {
1275+
if (!model.hasOffsetCol ||
1276+
(model.hasOffsetCol && family == Gaussian && link == Identity)) {
12901277
val agg = predictions.agg(sum(weight.multiply(
12911278
label.minus(offset))), sum(weight)).first()
12921279
link.link(agg.getDouble(0) / agg.getDouble(1))

mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
3939
w <- c(1, 2, 3, 4)
4040
*/
4141
instances1 = sc.parallelize(Seq(
42-
Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
43-
Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
44-
Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
45-
Instance(0.0, 4.0, Vectors.dense(3.0, 3.0))
46-
), 2).map(new OffsetInstance(_))
42+
OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse),
43+
OffsetInstance(0.0, 2.0, 0.0, Vectors.dense(1.0, 2.0)),
44+
OffsetInstance(1.0, 3.0, 0.0, Vectors.dense(2.0, 1.0)),
45+
OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0))
46+
), 2)
4747
/*
4848
R code:
4949
@@ -52,11 +52,11 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
5252
w <- c(1, 2, 3, 4)
5353
*/
5454
instances2 = sc.parallelize(Seq(
55-
Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
56-
Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
57-
Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
58-
Instance(9.0, 4.0, Vectors.dense(3.0, 13.0))
59-
), 2).map(new OffsetInstance(_))
55+
OffsetInstance(2.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse),
56+
OffsetInstance(8.0, 2.0, 0.0, Vectors.dense(1.0, 7.0)),
57+
OffsetInstance(3.0, 3.0, 0.0, Vectors.dense(2.0, 11.0)),
58+
OffsetInstance(9.0, 4.0, 0.0, Vectors.dense(3.0, 13.0))
59+
), 2)
6060
}
6161

6262
test("IRLS against GLM with Binomial errors") {

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ class GeneralizedLinearRegressionSuite
798798
}
799799
}
800800

801-
test("generalized linear regression with offset") {
801+
test("generalized linear regression with weight and offset") {
802802
/*
803803
R code:
804804
library(statmod)
@@ -881,30 +881,6 @@ class GeneralizedLinearRegressionSuite
881881
}
882882
}
883883

884-
test("generalized linear regression: predict with no offset") {
885-
val trainData = Seq(
886-
OffsetInstance(2.0, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
887-
OffsetInstance(8.0, 2.0, 3.0, Vectors.dense(1.0, 7.0)),
888-
OffsetInstance(3.0, 3.0, 1.0, Vectors.dense(2.0, 11.0)),
889-
OffsetInstance(9.0, 4.0, 4.0, Vectors.dense(3.0, 13.0))
890-
).toDF()
891-
val testData = trainData.select("weight", "features")
892-
893-
val trainer = new GeneralizedLinearRegression()
894-
.setFamily("poisson")
895-
.setWeightCol("weight")
896-
.setOffsetCol("offset")
897-
.setLinkPredictionCol("linkPrediction")
898-
899-
val model = trainer.fit(trainData)
900-
model.transform(testData).select("features", "linkPrediction")
901-
.collect().foreach {
902-
case Row(features: DenseVector, linkPrediction1: Double) =>
903-
val linkPrediction2 = BLAS.dot(features, model.coefficients) + model.intercept
904-
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch")
905-
}
906-
}
907-
908884
test("glm summary: gaussian family with weight and offset") {
909885
/*
910886
R code:
@@ -1309,7 +1285,7 @@ class GeneralizedLinearRegressionSuite
13091285
-0.16134949 0.20807694 -0.22544551 0.03258777
13101286
residuals(model, type = "working")
13111287
1 2 3 4
1312-
0.135315831 -0.084390309 0.113219135 -0.008279688
1288+
0.135315831 -0.084390309 0.113219135 -0.008279688
13131289
residuals(model, type = "response")
13141290
1 2 3 4
13151291
-0.1923918 0.2565224 -0.1496381 0.0320653

0 commit comments

Comments
 (0)