diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 267a38c21530..ea06df9363f2 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -43,7 +43,8 @@ exportMethods("glm", "spark.isoreg", "spark.gaussianMixture", "spark.als", - "spark.kstest") + "spark.kstest", + "spark.decisionTree") # Job group lifecycle management methods export("setJobGroup", @@ -347,7 +348,9 @@ export("as.DataFrame", "uncacheTable", "print.summary.GeneralizedLinearRegressionModel", "read.ml", - "print.summary.KSTest") + "print.summary.KSTest", + "print.summary.DecisionTreeRegressionModel", + "print.summary.DecisionTreeClassificationModel") export("structField", "structField.jobj", @@ -372,6 +375,8 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.DecisionTreeRegressionModel) +S3method(print, summary.DecisionTreeClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 90a02e277831..02c3bfc1b00b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1358,6 +1358,11 @@ setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.p #' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) +#' @rdname spark.decisionTree +#' @export +setGeneric("spark.decisionTree", + function(data, formula, ...) { standardGeneric("spark.decisionTree") }) + #' @rdname spark.gaussianMixture #' @export setGeneric("spark.gaussianMixture", diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b901307f8f40..ec779bf6fc08 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -95,6 +95,20 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) +#' S4 class that represents a DecisionTreeRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel +#' @export +#' @note DecisionTreeRegressionModel since 2.1.0 +setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a DecisionTreeClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel +#' @export +#' @note DecisionTreeClassificationModel since 2.1.0 +setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -103,8 +117,9 @@ setClass("KSTest", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.als}, link{spark.decisionTree}, \link{spark.gaussianMixture}, +#' @seealso \link{spark.isoreg}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.mlp}, +#' @seealso \link{spark.naiveBayes}, \link{spark.survreg}, #' @seealso \link{read.ml} NULL @@ -116,8 +131,9 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.als}, \link{spark.decisionTree}, \link{spark.gaussianMixture}, +#' @seealso \link{spark.isoreg}, \link{spark.kmeans}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.survreg}, NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -932,6 +948,10 @@ read.ml <- function(path) { new("GaussianMixtureModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { new("ALSModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) { + new("DecisionTreeRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) { + new("DecisionTreeClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1427,3 +1447,185 @@ print.summary.KSTest <- function(x, ...) { cat(summaryStr, "\n") invisible(x) } + +#' Decision Tree Model for Regression and Classification +#' +#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see \href{https://en.wikipedia.org/wiki/Decision_tree_learning}{Decision Tree} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' @param ... additional arguments passed to the method. +#' @aliases spark.decisionTree,SparkDataFrame,formula-method +#' @return \code{spark.decisionTree} returns a fitted Decision Tree model. +#' @rdname spark.decisionTree +#' @name spark.decisionTree +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(longley) +#' +#' # fit a Decision Tree Regression Model +#' model <- spark.decisionTree(data, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.decisionTree since 2.1.0 +setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32 ) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + switch(type, + regression = { + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins)) + new("DecisionTreeRegressionModel", jobj = jobj) + }, + classification = { + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins)) + new("DecisionTreeClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Decision Tree Regression model or +# a model produced by spark.decisionTree() + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.decisionTree +#' @export +#' @note predict(decisionTreeRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "DecisionTreeRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.decisionTree +#' @export +#' @note predict(decisionTreeClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "DecisionTreeClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' Save the Decision Tree Regression model to the input path. +#' +#' @param object A fitted Decision tree regression model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,DecisionTreeRegressionModel,character-method +#' @rdname spark.decisionTreeRegression +#' @export +#' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Save the Decision Tree Classification model to the input path. +#' +#' @param object A fitted Decision tree classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,DecisionTreeClassificationModel,character-method +#' @rdname spark.decisionTreeClassification +#' @export +#' @note write.ml(DecisionTreeClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of an DecisionTreeRegressionModel model + +#' @param object a fitted DecisionTreeRegressionModel or DecisionTreeClassificationModel +#' @return \code{summary} returns the model's features as lists, depth and number of nodes +#' or number of classes. +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeRegressionModel-method +#' @export +#' @note summary(DecisionTreeRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "DecisionTreeRegressionModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + depth <- callJMethod(jobj, "depth") + numNodes <- callJMethod(jobj, "numNodes") + ans <- list(features = features, depth = depth, numNodes = numNodes, jobj = jobj) + class(ans) <- "summary.DecisionTreeRegressionModel" + ans + }) + +# Get the summary of an DecisionTreeClassificationModel model + +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeClassificationModel-method +#' @export +#' @note summary(DecisionTreeClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "DecisionTreeClassificationModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + depth <- callJMethod(jobj, "depth") + numNodes <- callJMethod(jobj, "numNodes") + numClasses <- callJMethod(jobj, "numClasses") + ans <- list(features = features, depth = depth, + numNodes = numNodes, numClasses = numClasses, jobj = jobj) + class(ans) <- "summary.DecisionTreeClassificationModel" + ans + }) + +# Prints the summary of Decision Tree Regression Model + +#' @rdname spark.decisionTree +#' @param x summary object of decisionTreeRegressionModel or decisionTreeClassificationModel +#' returned by \code{summary}. +#' @export +#' @note print.summary.DecisionTreeRegressionModel since 2.1.0 +print.summary.DecisionTreeRegressionModel <- function(x, ...) { + jobj <- x$jobj + summaryStr <- callJMethod(jobj, "summary") + cat(summaryStr, "\n") + invisible(x) +} + +# Prints the summary of Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeClassificationModel since 2.1.0 +print.summary.DecisionTreeClassificationModel <- function(x, ...) { + jobj <- x$jobj + summaryStr <- callJMethod(jobj, "summary") + cat(summaryStr, "\n") + invisible(x) +} diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index a1eaaf20916a..3833b0370667 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -791,4 +791,59 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) +test_that("spark.decisionTree Regression", { + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) + + #Test summary + stats <- summary(model) + expect_equal(stats$depth, 5) + expect_equal(stats$numNodes, 31) + + #Test model predict + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + + unlink(modelPath) +}) + +test_that("spark.decisionTree Classification", { + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + #Test summary + stats <- summary(model) + expect_equal(stats$depth, 5) + expect_equal(stats$numNodes, 19) + expect_equal(stats$numClasses, 3) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) +}) + sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassifierWrapper.scala new file mode 100644 index 000000000000..36a8670ceb1f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassifierWrapper.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class DecisionTreeClassifierWrapper private ( + val pipeline: PipelineModel, + val features: Array[String], + val maxDepth: Int, + val maxBins: Int) extends MLWritable { + + private val DTModel: DecisionTreeClassificationModel = + pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel] + + lazy val depth: Int = DTModel.depth + lazy val numNodes: Int = DTModel.numNodes + lazy val numClasses: Int = DTModel.numClasses + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + DecisionTreeClassifierWrapper.DecisionTreeClassifierWrapperWriter(this) +} + +private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeClassifierWrapper] { + def fit(data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int): DecisionTreeClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + .setFeaturesCol("features") + .setLabelCol("label") + + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val decisionTreeClassification = new DecisionTreeClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setFeaturesCol(rFormula.getFeaturesCol) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, decisionTreeClassification)) + .fit(data) + + new DecisionTreeClassifierWrapper(pipeline, features, maxDepth, maxBins) + } + + override def read: MLReader[DecisionTreeClassifierWrapper] = + new DecisionTreeClassifierWrapperReader + + override def load(path: String): DecisionTreeClassifierWrapper = super.load(path) + + class DecisionTreeClassifierWrapperWriter(instance: DecisionTreeClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("features" -> instance.features.toSeq) ~ + ("maxDepth" -> instance.maxDepth) ~ + ("maxBins" -> instance.maxBins) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class DecisionTreeClassifierWrapperReader extends MLReader[DecisionTreeClassifierWrapper] { + + override def load(path: String): DecisionTreeClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val features = (rMetadata \ "features").extract[Array[String]] + val maxDepth = (rMetadata \ "maxDepth").extract[Int] + val maxBins = (rMetadata \ "maxBins").extract[Int] + + new DecisionTreeClassifierWrapper(pipeline, features, maxDepth, maxBins) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressorWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressorWrapper.scala new file mode 100644 index 000000000000..138fb826dae1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressorWrapper.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class DecisionTreeRegressorWrapper private ( + val pipeline: PipelineModel, + val features: Array[String], + val maxDepth: Int, + val maxBins: Int) extends MLWritable { + + private val DTModel: DecisionTreeRegressionModel = + pipeline.stages(1).asInstanceOf[DecisionTreeRegressionModel] + + lazy val depth: Int = DTModel.depth + lazy val numNodes: Int = DTModel.numNodes + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + DecisionTreeRegressorWrapper.DecisionTreeRegressorWrapperWriter(this) +} + +private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRegressorWrapper] { + def fit(data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int): DecisionTreeRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + .setFeaturesCol("features") + .setLabelCol("label") + + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val decisionTreeRegression = new DecisionTreeRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setFeaturesCol(rFormula.getFeaturesCol) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, decisionTreeRegression)) + .fit(data) + + new DecisionTreeRegressorWrapper(pipeline, features, maxDepth, maxBins) + } + + override def read: MLReader[DecisionTreeRegressorWrapper] = new DecisionTreeRegressorWrapperReader + + override def load(path: String): DecisionTreeRegressorWrapper = super.load(path) + + class DecisionTreeRegressorWrapperWriter(instance: DecisionTreeRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("features" -> instance.features.toSeq) ~ + ("maxDepth" -> instance.maxDepth) ~ + ("maxBins" -> instance.maxBins) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class DecisionTreeRegressorWrapperReader extends MLReader[DecisionTreeRegressorWrapper] { + + override def load(path: String): DecisionTreeRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val features = (rMetadata \ "features").extract[Array[String]] + val maxDepth = (rMetadata \ "maxDepth").extract[Int] + val maxBins = (rMetadata \ "maxBins").extract[Int] + + new DecisionTreeRegressorWrapper(pipeline, features, maxDepth, maxBins) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index d64de1b6abb6..e31de4689afe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -54,6 +54,10 @@ private[r] object RWrappers extends MLReader[Object] { GaussianMixtureWrapper.load(path) case "org.apache.spark.ml.r.ALSWrapper" => ALSWrapper.load(path) + case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" => + DecisionTreeRegressorWrapper.load(path) + case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" => + DecisionTreeClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") }