Skip to content

Commit 39820a4

Browse files
committed
[SPARK-17241] add sanity check
1 parent dd7cb82 commit 39820a4

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,6 @@ test_that("spark.glm summary", {
9999
expect_match(out[2], "Deviance Residuals:")
100100
expect_true(any(grepl("AIC: 59.22", out)))
101101

102-
# Test spark.glm works with regularization parameter
103-
regStats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species, regParam = 0.3))
104-
expect_equal(regStats$aic, 136.7, tolerance = 1e-3)
105-
106102
# binomial family
107103
df <- suppressWarnings(createDataFrame(iris))
108104
training <- df[df$Species %in% c("versicolor", "virginica"), ]
@@ -152,6 +148,12 @@ test_that("spark.glm summary", {
152148
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
153149
baseSummary <- summary(baseModel)
154150
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
151+
152+
# Test spark.glm works with regularization parameter
153+
data <- as.data.frame(cbind(a1, a2, b))
154+
df <- suppressWarnings(createDataFrame(data))
155+
regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0))
156+
expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result
155157
})
156158

157159
test_that("spark.glm save/load", {

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite
10341034
.setFamily("gaussian")
10351035
.fit(datasetGaussianIdentity.as[LabeledPoint])
10361036
}
1037+
1038+
test("generalized linear regression: regularization parameter") {
1039+
/*
1040+
R code:
1041+
1042+
a1 <- c(0, 1, 2, 3)
1043+
a2 <- c(5, 2, 1, 3)
1044+
b <- c(1, 0, 1, 0)
1045+
data <- as.data.frame(cbind(a1, a2, b))
1046+
df <- suppressWarnings(createDataFrame(data))
1047+
1048+
for (regParam in c(0.0, 0.1, 1.0)) {
1049+
model <- spark.glm(df, b ~ a1 + a2, regParam = regParam)
1050+
print(as.vector(summary(model)$aic))
1051+
}
1052+
1053+
[1] 12.88188
1054+
[1] 12.92681
1055+
[1] 13.32836
1056+
*/
1057+
val dataset = spark.createDataFrame(Seq(
1058+
LabeledPoint(1, Vectors.dense(5, 0)),
1059+
LabeledPoint(0, Vectors.dense(2, 1)),
1060+
LabeledPoint(1, Vectors.dense(1, 2)),
1061+
LabeledPoint(0, Vectors.dense(3, 3))
1062+
))
1063+
val expected = Seq(12.88188, 12.92681, 13.32836)
1064+
1065+
var idx = 0
1066+
for (regParam <- Seq(0.0, 0.1, 1.0)) {
1067+
val trainer = new GeneralizedLinearRegression()
1068+
.setRegParam(regParam)
1069+
.setLabelCol("label")
1070+
.setFeaturesCol("features")
1071+
val model = trainer.fit(dataset)
1072+
val actual = model.summary.aic
1073+
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.")
1074+
idx += 1
1075+
}
1076+
}
10371077
}
10381078

10391079
object GeneralizedLinearRegressionSuite {

0 commit comments

Comments
 (0)