@@ -135,8 +135,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
135135 def getLinkPredictionCol : String = $(linkPredictionCol)
136136
137137 /**
138- * Param for offset column name. If this is not set or empty, we treat all
139- * instance offsets as 0 .0.
138+ * Param for offset column name. If this is not set or empty, we treat all instance offsets
139+ * as 0.0. The feature specified as offset has a constant coefficient of 1 .0.
140140 * @group param
141141 */
142142 final val offsetCol : Param [String ] = new Param [String ](this , " offsetCol" , " The offset " +
@@ -145,6 +145,14 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
145145 /** @group getParam */
146146 def getOffsetCol : String = $(offsetCol)
147147
148+ /** Checks whether weight column is set and nonempty. */
149+ private [regression] def hasWeightCol : Boolean =
150+ isSet(weightCol) && $(weightCol).nonEmpty
151+
152+ /** Checks whether offset column is set and nonempty. */
153+ private [regression] def hasOffsetCol : Boolean =
154+ isSet(offsetCol) && $(offsetCol).nonEmpty
155+
148156 /** Checks whether we should output link prediction. */
149157 private [regression] def hasLinkPredictionCol : Boolean = {
150158 isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
@@ -179,9 +187,11 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
179187 }
180188
181189 val newSchema = super .validateAndTransformSchema(schema, fitting, featuresDataType)
182- if (fitting) {
183- if (isSetOffsetCol(this )) SchemaUtils .checkNumericType(schema, $(offsetCol))
190+
191+ if (hasOffsetCol) {
192+ SchemaUtils .checkNumericType(schema, $(offsetCol))
184193 }
194+
185195 if (hasLinkPredictionCol) {
186196 SchemaUtils .appendColumn(newSchema, $(linkPredictionCol), DoubleType )
187197 } else {
@@ -318,7 +328,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
318328
319329 /**
320330 * Sets the value of param [[offsetCol ]].
321- * The feature specified as offset has a constant coefficient of 1.0.
322331 * If this is not set or empty, we treat all instance offsets as 0.0.
323332 * Default is not set, so all instances have offset 0.0.
324333 *
@@ -364,8 +373,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
364373 " GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " +
365374 " set to false. To fit a model with 0 features, fitIntercept must be set to true." )
366375
367- val w = if (! isSetWeightCol( this ) ) lit(1.0 ) else col($(weightCol))
368- val offset = if (! isSetOffsetCol( this ) ) lit(0.0 ) else col($(offsetCol)).cast(DoubleType )
376+ val w = if (! hasWeightCol ) lit(1.0 ) else col($(weightCol))
377+ val offset = if (! hasOffsetCol ) lit(0.0 ) else col($(offsetCol)).cast(DoubleType )
369378
370379 val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity ) {
371380 // TODO: Make standardizeFeatures and standardizeLabel configurable.
@@ -437,14 +446,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
437446
438447 private [regression] val epsilon : Double = 1E-16
439448
440- /** Checks whether weight column is set and nonempty */
441- private [regression] def isSetWeightCol (params : GeneralizedLinearRegressionBase ): Boolean =
442- params.isSet(params.weightCol) && params.getWeightCol.nonEmpty
443-
444- /** Checks whether offset column is set and nonempty */
445- private [regression] def isSetOffsetCol (params : GeneralizedLinearRegressionBase ): Boolean =
446- params.isSet(params.offsetCol) && params.getOffsetCol.nonEmpty
447-
448449 /**
449450 * Wrapper of family and link combination used in the model.
450451 */
@@ -476,14 +477,14 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
476477 }
477478
478479 /**
479- * The reweight function used to update offsets and weights
480+ * The reweight function used to update working labels and weights
480481 * at each iteration of [[IterativelyReweightedLeastSquares ]].
481482 */
482483 val reweightFunc : (OffsetInstance , WeightedLeastSquaresModel ) => (Double , Double ) = {
483484 (instance : OffsetInstance , model : WeightedLeastSquaresModel ) => {
484- val eta = model.predict(instance.features)
485- val mu = fitted(eta + instance.offset )
486- val newLabel = eta + (instance.label - mu) * link.deriv(mu)
485+ val eta = model.predict(instance.features) + instance.offset
486+ val mu = fitted(eta)
487+ val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu)
487488 val newWeight = instance.weight / (math.pow(this .link.deriv(mu), 2.0 ) * family.variance(mu))
488489 (newLabel, newWeight)
489490 }
@@ -989,7 +990,7 @@ class GeneralizedLinearRegressionModel private[ml] (
989990 /**
990991 * Calculates the predicted value when offset is set.
991992 */
992- protected def predict (features : Vector , offset : Double ): Double = {
993+ def predict (features : Vector , offset : Double ): Double = {
993994 val eta = predictLink(features, offset)
994995 familyAndLink.fitted(eta)
995996 }
@@ -1009,22 +1010,8 @@ class GeneralizedLinearRegressionModel private[ml] (
10091010 override protected def transformImpl (dataset : Dataset [_]): DataFrame = {
10101011 val predictUDF = udf { (features : Vector , offset : Double ) => predict(features, offset) }
10111012 val predictLinkUDF = udf { (features : Vector , offset : Double ) => predictLink(features, offset) }
1012- /*
1013- Offset is only validated when it's specified in the model and available in prediction data set.
1014- When offset is specified but missing in the prediction data set, we default it to zero.
1015- */
1016- val offset = {
1017- if (! isSetOffsetCol(this )) {
1018- lit(0.0 )
1019- } else {
1020- if (dataset.schema.fieldNames.contains($(offsetCol))) {
1021- SchemaUtils .checkNumericType(dataset.schema, $(offsetCol))
1022- col($(offsetCol)).cast(DoubleType )
1023- } else {
1024- lit(0.0 )
1025- }
1026- }
1027- }
1013+
1014+ val offset = if (! hasOffsetCol) lit(0.0 ) else col($(offsetCol)).cast(DoubleType )
10281015 var output = dataset
10291016 if ($(predictionCol).nonEmpty) {
10301017 output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset))
@@ -1218,11 +1205,11 @@ class GeneralizedLinearRegressionSummary private[regression] (
12181205 private def prediction : Column = col(predictionCol)
12191206
12201207 private def weight : Column = {
1221- if (! isSetWeightCol( model) ) lit(1.0 ) else col(model.getWeightCol)
1208+ if (! model.hasWeightCol ) lit(1.0 ) else col(model.getWeightCol)
12221209 }
12231210
12241211 private def offset : Column = {
1225- if (! isSetOffsetCol( model) ) lit(0.0 ) else col(model.getOffsetCol).cast(DoubleType )
1212+ if (! model.hasOffsetCol ) lit(0.0 ) else col(model.getOffsetCol).cast(DoubleType )
12261213 }
12271214
12281215 private [regression] lazy val devianceResiduals : DataFrame = {
@@ -1285,8 +1272,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
12851272 Estimate intercept analytically when there is no offset, or when there is offset but
12861273 the model is Gaussian family with identity link. Otherwise, fit an intercept only model.
12871274 */
1288- if (! isSetOffsetCol( model) ||
1289- (isSetOffsetCol( model) && family == Gaussian && link == Identity )) {
1275+ if (! model.hasOffsetCol ||
1276+ (model.hasOffsetCol && family == Gaussian && link == Identity )) {
12901277 val agg = predictions.agg(sum(weight.multiply(
12911278 label.minus(offset))), sum(weight)).first()
12921279 link.link(agg.getDouble(0 ) / agg.getDouble(1 ))
0 commit comments