@@ -24,7 +24,7 @@ import org.apache.spark.SparkException
2424import org .apache .spark .annotation .{Experimental , Since }
2525import org .apache .spark .internal .Logging
2626import org .apache .spark .ml .PredictorParams
27- import org .apache .spark .ml .feature .Instance
27+ import org .apache .spark .ml .feature .{ Instance , OffsetInstance }
2828import org .apache .spark .ml .linalg .{BLAS , Vector }
2929import org .apache .spark .ml .optim ._
3030import org .apache .spark .ml .param ._
@@ -123,6 +123,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
123123 s " with ${$(family)} family does not support ${$(link)} link function. " )
124124 }
125125 val newSchema = super .validateAndTransformSchema(schema, fitting, featuresDataType)
126+ if (isSet(offsetCol) && $(offsetCol).nonEmpty) {
127+ SchemaUtils .checkNumericType(schema, $(offsetCol))
128+ }
126129 if (hasLinkPredictionCol) {
127130 SchemaUtils .appendColumn(newSchema, $(linkPredictionCol), DoubleType )
128131 } else {
@@ -286,16 +289,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
286289
287290 val w = if (! isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0 ) else col($(weightCol))
288291 val off = if (! isDefined(offsetCol) || $(offsetCol).isEmpty) lit(0.0 ) else col($(offsetCol))
289- val instances : RDD [GLRInstance ] =
292+ val instances : RDD [OffsetInstance ] =
290293 dataset.select(col($(labelCol)), w, off, col($(featuresCol))).rdd.map {
291294 case Row (label : Double , weight : Double , offset : Double , features : Vector ) =>
292- new GLRInstance (label, weight, offset, features)
295+ OffsetInstance (label, weight, offset, features)
293296 }
294297
295298 val model = if (familyObj == Gaussian && linkObj == Identity ) {
296299 // TODO: Make standardizeFeatures and standardizeLabel configurable.
297300 val wlsInstances : RDD [Instance ] = instances.map { instance =>
298- new Instance (instance.label - instance.offset, instance.weight, instance.features)
301+ Instance (instance.label - instance.offset, instance.weight, instance.features)
299302 }
300303 val optimizer = new WeightedLeastSquares ($(fitIntercept), $(regParam), elasticNetParam = 0.0 ,
301304 standardizeFeatures = true , standardizeLabel = true )
@@ -365,7 +368,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
365368 * Get the initial guess model for [[IterativelyReweightedLeastSquares ]].
366369 */
367370 def initialize (
368- instances : RDD [GLRInstance ],
371+ instances : RDD [OffsetInstance ],
369372 fitIntercept : Boolean ,
370373 regParam : Double ): WeightedLeastSquaresModel = {
371374 val newInstances = instances.map { instance =>
@@ -384,11 +387,11 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
384387 * The reweight function used to update offsets and weights
385388 * at each iteration of [[IterativelyReweightedLeastSquares ]].
386389 */
387- val reweightFunc : (GLRInstance , WeightedLeastSquaresModel ) => (Double , Double ) = {
388- (instance : GLRInstance , model : WeightedLeastSquaresModel ) => {
389- val eta = model.predict(instance.features) + instance.offset
390- val mu = fitted(eta)
391- val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu)
390+ val reweightFunc : (OffsetInstance , WeightedLeastSquaresModel ) => (Double , Double ) = {
391+ (instance : OffsetInstance , model : WeightedLeastSquaresModel ) => {
392+ val eta = model.predict(instance.features)
393+ val mu = fitted(eta + instance.offset )
394+ val newLabel = eta + (instance.label - mu) * link.deriv(mu)
392395 val newWeight = instance.weight / (math.pow(this .link.deriv(mu), 2.0 ) * family.variance(mu))
393396 (newLabel, newWeight)
394397 }
@@ -766,7 +769,7 @@ class GeneralizedLinearRegressionModel private[ml] (
766769 val eta = BLAS .dot(features, coefficients) + intercept
767770 familyAndLink.fitted(eta)
768771 } else {
769- throw new SparkException (" Must supply offset value when offset is set." )
772+ throw new SparkException (" Must supply offset to predict when offset column is set." )
770773 }
771774 }
772775
@@ -1201,25 +1204,3 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
12011204 }
12021205 }
12031206}
1204-
1205- /**
1206- * Case class that represents an instance of data point with
1207- * label, weight, offset and features.
1208- *
1209- * @param label Label for this data point.
1210- * @param weight The weight of this instance.
1211- * @param offset The offset used for this data point.
1212- * @param features The vector of features for this data point.
1213- */
1214- private [ml] case class GLRInstance (label : Double , weight : Double , offset : Double ,
1215- features : Vector ) {
1216-
1217- /** Constructs from an [[Instance ]] object and offset */
1218- def this (instance : Instance , offset : Double = 0.0 ) = {
1219- this (instance.label, instance.weight, offset, instance.features)
1220- }
1221-
1222- /** Converts to an [[Instance ]] object by leaving out the offset. */
1223- private [ml] def toInstance : Instance = Instance (label, weight, features)
1224-
1225- }
0 commit comments