Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ predict_internal <- function(object, newData) {
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param tol positive convergence tolerance of iterations.
#' @param regParam regularization parameter for L2 regularization.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
Expand Down Expand Up @@ -171,7 +172,8 @@ predict_internal <- function(object, newData) {
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
function(data, formula, family = gaussian, tol = 1e-6, regParam = 0.0, maxIter = 25,
weightCol = NULL) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can add that to the end of the argument list so that it doesn't break the existing calls to the function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the fit() method of the wrapper, as long as the parameter order matches, it's ok.

I've tested it already in R terminal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If say an R user call the function by spark.glm(df, label ~ feature, gaussian, 1e-6, 25). This will break their code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - we should try to avoid breaking existing caller

if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
Expand All @@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),

jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
tol, as.integer(maxIter), as.character(weightCol))
tol, regParam, as.integer(maxIter), as.character(weightCol))
new("GeneralizedLinearRegressionModel", jobj = jobj)
})

Expand Down
4 changes: 4 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ test_that("spark.glm summary", {
expect_match(out[2], "Deviance Residuals:")
expect_true(any(grepl("AIC: 59.22", out)))

# Test spark.glm works with regularization parameter
regStats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species, regParam = 0.3))
expect_equal(regStats$aic, 136.7, tolerance = 1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this number computed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I just check the output of model stats

maybe there is a better way to test it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it should match the result of glmnet? Perhaps you can try the same example there or take a look at https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala#L307

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

though it's very likely that the result would not change :)


# binomial family
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private[r] object GeneralizedLinearRegressionWrapper
family: String,
link: String,
tol: Double,
regParam: Double,
maxIter: Int,
weightCol: String): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
Expand All @@ -84,6 +85,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setLink(link)
.setFitIntercept(rFormula.hasIntercept)
.setTol(tol)
.setRegParam(regParam)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
val pipeline = new Pipeline()
Expand Down