Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
0f2f56c
[SPARK-20736][PYTHON] PySpark StringIndexer supports StringOrderType
May 21, 2017
3c9eef3
[SPARK-20786][SQL] Improve ceil and floor handle the value which is n…
heary-cao May 22, 2017
833c8d4
[SPARK-20770][SQL] Improve ColumnStats
kiszk May 22, 2017
a2b3b67
[SPARK-19089][SQL] Add support for nested sequences
michalsenkyr May 22, 2017
06dda1d
[SPARK-20687][MLLIB] mllib.Matrices.fromBreeze may crash when convert…
ghoto May 22, 2017
be846db
[SPARK-20506][DOCS] Add HTML links to highlight list in MLlib guide f…
May 22, 2017
190d8b0
[SPARK-20591][WEB UI] Succeeded tasks num not equal in all jobs page …
fjh100456 May 22, 2017
f1ffc6e
[SPARK-20609][CORE] Run the SortShuffleSuite unit tests have residual…
heary-cao May 22, 2017
aea73be
[SPARK-20813][WEB UI] Fixed Web UI executor page tab search by status…
May 22, 2017
2597674
[SPARK-20801] Record accurate size of blocks in MapStatus when it's a…
May 22, 2017
f3ed62a
[SPARK-20831][SQL] Fix INSERT OVERWRITE data source tables with IF NO…
gatorsmile May 22, 2017
cfca011
[SPARK-20764][ML][PYSPARK] Fix visibility discrepancy with numInstanc…
May 22, 2017
3630911
[SPARK-20756][YARN] yarn-shuffle jar references unshaded guava
markgrover May 22, 2017
4be3375
[SPARK-15767][ML][SPARKR] Decision Tree wrapper in SparkR
zhengruifeng May 22, 2017
df64fa7
[SPARK-20814][MESOS] Restore support for spark.executor.extraClassPath.
May 22, 2017
9b09101
[SPARK-20751][SQL][FOLLOWUP] Add cot test in MathExpressionsSuite
wangyum May 22, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ exportMethods("glm",
"spark.als",
"spark.kstest",
"spark.logit",
"spark.decisionTree",
"spark.randomForest",
"spark.gbt",
"spark.bisectingKmeans",
Expand Down Expand Up @@ -414,6 +415,8 @@ export("as.DataFrame",
"print.summary.GeneralizedLinearRegressionModel",
"read.ml",
"print.summary.KSTest",
"print.summary.DecisionTreeRegressionModel",
"print.summary.DecisionTreeClassificationModel",
"print.summary.RandomForestRegressionModel",
"print.summary.RandomForestClassificationModel",
"print.summary.GBTRegressionModel",
Expand Down Expand Up @@ -452,6 +455,8 @@ S3method(print, structField)
S3method(print, structType)
S3method(print, summary.GeneralizedLinearRegressionModel)
S3method(print, summary.KSTest)
S3method(print, summary.DecisionTreeRegressionModel)
S3method(print, summary.DecisionTreeClassificationModel)
S3method(print, summary.RandomForestRegressionModel)
S3method(print, summary.RandomForestClassificationModel)
S3method(print, summary.GBTRegressionModel)
Expand Down
5 changes: 5 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml
#' @export
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })

#' @rdname spark.decisionTree
#' @export
setGeneric("spark.decisionTree",
function(data, formula, ...) { standardGeneric("spark.decisionTree") })

#' @rdname spark.randomForest
#' @export
setGeneric("spark.randomForest",
Expand Down
240 changes: 240 additions & 0 deletions R/pkg/R/mllib_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
#' @note RandomForestClassificationModel since 2.1.0
setClass("RandomForestClassificationModel", representation(jobj = "jobj"))

#' S4 class that represents a DecisionTreeRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
#' @export
#' @note DecisionTreeRegressionModel since 2.3.0
setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))

#' S4 class that represents a DecisionTreeClassificationModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
#' @export
#' @note DecisionTreeClassificationModel since 2.3.0
setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))

# Create the summary of a tree ensemble model (eg. Random Forest, GBT)
summary.treeEnsemble <- function(model) {
jobj <- model@jobj
Expand Down Expand Up @@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) {
invisible(x)
}

# Create the summary of a decision tree model
summary.decisionTree <- function(model) {
jobj <- model@jobj
formula <- callJMethod(jobj, "formula")
numFeatures <- callJMethod(jobj, "numFeatures")
features <- callJMethod(jobj, "features")
featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
maxDepth <- callJMethod(jobj, "maxDepth")
list(formula = formula,
numFeatures = numFeatures,
features = features,
featureImportances = featureImportances,
maxDepth = maxDepth,
jobj = jobj)
}

# Prints the summary of decision tree models
print.summary.decisionTree <- function(x) {
jobj <- x$jobj
cat("Formula: ", x$formula)
cat("\nNumber of features: ", x$numFeatures)
cat("\nFeatures: ", unlist(x$features))
cat("\nFeature importances: ", x$featureImportances)
cat("\nMax Depth: ", x$maxDepth)

summaryStr <- callJMethod(jobj, "summary")
cat("\n", summaryStr, "\n")
invisible(x)
}

