Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ exportMethods("glm",
"spark.isoreg",
"spark.gaussianMixture",
"spark.als",
"spark.kstest")
"spark.kstest",
"spark.decisionTree")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down Expand Up @@ -347,7 +348,9 @@ export("as.DataFrame",
"uncacheTable",
"print.summary.GeneralizedLinearRegressionModel",
"read.ml",
"print.summary.KSTest")
"print.summary.KSTest",
"print.summary.DecisionTreeRegressionModel",
"print.summary.DecisionTreeClassificationModel")

export("structField",
"structField.jobj",
Expand All @@ -372,6 +375,8 @@ S3method(print, structField)
S3method(print, structType)
S3method(print, summary.GeneralizedLinearRegressionModel)
S3method(print, summary.KSTest)
S3method(print, summary.DecisionTreeRegressionModel)
S3method(print, summary.DecisionTreeClassificationModel)
S3method(structField, character)
S3method(structField, jobj)
S3method(structType, jobj)
Expand Down
5 changes: 5 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,11 @@ setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.p
#' @export
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })

#' @rdname spark.decisionTree
#' @export
setGeneric("spark.decisionTree",
function(data, formula, ...) { standardGeneric("spark.decisionTree") })

#' @rdname spark.gaussianMixture
#' @export
setGeneric("spark.gaussianMixture",
Expand Down
210 changes: 206 additions & 4 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ setClass("ALSModel", representation(jobj = "jobj"))
#' @note KSTest since 2.1.0
setClass("KSTest", representation(jobj = "jobj"))

#' S4 class that represents a DecisionTreeRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
#' @export
#' @note DecisionTreeRegressionModel since 2.1.0
setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))

#' S4 class that represents a DecisionTreeClassificationModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
#' @export
#' @note DecisionTreeClassificationModel since 2.1.0
setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))

#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
Expand All @@ -103,8 +117,9 @@ setClass("KSTest", representation(jobj = "jobj"))
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.als}, link{spark.decisionTree}, \link{spark.gaussianMixture},
#' @seealso \link{spark.isoreg}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.mlp},
#' @seealso \link{spark.naiveBayes}, \link{spark.survreg},
#' @seealso \link{read.ml}
NULL

Expand All @@ -116,8 +131,9 @@ NULL
#' @name predict
#' @export
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.als}, \link{spark.decisionTree}, \link{spark.gaussianMixture},
#' @seealso \link{spark.isoreg}, \link{spark.kmeans}, \link{spark.mlp}, \link{spark.naiveBayes},
#' @seealso \link{spark.survreg},
NULL

write_internal <- function(object, path, overwrite = FALSE) {
Expand Down Expand Up @@ -932,6 +948,10 @@ read.ml <- function(path) {
new("GaussianMixtureModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
new("ALSModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
new("DecisionTreeRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
new("DecisionTreeClassificationModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
Expand Down Expand Up @@ -1427,3 +1447,185 @@ print.summary.KSTest <- function(x, ...) {
cat(summaryStr, "\n")
invisible(x)
}

#' Decision Tree Model for Regression and Classification
#'
#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
#' save/load fitted models.
#' For more details, see \href{https://en.wikipedia.org/wiki/Decision_tree_learning}{Decision Tree}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you point this url to the Spark programming guide, like http://spark.apache.org/docs/latest/ml-classification-regression.html

#'
#' @param data a SparkDataFrame for training.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' @param type type of model to fit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add the types supported, eg. one of "regression" or "classification" as the type of model

#' @param maxDepth Maximum depth of the tree (>= 0).
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
#' how to split on features at each node. More bins give higher granularity. Must be
#' >= 2 and >= number of categories in any categorical feature. (default = 32)
#' @param ... additional arguments passed to the method.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it support other parameters, like numClasses, features, impurity?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or a future PR?

#' @aliases spark.decisionTree,SparkDataFrame,formula-method
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
#' @rdname spark.decisionTree
#' @name spark.decisionTree
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(longley)
#'
#' # fit a Decision Tree Regression Model
#' model <- spark.decisionTree(data, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
#'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add an example for "classification" too?

#' # get the summary of the model
#' summary(model)
#'
#' # make predictions
#' predictions <- predict(model, df)
#'
#' # save and load the model
#' path <- "path/to/model"
#' write.ml(model, path)
#' savedModel <- read.ml(path)
#' summary(savedModel)
#' }
#' @note spark.decisionTree since 2.1.0
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32 ) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space after 32 )

type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch(type,
regression = {
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins))
new("DecisionTreeRegressionModel", jobj = jobj)
},
classification = {
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins))
new("DecisionTreeClassificationModel", jobj = jobj)
}
)
})

# Makes predictions from a Decision Tree Regression model or
# a model produced by spark.decisionTree()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the Decision Tree Regression model produced by spark.decisionTree()? could you clarify?


