Skip to content

Commit 4852b7d

Browse files
actuaryzhangyanboliang
authored andcommitted
[SPARK-21310][ML][PYSPARK] Expose offset in PySpark
## What changes were proposed in this pull request? Add offset to PySpark in GLM as in #16699. ## How was this patch tested? Python test Author: actuaryzhang <[email protected]> Closes #18534 from actuaryzhang/pythonOffset.
1 parent a386432 commit 4852b7d

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

python/pyspark/ml/regression.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,17 +1376,20 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
13761376
typeConverter=TypeConverters.toFloat)
13771377
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
13781378
"options: irls.", typeConverter=TypeConverters.toString)
1379+
offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " +
1380+
"or empty, we treat all instance offsets as 0.0",
1381+
typeConverter=TypeConverters.toString)
13791382

13801383
@keyword_only
13811384
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
13821385
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
13831386
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1384-
variancePower=0.0, linkPower=None):
1387+
variancePower=0.0, linkPower=None, offsetCol=None):
13851388
"""
13861389
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
13871390
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
13881391
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1389-
variancePower=0.0, linkPower=None)
1392+
variancePower=0.0, linkPower=None, offsetCol=None)
13901393
"""
13911394
super(GeneralizedLinearRegression, self).__init__()
13921395
self._java_obj = self._new_java_obj(
@@ -1402,12 +1405,12 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred
14021405
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
14031406
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
14041407
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1405-
variancePower=0.0, linkPower=None):
1408+
variancePower=0.0, linkPower=None, offsetCol=None):
14061409
"""
14071410
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
14081411
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
14091412
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1410-
variancePower=0.0, linkPower=None)
1413+
variancePower=0.0, linkPower=None, offsetCol=None)
14111414
Sets params for generalized linear regression.
14121415
"""
14131416
kwargs = self._input_kwargs
@@ -1486,6 +1489,20 @@ def getLinkPower(self):
14861489
"""
14871490
return self.getOrDefault(self.linkPower)
14881491

1492+
@since("2.3.0")
1493+
def setOffsetCol(self, value):
1494+
"""
1495+
Sets the value of :py:attr:`offsetCol`.
1496+
"""
1497+
return self._set(offsetCol=value)
1498+
1499+
@since("2.3.0")
1500+
def getOffsetCol(self):
1501+
"""
1502+
Gets the value of offsetCol or its default value.
1503+
"""
1504+
return self.getOrDefault(self.offsetCol)
1505+
14891506

14901507
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
14911508
JavaMLReadable):

python/pyspark/ml/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,20 @@ def test_tweedie_distribution(self):
12911291
self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
12921292
self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
12931293

1294+
def test_offset(self):
1295+
1296+
df = self.spark.createDataFrame(
1297+
[(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
1298+
(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)),
1299+
(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)),
1300+
(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"])
1301+
1302+
glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset")
1303+
model = glr.fit(df)
1304+
self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581],
1305+
atol=1E-4))
1306+
self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4))
1307+
12941308

12951309
class FPGrowthTests(SparkSessionTestCase):
12961310
def setUp(self):

0 commit comments

Comments
 (0)