Skip to content

Commit 224e3e5

Browse files
yanbolianglw-lin
authored andcommitted
[SPARK-13925][ML][SPARKR] Expose R-like summary statistics in SparkR::glm for more family and link functions
## What changes were proposed in this pull request? Expose R-like summary statistics in SparkR::glm for more family and link functions. Note: Not all values in R [summary.glm](http://stat.ethz.ch/R-manual/R-patched/library/stats/html/summary.glm.html) are exposed, we only provide the most commonly used statistics in this PR. More statistics can be added in the followup work. ## How was this patch tested? Unit tests. SparkR Output: ``` Deviance Residuals: (Note: These are approximate quantiles with relative error <= 0.01) Min 1Q Median 3Q Max -0.95096 -0.16585 -0.00232 0.17410 0.72918 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) 1.6765 0.23536 7.1231 4.4561e-11 Sepal_Length 0.34988 0.046301 7.5566 4.1873e-12 Species_versicolor -0.98339 0.072075 -13.644 0 Species_virginica -1.0075 0.093306 -10.798 0 (Dispersion parameter for gaussian family taken to be 0.08351462) Null deviance: 28.307 on 149 degrees of freedom Residual deviance: 12.193 on 146 degrees of freedom AIC: 59.22 Number of Fisher Scoring iterations: 1 ``` R output: ``` Deviance Residuals: Min 1Q Median 3Q Max -0.95096 -0.16522 0.00171 0.18416 0.72918 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) 1.67650 0.23536 7.123 4.46e-11 *** Sepal.Length 0.34988 0.04630 7.557 4.19e-12 *** Speciesversicolor -0.98339 0.07207 -13.644 < 2e-16 *** Speciesvirginica -1.00751 0.09331 -10.798 < 2e-16 *** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 (Dispersion parameter for gaussian family taken to be 0.08351462) Null deviance: 28.307 on 149 degrees of freedom Residual deviance: 12.193 on 146 degrees of freedom AIC: 59.217 Number of Fisher Scoring iterations: 2 ``` cc mengxr Author: Yanbo Liang <[email protected]> Closes apache#12393 from yanboliang/spark-13925.
1 parent dfcae9f commit 224e3e5

File tree

4 files changed

+143
-10
lines changed

4 files changed

+143
-10
lines changed

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ export("as.DataFrame",
292292
"tableToDF",
293293
"tableNames",
294294
"tables",
295-
"uncacheTable")
295+
"uncacheTable",
296+
"print.summary.GeneralizedLinearRegressionModel")
296297

297298
export("structField",
298299
"structField.jobj",

R/pkg/R/mllib.R

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,55 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
101101
jobj <- object@jobj
102102
features <- callJMethod(jobj, "rFeatures")
103103
coefficients <- callJMethod(jobj, "rCoefficients")
104-
coefficients <- as.matrix(unlist(coefficients))
105-
colnames(coefficients) <- c("Estimate")
104+
deviance.resid <- callJMethod(jobj, "rDevianceResiduals")
105+
dispersion <- callJMethod(jobj, "rDispersion")
106+
null.deviance <- callJMethod(jobj, "rNullDeviance")
107+
deviance <- callJMethod(jobj, "rDeviance")
108+
df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull")
109+
df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom")
110+
aic <- callJMethod(jobj, "rAic")
111+
iter <- callJMethod(jobj, "rNumIterations")
112+
family <- callJMethod(jobj, "rFamily")
113+
114+
deviance.resid <- dataFrame(deviance.resid)
115+
coefficients <- matrix(coefficients, ncol = 4)
116+
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
106117
rownames(coefficients) <- unlist(features)
107-
return(list(coefficients = coefficients))
118+
ans <- list(deviance.resid = deviance.resid, coefficients = coefficients,
119+
dispersion = dispersion, null.deviance = null.deviance,
120+
deviance = deviance, df.null = df.null, df.residual = df.residual,
121+
aic = aic, iter = iter, family = family)
122+
class(ans) <- "summary.GeneralizedLinearRegressionModel"
123+
return(ans)
108124
})
109125

126+
#' Print the summary of GeneralizedLinearRegressionModel
127+
#'
128+
#' @rdname print
129+
#' @name print.summary.GeneralizedLinearRegressionModel
130+
#' @export
131+
print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
132+
x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
133+
c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max"))
134+
x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
135+
cat("\nDeviance Residuals: \n")
136+
cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
137+
print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
138+
139+
cat("\nCoefficients:\n")
140+
print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L)
141+
142+
cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion),
143+
")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"),
144+
format(unlist(x[c("null.deviance", "deviance")]), digits = 5L),
145+
" on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"),
146+
1L, paste, collapse = " "), sep = "")
147+
cat("AIC: ", format(x$aic, digits = 4L), "\n\n",
148+
"Number of Fisher Scoring iterations: ", x$iter, "\n", sep = "")
149+
cat("\n")
150+
invisible(x)
151+
}
152+
110153
#' Make predictions from a generalized linear model
111154
#'
112155
#' Makes predictions from a generalized linear model produced by glm(), similarly to R's predict().

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,55 @@ test_that("glm and predict", {
7777
expect_equal(length(predict(lm(y ~ x))), 15)
7878
})
7979