#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
#' "prediction"
#' @rdname spark.decisionTree
#' @export
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @Aliases

#' @note predict(decisionTreeRegressionModel) since 2.1.0
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
function(object, newData) {
predict_internal(object, newData)
})

#' @rdname spark.decisionTree
#' @export
#' @note predict(decisionTreeClassificationModel) since 2.1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @aliases

setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
function(object, newData) {
predict_internal(object, newData)
})

#' Save the Decision Tree Regression model to the input path.
#'
#' @param object A fitted Decision tree regression model
#' @param path The directory where the model is saved
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
#' @rdname spark.decisionTreeRegression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be @rdname spark.decisionTree to match the other instances

#' @export
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' Save the Decision Tree Classification model to the input path.
#'
#' @param object A fitted Decision tree classification model
Copy link
Member

@felixcheung felixcheung Oct 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you check the output doc by running create-doc.sh - I think this will duplicate the object when the @rdname is changed - in that case, just have one instance of this and say "regression or classification model"

#' @param path The directory where the model is saved
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @aliases write.ml,DecisionTreeClassificationModel,character-method
#' @rdname spark.decisionTreeClassification
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to @rdname spark.decisionTree

#' @export
#' @note write.ml(DecisionTreeClassificationModel, character) since 2.1.0
setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

# Get the summary of an DecisionTreeRegressionModel model

#' @param object a fitted DecisionTreeRegressionModel or DecisionTreeClassificationModel
#' @return \code{summary} returns the model's features as lists, depth and number of nodes
#' or number of classes.
#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeRegressionModel-method
#' @export
#' @note summary(DecisionTreeRegressionModel) since 2.1.0
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
function(object, ...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not put ... in signature here

jobj <- object@jobj
features <- callJMethod(jobj, "features")
depth <- callJMethod(jobj, "depth")
numNodes <- callJMethod(jobj, "numNodes")
ans <- list(features = features, depth = depth, numNodes = numNodes, jobj = jobj)
class(ans) <- "summary.DecisionTreeRegressionModel"
ans
})

# Get the summary of an DecisionTreeClassificationModel model

#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeClassificationModel-method
#' @export
#' @note summary(DecisionTreeClassificationModel) since 2.1.0
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
function(object, ...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

jobj <- object@jobj
features <- callJMethod(jobj, "features")
depth <- callJMethod(jobj, "depth")
numNodes <- callJMethod(jobj, "numNodes")
numClasses <- callJMethod(jobj, "numClasses")
ans <- list(features = features, depth = depth,
numNodes = numNodes, numClasses = numClasses, jobj = jobj)
class(ans) <- "summary.DecisionTreeClassificationModel"
ans
})

# Prints the summary of Decision Tree Regression Model

#' @rdname spark.decisionTree
#' @param x summary object of decisionTreeRegressionModel or decisionTreeClassificationModel
#' returned by \code{summary}.
#' @export
#' @note print.summary.DecisionTreeRegressionModel since 2.1.0
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
jobj <- x$jobj
summaryStr <- callJMethod(jobj, "summary")
cat(summaryStr, "\n")
invisible(x)
}

# Prints the summary of Decision Tree Classification Model

#' @rdname spark.decisionTree
#' @export
#' @note print.summary.DecisionTreeClassificationModel since 2.1.0
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
jobj <- x$jobj
summaryStr <- callJMethod(jobj, "summary")
cat(summaryStr, "\n")
invisible(x)
}
55 changes: 55 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -791,4 +791,59 @@ test_that("spark.kstest", {
expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
})

test_that("spark.decisionTree Regression", {
data <- suppressWarnings(createDataFrame(longley))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a test for print (see spark.glm)

model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16)

#Test summary
stats <- summary(model)
expect_equal(stats$depth, 5)
expect_equal(stats$numNodes, 31)

#Test model predict
predictions <- collect(predict(model, data))
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
63.221, 63.639, 64.989, 63.761,
66.019, 67.857, 68.169, 66.513,
68.655, 69.564, 69.331, 70.551),
tolerance = 1e-4)

# Test model save/load
modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp")
write.ml(model, modelPath)
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
stats2 <- summary(model2)
expect_equal(stats$depth, stats2$depth)
expect_equal(stats$numNodes, stats2$numNodes)

unlink(modelPath)
})

test_that("spark.decisionTree Classification", {
data <- suppressWarnings(createDataFrame(iris))
model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)

#Test summary
stats <- summary(model)
expect_equal(stats$depth, 5)
expect_equal(stats$numNodes, 19)
expect_equal(stats$numClasses, 3)

# Test model save/load
modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp")
write.ml(model, modelPath)
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
stats2 <- summary(model2)
expect_equal(stats$depth, stats2$depth)
expect_equal(stats$numNodes, stats2$numNodes)
expect_equal(stats$numClasses, stats2$numClasses)

unlink(modelPath)
})

sparkR.session.stop()
Loading