Skip to content

Commit 61d220b

Browse files
committed
Merge remote-tracking branch 'origin/master' into off-heap-storage-memory-bookkeeping
2 parents c4d2aeb + 43b15e0 commit 61d220b

File tree

140 files changed

+3383
-1673
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

140 files changed

+3383
-1673
lines changed

R/pkg/DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ Depends:
1212
methods,
1313
Suggests:
1414
testthat,
15-
e1071
15+
e1071,
16+
survival
1617
Description: R frontend for Spark
1718
License: Apache License (== 2.0)
1819
Collate:

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ exportMethods("glm",
1616
"summary",
1717
"kmeans",
1818
"fitted",
19-
"naiveBayes")
19+
"naiveBayes",
20+
"survreg")
2021

2122
# Job group lifecycle management methods
2223
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,3 +1179,7 @@ setGeneric("fitted")
11791179
#' @rdname naiveBayes
11801180
#' @export
11811181
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
1182+
1183+
#' @rdname survreg
1184+
#' @export
1185+
setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })

R/pkg/R/mllib.R

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ setClass("PipelineModel", representation(model = "jobj"))
2727
#' @export
2828
setClass("NaiveBayesModel", representation(jobj = "jobj"))
2929

30+
#' @title S4 class that represents a AFTSurvivalRegressionModel
31+
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
32+
#' @export
33+
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
34+
3035
#' Fits a generalized linear model
3136
#'
3237
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -273,3 +278,73 @@ setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
273278
formula, data@sdf, laplace)
274279
return(new("NaiveBayesModel", jobj = jobj))
275280
})
281+
282+
#' Fit an accelerated failure time (AFT) survival regression model.
283+
#'
284+
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
285+
#'
286+
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
287+
#' operators are supported, including '~', ':', '+', and '-'.
288+
#' Note that operator '.' is not supported currently.
289+
#' @param data DataFrame for training.
290+
#' @return a fitted AFT survival regression model
291+
#' @rdname survreg
292+
#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
293+
#' @export
294+
#' @examples
295+
#' \dontrun{
296+
#' df <- createDataFrame(sqlContext, ovarian)
297+
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
298+
#' }
299+
setMethod("survreg", signature(formula = "formula", data = "DataFrame"),
300+
function(formula, data, ...) {
301+
formula <- paste(deparse(formula), collapse = "")
302+
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
303+
"fit", formula, data@sdf)
304+
return(new("AFTSurvivalRegressionModel", jobj = jobj))
305+
})
306+
307+
#' Get the summary of an AFT survival regression model
308+
#'
309+
#' Returns the summary of an AFT survival regression model produced by survreg(),
310+
#' similarly to R's summary().
311+
#'
312+
#' @param object a fitted AFT survival regression model
313+
#' @return coefficients the model's coefficients, intercept and log(scale).
314+
#' @rdname summary
315+
#' @export
316+
#' @examples
317+
#' \dontrun{
318+
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
319+
#' summary(model)
320+
#' }
321+
setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
322+
function(object, ...) {
323+
jobj <- object@jobj
324+
features <- callJMethod(jobj, "rFeatures")
325+
coefficients <- callJMethod(jobj, "rCoefficients")
326+
coefficients <- as.matrix(unlist(coefficients))
327+
colnames(coefficients) <- c("Value")
328+
rownames(coefficients) <- unlist(features)
329+
return(list(coefficients = coefficients))
330+
})
331+
332+
#' Make predictions from an AFT survival regression model
333+
#'
334+
#' Make predictions from a model produced by survreg(), similarly to R package survival's predict.
335+
#'
336+
#' @param object A fitted AFT survival regression model
337+
#' @param newData DataFrame for testing
338+
#' @return DataFrame containing predicted labels in a column named "prediction"
339+
#' @rdname predict
340+
#' @export
341+
#' @examples
342+
#' \dontrun{
343+
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
344+
#' predicted <- predict(model, testData)
345+
#' showDF(predicted)
346+
#' }
347+
setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
348+
function(object, newData) {
349+
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
350+
})

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,52 @@ test_that("naiveBayes", {
200200
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
201201
}
202202
})
203+
204+
test_that("survreg", {
205+
# R code to reproduce the result.
206+
#
207+
#' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
208+
#' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
209+
#' library(survival)
210+
#' model <- survreg(Surv(time, status) ~ x + sex, rData)
211+
#' summary(model)
212+
#' predict(model, data)
213+
#
214+
# -- output of 'summary(model)'
215+
#
216+
# Value Std. Error z p
217+
# (Intercept) 1.315 0.270 4.88 1.07e-06
218+
# x -0.190 0.173 -1.10 2.72e-01
219+
# sex -0.253 0.329 -0.77 4.42e-01
220+
# Log(scale) -1.160 0.396 -2.93 3.41e-03
221+
#
222+
# -- output of 'predict(model, data)'
223+
#
224+
# 1 2 3 4 5 6 7
225+
# 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
226+
#
227+
data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
228+
list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
229+
df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
230+
model <- survreg(Surv(time, status) ~ x + sex, df)
231+
stats <- summary(model)
232+
coefs <- as.vector(stats$coefficients[, 1])
233+
rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)
234+
expect_equal(coefs, rCoefs, tolerance = 1e-4)
235+
expect_true(all(
236+
rownames(stats$coefficients) ==
237+
c("(Intercept)", "x", "sex", "Log(scale)")))
238+
p <- collect(select(predict(model, df), "prediction"))
239+
expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
240+
2.390146, 2.891269, 2.891269), tolerance = 1e-4)
241+
242+
# Test survival::survreg
243+
if (requireNamespace("survival", quietly = TRUE)) {
244+
rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
245+
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
246+
expect_that(
247+
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
248+
not(throws_error()))
249+
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
250+
}
251+
})

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1817,7 +1817,8 @@ test_that("approxQuantile() on a DataFrame", {
18171817

18181818
test_that("SQL error message is returned from JVM", {
18191819
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
1820-
expect_equal(grepl("Table not found: blah", retError), TRUE)
1820+
expect_equal(grepl("Table not found", retError), TRUE)
1821+
expect_equal(grepl("blah", retError), TRUE)
18211822
})
18221823

18231824
irisDF <- suppressWarnings(createDataFrame(sqlContext, iris))

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ private[spark] class BlockManager(
746746
// We will drop it to disk later if the memory store can't hold it.
747747
val putSucceeded = if (level.deserialized) {
748748
val values = serializerManager.dataDeserialize(blockId, bytes)(classTag)
749-
memoryStore.putIterator(blockId, values, level, classTag) match {
749+
memoryStore.putIteratorAsValues(blockId, values, classTag) match {
750750
case Right(_) => true
751751
case Left(iter) =>
752752
// If putting deserialized values in memory failed, we will put the bytes directly to
@@ -876,21 +876,40 @@ private[spark] class BlockManager(
876876
if (level.useMemory) {
877877
// Put it in memory first, even if it also has useDisk set to true;
878878
// We will drop it to disk later if the memory store can't hold it.
879-
memoryStore.putIterator(blockId, iterator(), level, classTag) match {
880-
case Right(s) =>
881-
size = s
882-
case Left(iter) =>
883-
// Not enough space to unroll this block; drop to disk if applicable
884-
if (level.useDisk) {
885-
logWarning(s"Persisting block $blockId to disk instead.")
886-
diskStore.put(blockId) { fileOutputStream =>
887-
serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
879+
if (level.deserialized) {
880+
memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match {
881+
case Right(s) =>
882+
size = s
883+
case Left(iter) =>
884+
// Not enough space to unroll this block; drop to disk if applicable
885+
if (level.useDisk) {
886+
logWarning(s"Persisting block $blockId to disk instead.")
887+
diskStore.put(blockId) { fileOutputStream =>
888+
serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
889+
}
890+
size = diskStore.getSize(blockId)
891+
} else {
892+
iteratorFromFailedMemoryStorePut = Some(iter)
888893
}
889-
size = diskStore.getSize(blockId)
890-
} else {
891-
iteratorFromFailedMemoryStorePut = Some(iter)
892-
}
894+
}
895+
} else { // !level.deserialized
896+
memoryStore.putIteratorAsBytes(blockId, iterator(), classTag) match {
897+
case Right(s) =>
898+
size = s
899+
case Left(partiallySerializedValues) =>
900+
// Not enough space to unroll this block; drop to disk if applicable
901+
if (level.useDisk) {
902+
logWarning(s"Persisting block $blockId to disk instead.")
903+
diskStore.put(blockId) { fileOutputStream =>
904+
partiallySerializedValues.finishWritingToStream(fileOutputStream)
905+
}
906+
size = diskStore.getSize(blockId)
907+
} else {
908+
iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
909+
}
910+
}
893911
}
912+
894913
} else if (level.useDisk) {
895914
diskStore.put(blockId) { fileOutputStream =>
896915
serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
@@ -991,7 +1010,7 @@ private[spark] class BlockManager(
9911010
// Note: if we had a means to discard the disk iterator, we would do that here.
9921011
memoryStore.getValues(blockId).get
9931012
} else {
994-
memoryStore.putIterator(blockId, diskIterator, level, classTag) match {
1013+
memoryStore.putIteratorAsValues(blockId, diskIterator, classTag) match {
9951014
case Left(iter) =>
9961015
// The memory store put() failed, so it returned the iterator back to us:
9971016
iter

0 commit comments

Comments
 (0)