80+
test_that("glm summary", {
81+
# gaussian family
82+
training <- suppressWarnings(createDataFrame(sqlContext, iris))
83+
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
84+
85+
rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
86+
87+
coefs <- unlist(stats$coefficients)
88+
rCoefs <- unlist(rStats$coefficients)
89+
expect_true(all(abs(rCoefs - coefs) < 1e-4))
90+
expect_true(all(
91+
rownames(stats$coefficients) ==
92+
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
93+
expect_equal(stats$dispersion, rStats$dispersion)
94+
expect_equal(stats$null.deviance, rStats$null.deviance)
95+
expect_equal(stats$deviance, rStats$deviance)
96+
expect_equal(stats$df.null, rStats$df.null)
97+
expect_equal(stats$df.residual, rStats$df.residual)
98+
expect_equal(stats$aic, rStats$aic)
99+
100+
# binomial family
101+
df <- suppressWarnings(createDataFrame(sqlContext, iris))
102+
training <- df[df$Species %in% c("versicolor", "virginica"), ]
103+
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
104+
family = binomial(link = "logit")))
105+
106+
rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
107+
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
108+
family = binomial(link = "logit")))
109+
110+
coefs <- unlist(stats$coefficients)
111+
rCoefs <- unlist(rStats$coefficients)
112+
expect_true(all(abs(rCoefs - coefs) < 1e-4))
113+
expect_true(all(
114+
rownames(stats$coefficients) ==
115+
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
116+
expect_equal(stats$dispersion, rStats$dispersion)
117+
expect_equal(stats$null.deviance, rStats$null.deviance)
118+
expect_equal(stats$deviance, rStats$deviance)
119+
expect_equal(stats$df.null, rStats$df.null)
120+
expect_equal(stats$df.residual, rStats$df.residual)
121+
expect_equal(stats$aic, rStats$aic)
122+
123+
# Test summary works on base GLM models
124+
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
125+
baseSummary <- summary(baseModel)
126+
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
127+
})
128+
80129
test_that("kmeans", {
81130
newIris <- iris
82131
newIris$Species <- NULL

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,59 @@ private[r] class GeneralizedLinearRegressionWrapper private (
3030
private val glm: GeneralizedLinearRegressionModel =
3131
pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
3232

33+
lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
34+
Array("(Intercept)") ++ features
35+
} else {
36+
features
37+
}
38+
3339
lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
34-
Array(glm.intercept) ++ glm.coefficients.toArray
40+
Array(glm.intercept) ++ glm.coefficients.toArray ++
41+
rCoefficientStandardErrors ++ rTValues ++ rPValues
3542
} else {
36-
glm.coefficients.toArray
43+
glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
3744
}
3845

39-
lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
40-
Array("(Intercept)") ++ features
46+
private lazy val rCoefficientStandardErrors = if (glm.getFitIntercept) {
47+
Array(glm.summary.coefficientStandardErrors.last) ++
48+
glm.summary.coefficientStandardErrors.dropRight(1)
4149
} else {
42-
features
50+
glm.summary.coefficientStandardErrors
51+
}
52+
53+
private lazy val rTValues = if (glm.getFitIntercept) {
54+
Array(glm.summary.tValues.last) ++ glm.summary.tValues.dropRight(1)
55+
} else {
56+
glm.summary.tValues
4357
}
4458

45-
def transform(dataset: DataFrame): DataFrame = {
59+
private lazy val rPValues = if (glm.getFitIntercept) {
60+
Array(glm.summary.pValues.last) ++ glm.summary.pValues.dropRight(1)
61+
} else {
62+
glm.summary.pValues
63+
}
64+
65+
lazy val rDispersion: Double = glm.summary.dispersion
66+
67+
lazy val rNullDeviance: Double = glm.summary.nullDeviance
68+
69+
lazy val rDeviance: Double = glm.summary.deviance
70+
71+
lazy val rResidualDegreeOfFreedomNull: Long = glm.summary.residualDegreeOfFreedomNull
72+
73+
lazy val rResidualDegreeOfFreedom: Long = glm.summary.residualDegreeOfFreedom
74+
75+
lazy val rAic: Double = glm.summary.aic
76+
77+
lazy val rNumIterations: Int = glm.summary.numIterations
78+
79+
lazy val rDevianceResiduals: DataFrame = glm.summary.residuals()
80+
81+
lazy val rFamily: String = glm.getFamily
82+
83+
def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType)
84+
85+
def transform(dataset: Dataset[_]): DataFrame = {
4686
pipeline.transform(dataset).drop(glm.getFeaturesCol)
4787
}
4888
}

0 commit comments

Comments
 (0)