#' Gradient Boosted Tree Model for Regression and Classification
#'
#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a
Expand Down Expand Up @@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' Decision Tree Model for Regression and Classification
#'
#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
#' save/load fitted models.
#' For more details, see
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
#' Decision Tree Regression} and
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
#' Decision Tree Classification}
#'
#' @param data a SparkDataFrame for training.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' @param type type of model, one of "regression" or "classification", to fit
#' @param maxDepth Maximum depth of the tree (>= 0).
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
#' how to split on features at each node. More bins give higher granularity. Must be
#' >= 2 and >= number of categories in any categorical feature.
#' @param impurity Criterion used for information gain calculation.
#' For regression, must be "variance". For classification, must be one of
#' "entropy" and "gini", default is "gini".
#' @param seed integer seed for random number generation.
#' @param minInstancesPerNode Minimum number of instances each child must have after split.
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param ... additional arguments passed to the method.
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
#' @rdname spark.decisionTree
#' @name spark.decisionTree
#' @export
#' @examples
#' \dontrun{
#' # fit a Decision Tree Regression Model
#' df <- createDataFrame(longley)
#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
#'
#' # get the summary of the model
#' summary(model)
#'
#' # make predictions
#' predictions <- predict(model, df)
#'
#' # save and load the model
#' path <- "path/to/model"
#' write.ml(model, path)
#' savedModel <- read.ml(path)
#' summary(savedModel)
#'
#' # fit a Decision Tree Classification Model
#' t <- as.data.frame(Titanic)
#' df <- createDataFrame(t)
#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
#' }
#' @note spark.decisionTree since 2.3.0
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
seed <- as.character(as.integer(seed))
}
switch(type,
regression = {
if (is.null(impurity)) impurity <- "variance"
impurity <- match.arg(impurity, "variance")
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("DecisionTreeRegressionModel", jobj = jobj)
},
classification = {
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("DecisionTreeClassificationModel", jobj = jobj)
}
)
})

# Get the summary of a Decision Tree Regression Model

#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list of components includes \code{formula} (formula),
#' \code{numFeatures} (number of features), \code{features} (list of features),
#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees).
#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeRegressionModel-method
#' @export
#' @note summary(DecisionTreeRegressionModel) since 2.3.0
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
function(object) {
ans <- summary.decisionTree(object)
class(ans) <- "summary.DecisionTreeRegressionModel"
ans
})

# Prints the summary of Decision Tree Regression Model

#' @param x summary object of Decision Tree regression model or classification model
#' returned by \code{summary}.
#' @rdname spark.decisionTree
#' @export
#' @note print.summary.DecisionTreeRegressionModel since 2.3.0
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
print.summary.decisionTree(x)
}

# Get the summary of a Decision Tree Classification Model

#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeClassificationModel-method
#' @export
#' @note summary(DecisionTreeClassificationModel) since 2.3.0
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
function(object) {
ans <- summary.decisionTree(object)
class(ans) <- "summary.DecisionTreeClassificationModel"
ans
})

# Prints the summary of Decision Tree Classification Model

#' @rdname spark.decisionTree
#' @export
#' @note print.summary.DecisionTreeClassificationModel since 2.3.0
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
print.summary.decisionTree(x)
}

# Makes predictions from a Decision Tree Regression model or Classification model

#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
#' "prediction".
#' @rdname spark.decisionTree
#' @aliases predict,DecisionTreeRegressionModel-method
#' @export
#' @note predict(DecisionTreeRegressionModel) since 2.3.0
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
function(object, newData) {
predict_internal(object, newData)
})

#' @rdname spark.decisionTree
#' @aliases predict,DecisionTreeClassificationModel-method
#' @export
#' @note predict(DecisionTreeClassificationModel) since 2.3.0
setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
function(object, newData) {
predict_internal(object, newData)
})

# Save the Decision Tree Regression or Classification model to the input path.

#' @param object A fitted Decision Tree regression model or classification model.
#' @param path The directory where the model is saved.
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
#' @rdname spark.decisionTree
#' @export
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' @aliases write.ml,DecisionTreeClassificationModel,character-method
#' @rdname spark.decisionTree
#' @export
#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
14 changes: 10 additions & 4 deletions R/pkg/R/mllib_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
#' @rdname write.ml
#' @name write.ml
#' @export
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.logit},
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
Expand All @@ -48,8 +49,9 @@ NULL
#' @rdname predict
#' @name predict
#' @export
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}
Expand Down Expand Up @@ -110,6 +112,10 @@ read.ml <- function(path) {
new("RandomForestRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
new("RandomForestClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
new("DecisionTreeRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
new("DecisionTreeClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) {
new("GBTRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) {
Expand Down
Loading