Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param regParam regularization parameter for L2 regularization.
#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
Expand Down Expand Up @@ -171,7 +172,8 @@ predict_internal <- function(object, newData) {
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
regParam = 0.0) {
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
Expand All @@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),

jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
tol, as.integer(maxIter), as.character(weightCol))
tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})

Expand Down
6 changes: 6 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ test_that("spark.glm summary", {
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)

# Test spark.glm works with regularization parameter
data <- as.data.frame(cbind(a1, a2, b))
df <- suppressWarnings(createDataFrame(data))
regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0))
expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result
})

test_that("spark.glm save/load", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper
link: String,
tol: Double,
maxIter: Int,
weightCol: String): GeneralizedLinearRegressionWrapper = {
weightCol: String,
regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)
Expand All @@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite
.setFamily("gaussian")
.fit(datasetGaussianIdentity.as[LabeledPoint])
}

test("generalized linear regression: regularization parameter") {
/*
R code:

a1 <- c(0, 1, 2, 3)
a2 <- c(5, 2, 1, 3)
b <- c(1, 0, 1, 0)
data <- as.data.frame(cbind(a1, a2, b))
df <- suppressWarnings(createDataFrame(data))

for (regParam in c(0.0, 0.1, 1.0)) {
model <- spark.glm(df, b ~ a1 + a2, regParam = regParam)
print(as.vector(summary(model)$aic))
}

[1] 12.88188
[1] 12.92681
[1] 13.32836
*/
val dataset = spark.createDataFrame(Seq(
LabeledPoint(1, Vectors.dense(5, 0)),
LabeledPoint(0, Vectors.dense(2, 1)),
LabeledPoint(1, Vectors.dense(1, 2)),
LabeledPoint(0, Vectors.dense(3, 3))
))
val expected = Seq(12.88188, 12.92681, 13.32836)

var idx = 0
for (regParam <- Seq(0.0, 0.1, 1.0)) {
val trainer = new GeneralizedLinearRegression()
.setRegParam(regParam)
.setLabelCol("label")
.setFeaturesCol("features")
val model = trainer.fit(dataset)
val actual = model.summary.aic
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.")
idx += 1
}
}
}

object GeneralizedLinearRegressionSuite {
Expand Down