Skip to content

Commit d44974c

Browse files
committed
rename to OffsetInstance and add param check
1 parent 9eca1a6 commit d44974c

File tree

5 files changed

+52
-51
lines changed

5 files changed

+52
-51
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,25 @@ import org.apache.spark.ml.linalg.Vector
2727
* @param features The vector of features for this data point.
2828
*/
2929
private[ml] case class Instance(label: Double, weight: Double, features: Vector)
30+
31+
/**
32+
* Case class that represents an instance of data point with
33+
* label, weight, offset and features.
34+
*
35+
* @param label Label for this data point.
36+
* @param weight The weight of this instance.
37+
* @param offset The offset used for this data point.
38+
* @param features The vector of features for this data point.
39+
*/
40+
private[ml] case class OffsetInstance(label: Double, weight: Double, offset: Double,
41+
features: Vector) {
42+
43+
/** Constructs from an [[Instance]] object and offset */
44+
def this(instance: Instance, offset: Double = 0.0) = {
45+
this(instance.label, instance.weight, offset, instance.features)
46+
}
47+
48+
/** Converts to an [[Instance]] object by leaving out the offset. */
49+
private[ml] def toInstance: Instance = Instance(label, weight, features)
50+
51+
}

mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
package org.apache.spark.ml.optim
1919

2020
import org.apache.spark.internal.Logging
21-
import org.apache.spark.ml.feature.Instance
21+
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
2222
import org.apache.spark.ml.linalg._
23-
import org.apache.spark.ml.regression.GLRInstance
2423
import org.apache.spark.rdd.RDD
2524

