Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param regParam regularization parameter for L2 regularization.
#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
Expand Down Expand Up @@ -171,7 +172,8 @@ predict_internal <- function(object, newData) {
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
regParam = 0.0) {
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
Expand All @@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),

jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
tol, as.integer(maxIter), as.character(weightCol))
tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})

Expand Down
4 changes: 4 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ test_that("spark.glm summary", {
expect_match(out[2], "Deviance Residuals:")
expect_true(any(grepl("AIC: 59.22", out)))

# Test spark.glm works with regularization parameter
regStats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species, regParam = 0.3))
expect_equal(regStats$aic, 136.7, tolerance = 1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this number computed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I just check the output of model stats

maybe there is a better way to test it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it should match the result of glmnet? Perhaps you can try the same example there or take a look at https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala#L307

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

though it's very likely that the result would not change :)


# binomial family
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper
link: String,
tol: Double,
maxIter: Int,
weightCol: String): GeneralizedLinearRegressionWrapper = {
weightCol: String,
regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)
Expand All @@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
Expand Down