-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-15767][R][ML] Decision Tree Regression wrapper in SparkR #13690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bee4868
463f965
b18b718
9787219
d107ab9
d034735
0694f84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,20 @@ setClass("ALSModel", representation(jobj = "jobj")) | |
| #' @note KSTest since 2.1.0 | ||
| setClass("KSTest", representation(jobj = "jobj")) | ||
|
|
||
| #' S4 class that represents a DecisionTreeRegressionModel | ||
| #' | ||
| #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel | ||
| #' @export | ||
| #' @note DecisionTreeRegressionModel since 2.1.0 | ||
| setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) | ||
|
|
||
| #' S4 class that represents a DecisionTreeClassificationModel | ||
| #' | ||
| #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel | ||
| #' @export | ||
| #' @note DecisionTreeClassificationModel since 2.1.0 | ||
| setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) | ||
|
|
||
| #' Saves the MLlib model to the input path | ||
| #' | ||
| #' Saves the MLlib model to the input path. For more information, see the specific | ||
|
|
@@ -103,8 +117,9 @@ setClass("KSTest", representation(jobj = "jobj")) | |
| #' @name write.ml | ||
| #' @export | ||
| #' @seealso \link{spark.glm}, \link{glm}, | ||
| #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, | ||
| #' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} | ||
| #' @seealso \link{spark.als}, link{spark.decisionTree}, \link{spark.gaussianMixture}, | ||
| #' @seealso \link{spark.isoreg}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.mlp}, | ||
| #' @seealso \link{spark.naiveBayes}, \link{spark.survreg}, | ||
| #' @seealso \link{read.ml} | ||
| NULL | ||
|
|
||
|
|
@@ -116,8 +131,9 @@ NULL | |
| #' @name predict | ||
| #' @export | ||
| #' @seealso \link{spark.glm}, \link{glm}, | ||
| #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, | ||
| #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} | ||
| #' @seealso \link{spark.als}, \link{spark.decisionTree}, \link{spark.gaussianMixture}, | ||
| #' @seealso \link{spark.isoreg}, \link{spark.kmeans}, \link{spark.mlp}, \link{spark.naiveBayes}, | ||
| #' @seealso \link{spark.survreg}, | ||
| NULL | ||
|
|
||
| write_internal <- function(object, path, overwrite = FALSE) { | ||
|
|
@@ -932,6 +948,10 @@ read.ml <- function(path) { | |
| new("GaussianMixtureModel", jobj = jobj) | ||
| } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { | ||
| new("ALSModel", jobj = jobj) | ||
| } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) { | ||
| new("DecisionTreeRegressionModel", jobj = jobj) | ||
| } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) { | ||
| new("DecisionTreeClassificationModel", jobj = jobj) | ||
| } else { | ||
| stop("Unsupported model: ", jobj) | ||
| } | ||
|
|
@@ -1427,3 +1447,185 @@ print.summary.KSTest <- function(x, ...) { | |
| cat(summaryStr, "\n") | ||
| invisible(x) | ||
| } | ||
|
|
||
| #' Decision Tree Model for Regression and Classification | ||
| #' | ||
| #' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on | ||
| #' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree | ||
| #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to | ||
| #' save/load fitted models. | ||
| #' For more details, see \href{https://en.wikipedia.org/wiki/Decision_tree_learning}{Decision Tree} | ||
| #' | ||
| #' @param data a SparkDataFrame for training. | ||
| #' @param formula a symbolic description of the model to be fitted. Currently only a few formula | ||
| #' operators are supported, including '~', ':', '+', and '-'. | ||
| #' @param type type of model to fit | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add the types supported, eg. |
||
| #' @param maxDepth Maximum depth of the tree (>= 0). | ||
| #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing | ||
| #' how to split on features at each node. More bins give higher granularity. Must be | ||
| #' >= 2 and >= number of categories in any categorical feature. (default = 32) | ||
| #' @param ... additional arguments passed to the method. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should it support other parameters, like numClasses, features, impurity?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or a future PR? |
||
| #' @aliases spark.decisionTree,SparkDataFrame,formula-method | ||
| #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. | ||
| #' @rdname spark.decisionTree | ||
| #' @name spark.decisionTree | ||
| #' @export | ||
| #' @examples | ||
| #' \dontrun{ | ||
| #' df <- createDataFrame(longley) | ||
| #' | ||
| #' # fit a Decision Tree Regression Model | ||
| #' model <- spark.decisionTree(data, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) | ||
| #' | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add an example for "classification" too? |
||
| #' # get the summary of the model | ||
| #' summary(model) | ||
| #' | ||
| #' # make predictions | ||
| #' predictions <- predict(model, df) | ||
| #' | ||
| #' # save and load the model | ||
| #' path <- "path/to/model" | ||
| #' write.ml(model, path) | ||
| #' savedModel <- read.ml(path) | ||
| #' summary(savedModel) | ||
| #' } | ||
| #' @note spark.decisionTree since 2.1.0 | ||
| setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), | ||
| function(data, formula, type = c("regression", "classification"), | ||
| maxDepth = 5, maxBins = 32 ) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra space after |
||
| type <- match.arg(type) | ||
| formula <- paste(deparse(formula), collapse = "") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
| switch(type, | ||
| regression = { | ||
| jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", | ||
| "fit", data@sdf, formula, as.integer(maxDepth), | ||
| as.integer(maxBins)) | ||
| new("DecisionTreeRegressionModel", jobj = jobj) | ||
| }, | ||
| classification = { | ||
| jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", | ||
| "fit", data@sdf, formula, as.integer(maxDepth), | ||
| as.integer(maxBins)) | ||
| new("DecisionTreeClassificationModel", jobj = jobj) | ||
| } | ||
| ) | ||
| }) | ||
|
|
||
| # Makes predictions from a Decision Tree Regression model or | ||
| # a model produced by spark.decisionTree() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the |
||
|
|
||
| #' @param newData a SparkDataFrame for testing. | ||
| #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named | ||
| #' "prediction" | ||
| #' @rdname spark.decisionTree | ||
| #' @export | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add @Aliases |
||
| #' @note predict(decisionTreeRegressionModel) since 2.1.0 | ||
| setMethod("predict", signature(object = "DecisionTreeRegressionModel"), | ||
| function(object, newData) { | ||
| predict_internal(object, newData) | ||
| }) | ||
|
|
||
| #' @rdname spark.decisionTree | ||
| #' @export | ||
| #' @note predict(decisionTreeClassificationModel) since 2.1.0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
| setMethod("predict", signature(object = "DecisionTreeClassificationModel"), | ||
| function(object, newData) { | ||
| predict_internal(object, newData) | ||
| }) | ||
|
|
||
| #' Save the Decision Tree Regression model to the input path. | ||
| #' | ||
| #' @param object A fitted Decision tree regression model | ||
| #' @param path The directory where the model is saved | ||
| #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE | ||
| #' which means throw exception if the output path exists. | ||
| #' | ||
| #' @aliases write.ml,DecisionTreeRegressionModel,character-method | ||
| #' @rdname spark.decisionTreeRegression | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be |
||
| #' @export | ||
| #' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0 | ||
| setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), | ||
| function(object, path, overwrite = FALSE) { | ||
| write_internal(object, path, overwrite) | ||
| }) | ||
|
|
||
| #' Save the Decision Tree Classification model to the input path. | ||
| #' | ||
| #' @param object A fitted Decision tree classification model | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you check the output doc by running create-doc.sh - I think this will duplicate the |
||
| #' @param path The directory where the model is saved | ||
| #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE | ||
| #' which means throw exception if the output path exists. | ||
| #' | ||
| #' @aliases write.ml,DecisionTreeClassificationModel,character-method | ||
| #' @rdname spark.decisionTreeClassification | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to |
||
| #' @export | ||
| #' @note write.ml(DecisionTreeClassificationModel, character) since 2.1.0 | ||
| setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), | ||
| function(object, path, overwrite = FALSE) { | ||
| write_internal(object, path, overwrite) | ||
| }) | ||
|
|
||
| # Get the summary of an DecisionTreeRegressionModel model | ||
|
|
||
| #' @param object a fitted DecisionTreeRegressionModel or DecisionTreeClassificationModel | ||
| #' @return \code{summary} returns the model's features as lists, depth and number of nodes | ||
| #' or number of classes. | ||
| #' @rdname spark.decisionTree | ||
| #' @aliases summary,DecisionTreeRegressionModel-method | ||
| #' @export | ||
| #' @note summary(DecisionTreeRegressionModel) since 2.1.0 | ||
| setMethod("summary", signature(object = "DecisionTreeRegressionModel"), | ||
| function(object, ...) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do not put |
||
| jobj <- object@jobj | ||
| features <- callJMethod(jobj, "features") | ||
| depth <- callJMethod(jobj, "depth") | ||
| numNodes <- callJMethod(jobj, "numNodes") | ||
| ans <- list(features = features, depth = depth, numNodes = numNodes, jobj = jobj) | ||
| class(ans) <- "summary.DecisionTreeRegressionModel" | ||
| ans | ||
| }) | ||
|
|
||
| # Get the summary of an DecisionTreeClassificationModel model | ||
|
|
||
| #' @rdname spark.decisionTree | ||
| #' @aliases summary,DecisionTreeClassificationModel-method | ||
| #' @export | ||
| #' @note summary(DecisionTreeClassificationModel) since 2.1.0 | ||
| setMethod("summary", signature(object = "DecisionTreeClassificationModel"), | ||
| function(object, ...) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| jobj <- object@jobj | ||
| features <- callJMethod(jobj, "features") | ||
| depth <- callJMethod(jobj, "depth") | ||
| numNodes <- callJMethod(jobj, "numNodes") | ||
| numClasses <- callJMethod(jobj, "numClasses") | ||
| ans <- list(features = features, depth = depth, | ||
| numNodes = numNodes, numClasses = numClasses, jobj = jobj) | ||
| class(ans) <- "summary.DecisionTreeClassificationModel" | ||
| ans | ||
| }) | ||
|
|
||
| # Prints the summary of Decision Tree Regression Model | ||
|
|
||
| #' @rdname spark.decisionTree | ||
| #' @param x summary object of decisionTreeRegressionModel or decisionTreeClassificationModel | ||
| #' returned by \code{summary}. | ||
| #' @export | ||
| #' @note print.summary.DecisionTreeRegressionModel since 2.1.0 | ||
| print.summary.DecisionTreeRegressionModel <- function(x, ...) { | ||
| jobj <- x$jobj | ||
| summaryStr <- callJMethod(jobj, "summary") | ||
| cat(summaryStr, "\n") | ||
| invisible(x) | ||
| } | ||
|
|
||
| # Prints the summary of Decision Tree Classification Model | ||
|
|
||
| #' @rdname spark.decisionTree | ||
| #' @export | ||
| #' @note print.summary.DecisionTreeClassificationModel since 2.1.0 | ||
| print.summary.DecisionTreeClassificationModel <- function(x, ...) { | ||
| jobj <- x$jobj | ||
| summaryStr <- callJMethod(jobj, "summary") | ||
| cat(summaryStr, "\n") | ||
| invisible(x) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -791,4 +791,59 @@ test_that("spark.kstest", { | |
| expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") | ||
| }) | ||
|
|
||
| test_that("spark.decisionTree Regression", { | ||
| data <- suppressWarnings(createDataFrame(longley)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a test for print (see spark.glm) |
||
| model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) | ||
|
|
||
| #Test summary | ||
| stats <- summary(model) | ||
| expect_equal(stats$depth, 5) | ||
| expect_equal(stats$numNodes, 31) | ||
|
|
||
| #Test model predict | ||
| predictions <- collect(predict(model, data)) | ||
| expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, | ||
| 63.221, 63.639, 64.989, 63.761, | ||
| 66.019, 67.857, 68.169, 66.513, | ||
| 68.655, 69.564, 69.331, 70.551), | ||
| tolerance = 1e-4) | ||
|
|
||
| # Test model save/load | ||
| modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") | ||
| write.ml(model, modelPath) | ||
| expect_error(write.ml(model, modelPath)) | ||
| write.ml(model, modelPath, overwrite = TRUE) | ||
| model2 <- read.ml(modelPath) | ||
| stats2 <- summary(model2) | ||
| expect_equal(stats$depth, stats2$depth) | ||
| expect_equal(stats$numNodes, stats2$numNodes) | ||
|
|
||
| unlink(modelPath) | ||
| }) | ||
|
|
||
| test_that("spark.decisionTree Classification", { | ||
| data <- suppressWarnings(createDataFrame(iris)) | ||
| model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification", | ||
| maxDepth = 5, maxBins = 16) | ||
|
|
||
| #Test summary | ||
| stats <- summary(model) | ||
| expect_equal(stats$depth, 5) | ||
| expect_equal(stats$numNodes, 19) | ||
| expect_equal(stats$numClasses, 3) | ||
|
|
||
| # Test model save/load | ||
| modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") | ||
| write.ml(model, modelPath) | ||
| expect_error(write.ml(model, modelPath)) | ||
| write.ml(model, modelPath, overwrite = TRUE) | ||
| model2 <- read.ml(modelPath) | ||
| stats2 <- summary(model2) | ||
| expect_equal(stats$depth, stats2$depth) | ||
| expect_equal(stats$numNodes, stats2$numNodes) | ||
| expect_equal(stats$numClasses, stats2$numClasses) | ||
|
|
||
| unlink(modelPath) | ||
| }) | ||
|
|
||
| sparkR.session.stop() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you point this url to the Spark programming guide, like http://spark.apache.org/docs/latest/ml-classification-regression.html