2625
/**
@@ -58,13 +57,13 @@ private[ml] class IterativelyReweightedLeastSquaresModel(
5857
*/
5958
private[ml] class IterativelyReweightedLeastSquares(
6059
val initialModel: WeightedLeastSquaresModel,
61-
val reweightFunc: (GLRInstance, WeightedLeastSquaresModel) => (Double, Double),
60+
val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double),
6261
val fitIntercept: Boolean,
6362
val regParam: Double,
6463
val maxIter: Int,
6564
val tol: Double) extends Logging with Serializable {
6665

67-
def fit(instances: RDD[GLRInstance]): IterativelyReweightedLeastSquaresModel = {
66+
def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = {
6867

6968
var converged = false
7069
var iter = 0

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.SparkException
2424
import org.apache.spark.annotation.{Experimental, Since}
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.ml.PredictorParams
27-
import org.apache.spark.ml.feature.Instance
27+
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
2828
import org.apache.spark.ml.linalg.{BLAS, Vector}
2929
import org.apache.spark.ml.optim._
3030
import org.apache.spark.ml.param._
@@ -123,6 +123,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
123123
s"with ${$(family)} family does not support ${$(link)} link function.")
124124
}
125125
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
126+
if (isSet(offsetCol) && $(offsetCol).nonEmpty) {
127+
SchemaUtils.checkNumericType(schema, $(offsetCol))
128+
}
126129
if (hasLinkPredictionCol) {
127130
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
128131
} else {
@@ -286,16 +289,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
286289

287290
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
288291
val off = if (!isDefined(offsetCol) || $(offsetCol).isEmpty) lit(0.0) else col($(offsetCol))
289-
val instances: RDD[GLRInstance] =
292+
val instances: RDD[OffsetInstance] =
290293
dataset.select(col($(labelCol)), w, off, col($(featuresCol))).rdd.map {
291294
case Row(label: Double, weight: Double, offset: Double, features: Vector) =>
292-
new GLRInstance(label, weight, offset, features)
295+
OffsetInstance(label, weight, offset, features)
293296
}
294297

295298
val model = if (familyObj == Gaussian && linkObj == Identity) {
296299
// TODO: Make standardizeFeatures and standardizeLabel configurable.
297300
val wlsInstances: RDD[Instance] = instances.map { instance =>
298-
new Instance(instance.label - instance.offset, instance.weight, instance.features)
301+
Instance(instance.label - instance.offset, instance.weight, instance.features)
299302
}
300303
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0,
301304
standardizeFeatures = true, standardizeLabel = true)
@@ -365,7 +368,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
365368
* Get the initial guess model for [[IterativelyReweightedLeastSquares]].
366369
*/
367370
def initialize(
368-
instances: RDD[GLRInstance],
371+
instances: RDD[OffsetInstance],
369372
fitIntercept: Boolean,
370373
regParam: Double): WeightedLeastSquaresModel = {
371374
val newInstances = instances.map { instance =>
@@ -384,11 +387,11 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
384387
* The reweight function used to update offsets and weights
385388
* at each iteration of [[IterativelyReweightedLeastSquares]].
386389
*/
387-
val reweightFunc: (GLRInstance, WeightedLeastSquaresModel) => (Double, Double) = {
388-
(instance: GLRInstance, model: WeightedLeastSquaresModel) => {
389-
val eta = model.predict(instance.features) + instance.offset
390-
val mu = fitted(eta)
391-
val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu)
390+
val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = {
391+
(instance: OffsetInstance, model: WeightedLeastSquaresModel) => {
392+
val eta = model.predict(instance.features)
393+
val mu = fitted(eta + instance.offset)
394+
val newLabel = eta + (instance.label - mu) * link.deriv(mu)
392395
val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
393396
(newLabel, newWeight)
394397
}
@@ -766,7 +769,7 @@ class GeneralizedLinearRegressionModel private[ml] (
766769
val eta = BLAS.dot(features, coefficients) + intercept
767770
familyAndLink.fitted(eta)
768771
} else {
769-
throw new SparkException("Must supply offset value when offset is set.")
772+
throw new SparkException("Must supply offset to predict when offset column is set.")
770773
}
771774
}
772775

@@ -1201,25 +1204,3 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
12011204
}
12021205
}
12031206
}
1204-
1205-
/**
1206-
* Case class that represents an instance of data point with
1207-
* label, weight, offset and features.
1208-
*
1209-
* @param label Label for this data point.
1210-
* @param weight The weight of this instance.
1211-
* @param offset The offset used for this data point.
1212-
* @param features The vector of features for this data point.
1213-
*/
1214-
private[ml] case class GLRInstance(label: Double, weight: Double, offset: Double,
1215-
features: Vector) {
1216-
1217-
/** Constructs from an [[Instance]] object and offset */
1218-
def this(instance: Instance, offset: Double = 0.0) = {
1219-
this(instance.label, instance.weight, offset, instance.features)
1220-
}
1221-
1222-
/** Converts to an [[Instance]] object by leaving out the offset. */
1223-
private[ml] def toInstance: Instance = Instance(label, weight, features)
1224-
1225-
}

mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,16 @@
1818
package org.apache.spark.ml.optim
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.ml.feature.Instance
21+
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
2222
import org.apache.spark.ml.linalg.Vectors
23-
import org.apache.spark.ml.regression.GLRInstance
2423
import org.apache.spark.ml.util.TestingUtils._
2524
import org.apache.spark.mllib.util.MLlibTestSparkContext
2625
import org.apache.spark.rdd.RDD
2726

2827
class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
2928

30-
private var instances1: RDD[GLRInstance] = _
31-
private var instances2: RDD[GLRInstance] = _
29+
private var instances1: RDD[OffsetInstance] = _
30+
private var instances2: RDD[OffsetInstance] = _
3231

3332
override def beforeAll(): Unit = {
3433
super.beforeAll()
@@ -44,7 +43,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
4443
Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
4544
Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
4645
Instance(0.0, 4.0, Vectors.dense(3.0, 3.0))
47-
), 2).map(new GLRInstance(_))
46+
), 2).map(new OffsetInstance(_))
4847
/*
4948
R code:
5049
@@ -57,7 +56,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
5756
Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
5857
Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
5958
Instance(9.0, 4.0, Vectors.dense(3.0, 13.0))
60-
), 2).map(new GLRInstance(_))
59+
), 2).map(new OffsetInstance(_))
6160
}
6261

6362
test("IRLS against GLM with Binomial errors") {
@@ -170,7 +169,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
170169
object IterativelyReweightedLeastSquaresSuite {
171170

172171
def BinomialReweightFunc(
173-
instance: GLRInstance,
172+
instance: OffsetInstance,
174173
model: WeightedLeastSquaresModel): (Double, Double) = {
175174
val eta = model.predict(instance.features) + instance.offset
176175
val mu = 1.0 / (1.0 + math.exp(-1.0 * eta))
@@ -180,7 +179,7 @@ object IterativelyReweightedLeastSquaresSuite {
180179
}
181180

182181
def PoissonReweightFunc(
183-
instance: GLRInstance,
182+
instance: OffsetInstance,
184183
model: WeightedLeastSquaresModel): (Double, Double) = {
185184
val eta = model.predict(instance.features) + instance.offset
186185
val mu = math.exp(eta)
@@ -190,7 +189,7 @@ object IterativelyReweightedLeastSquaresSuite {
190189
}
191190

192191
def L1RegressionReweightFunc(
193-
instance: GLRInstance,
192+
instance: OffsetInstance,
194193
model: WeightedLeastSquaresModel): (Double, Double) = {
195194
val eta = model.predict(instance.features) + instance.offset
196195
val e = math.max(math.abs(eta - instance.label), 1e-7)

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.util.Random
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.ml.classification.LogisticRegressionSuite._
24-
import org.apache.spark.ml.feature.Instance
24+
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
2525
import org.apache.spark.ml.feature.LabeledPoint
2626
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
2727
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
@@ -604,10 +604,10 @@ class GeneralizedLinearRegressionSuite
604604
[1] -0.27378146 0.31599396 -0.06204946
605605
*/
606606
val dataset = Seq(
607-
GLRInstance(1.0, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
608-
GLRInstance(2.0, 2.0, 0.5, Vectors.dense(1.0, 2.0)),
609-
GLRInstance(1.0, 3.0, 1.0, Vectors.dense(2.0, 1.0)),
610-
GLRInstance(2.0, 4.0, 0.0, Vectors.dense(3.0, 3.0))
607+
OffsetInstance(1.0, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
608+
OffsetInstance(2.0, 2.0, 0.5, Vectors.dense(1.0, 2.0)),
609+
OffsetInstance(1.0, 3.0, 1.0, Vectors.dense(2.0, 1.0)),
610+
OffsetInstance(2.0, 4.0, 0.0, Vectors.dense(3.0, 3.0))
611611
).toDF()
612612

613613
val expected = Seq(

0 commit comments

Comments
 (0)