Skip to content

Commit 7a5000f

Browse files
keypointtshivaram
authored andcommitted
[SPARK-17241][SPARKR][MLLIB] SparkR spark.glm should have configurable regularization parameter
https://issues.apache.org/jira/browse/SPARK-17241 ## What changes were proposed in this pull request? Spark has configurable L2 regularization parameter for generalized linear regression. It is very important to have them in SparkR so that users can run ridge regression. ## How was this patch tested? Test manually on local laptop. Author: Xin Ren <[email protected]> Closes #14856 from keypointt/SPARK-17241.
1 parent d008638 commit 7a5000f

File tree

4 files changed

+55
-5
lines changed

4 files changed

+55
-5
lines changed

R/pkg/R/mllib.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,11 @@ predict_internal <- function(object, newData) {
138138
#' This can be a character string naming a family function, a family function or
139139
#' the result of a call to a family function. Refer R family at
140140
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
141-
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
142-
#' weights as 1.0.
143141
#' @param tol positive convergence tolerance of iterations.
144142
#' @param maxIter integer giving the maximal number of IRLS iterations.
143+
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
144+
#' weights as 1.0.
145+
#' @param regParam regularization parameter for L2 regularization.
145146
#' @param ... additional arguments passed to the method.
146147
#' @aliases spark.glm,SparkDataFrame,formula-method
147148
#' @return \code{spark.glm} returns a fitted generalized linear model
@@ -171,7 +172,8 @@ predict_internal <- function(object, newData) {
171172
#' @note spark.glm since 2.0.0
172173
#' @seealso \link{glm}, \link{read.ml}
173174
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
174-
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
175+
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
176+
regParam = 0.0) {
175177
if (is.character(family)) {
176178
family <- get(family, mode = "function", envir = parent.frame())
177179
}
@@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
190192

191193
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
192194
"fit", formula, data@sdf, family$family, family$link,
193-
tol, as.integer(maxIter), as.character(weightCol))
195+
tol, as.integer(maxIter), as.character(weightCol), regParam)
194196
new("GeneralizedLinearRegressionModel", jobj = jobj)
195197
})
196198

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ test_that("spark.glm summary", {
148148
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
149149
baseSummary <- summary(baseModel)
150150
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
151+
152+
# Test spark.glm works with regularization parameter
153+
data <- as.data.frame(cbind(a1, a2, b))
154+
df <- suppressWarnings(createDataFrame(data))
155+
regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0))
156+
expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result
151157
})
152158

153159
test_that("spark.glm save/load", {

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper
6969
link: String,
7070
tol: Double,
7171
maxIter: Int,
72-
weightCol: String): GeneralizedLinearRegressionWrapper = {
72+
weightCol: String,
73+
regParam: Double): GeneralizedLinearRegressionWrapper = {
7374
val rFormula = new RFormula()
7475
.setFormula(formula)
7576
val rFormulaModel = rFormula.fit(data)
@@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper
8687
.setTol(tol)
8788
.setMaxIter(maxIter)
8889
.setWeightCol(weightCol)
90+
.setRegParam(regParam)
8991
val pipeline = new Pipeline()
9092
.setStages(Array(rFormulaModel, glr))
9193
.fit(data)

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite
10341034
.setFamily("gaussian")
10351035
.fit(datasetGaussianIdentity.as[LabeledPoint])
10361036
}
1037+
1038+
test("generalized linear regression: regularization parameter") {
1039+
/*
1040+
R code:
1041+
1042+
a1 <- c(0, 1, 2, 3)
1043+
a2 <- c(5, 2, 1, 3)
1044+
b <- c(1, 0, 1, 0)
1045+
data <- as.data.frame(cbind(a1, a2, b))
1046+
df <- suppressWarnings(createDataFrame(data))
1047+
1048+
for (regParam in c(0.0, 0.1, 1.0)) {
1049+
model <- spark.glm(df, b ~ a1 + a2, regParam = regParam)
1050+
print(as.vector(summary(model)$aic))
1051+
}
1052+
1053+
[1] 12.88188
1054+
[1] 12.92681
1055+
[1] 13.32836
1056+
*/
1057+
val dataset = spark.createDataFrame(Seq(
1058+
LabeledPoint(1, Vectors.dense(5, 0)),
1059+
LabeledPoint(0, Vectors.dense(2, 1)),
1060+
LabeledPoint(1, Vectors.dense(1, 2)),
1061+
LabeledPoint(0, Vectors.dense(3, 3))
1062+
))
1063+
val expected = Seq(12.88188, 12.92681, 13.32836)
1064+
1065+
var idx = 0
1066+
for (regParam <- Seq(0.0, 0.1, 1.0)) {
1067+
val trainer = new GeneralizedLinearRegression()
1068+
.setRegParam(regParam)
1069+
.setLabelCol("label")
1070+
.setFeaturesCol("features")
1071+
val model = trainer.fit(dataset)
1072+
val actual = model.summary.aic
1073+
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.")
1074+
idx += 1
1075+
}
1076+
}
10371077
}
10381078

10391079
object GeneralizedLinearRegressionSuite {

0 commit comments

Comments
 (0)