Skip to content

Commit 6fc2b65

Browse files
committed
[SPARK-11888][ML] Decision tree persistence in spark.ml
### What changes were proposed in this pull request? Made these MLReadable and MLWritable: DecisionTreeClassifier, DecisionTreeClassificationModel, DecisionTreeRegressor, DecisionTreeRegressionModel * The shared implementation is in treeModels.scala * I use case classes to create a DataFrame to save, and I use the Dataset API to parse loaded files. Other changes: * Made CategoricalSplit.numCategories public (to use in persistence) * Fixed a bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite, where it did not call the checkModelData function passed as an argument. This caused an error in LDASuite, which I fixed. ### How was this patch tested? Persistence is tested via unit tests. For each algorithm, there are 2 non-trivial trees (depth 2). One is built with continuous features, and one with categorical; this ensures that both types of splits are tested. Author: Joseph K. Bradley <[email protected]> Closes #11581 from jkbradley/dt-io.
1 parent 3f06eb7 commit 6fc2b65

23 files changed

+428
-71
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,24 @@
1717

1818
package org.apache.spark.ml.classification
1919

20+
import org.apache.hadoop.fs.Path
21+
import org.json4s.{DefaultFormats, JObject}
22+
import org.json4s.JsonDSL._
23+
2024
import org.apache.spark.annotation.{Experimental, Since}
2125
import org.apache.spark.ml.param.ParamMap
22-
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
26+
import org.apache.spark.ml.tree._
27+
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
2328
import org.apache.spark.ml.tree.impl.RandomForest
24-
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
29+
import org.apache.spark.ml.util._
2530
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2631
import org.apache.spark.mllib.regression.LabeledPoint
2732
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
2833
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
2934
import org.apache.spark.rdd.RDD
3035
import org.apache.spark.sql.DataFrame
3136

37+
3238
/**
3339
* :: Experimental ::
3440
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
@@ -41,7 +47,7 @@ import org.apache.spark.sql.DataFrame
4147
final class DecisionTreeClassifier @Since("1.4.0") (
4248
@Since("1.4.0") override val uid: String)
4349
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
44-
with DecisionTreeParams with TreeClassifierParams {
50+
with DecisionTreeClassifierParams with DefaultParamsWritable {
4551

4652
@Since("1.4.0")
4753
def this() = this(Identifiable.randomUID("dtc"))
@@ -115,10 +121,13 @@ final class DecisionTreeClassifier @Since("1.4.0") (
115121

116122
@Since("1.4.0")
117123
@Experimental
118-
object DecisionTreeClassifier {
124+
object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] {
119125
/** Accessor for supported impurities: entropy, gini */
120126
@Since("1.4.0")
121127
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
128+
129+
@Since("2.0.0")
130+
override def load(path: String): DecisionTreeClassifier = super.load(path)
122131
}
123132

124133
/**
@@ -135,7 +144,7 @@ final class DecisionTreeClassificationModel private[ml] (
135144
@Since("1.6.0")override val numFeatures: Int,
136145
@Since("1.5.0")override val numClasses: Int)
137146
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
138-
with DecisionTreeModel with Serializable {
147+
with DecisionTreeModel with DecisionTreeClassifierParams with MLWritable with Serializable {
139148

140149
require(rootNode != null,
141150
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
@@ -200,12 +209,57 @@ final class DecisionTreeClassificationModel private[ml] (
200209
private[ml] def toOld: OldDecisionTreeModel = {
201210
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
202211
}
212+
213+
@Since("2.0.0")
214+
override def write: MLWriter =
215+
new DecisionTreeClassificationModel.DecisionTreeClassificationModelWriter(this)
203216
}
204217

205-
private[ml] object DecisionTreeClassificationModel {
218+
@Since("2.0.0")
219+
object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassificationModel] {
220+
221+
@Since("2.0.0")
222+
override def read: MLReader[DecisionTreeClassificationModel] =
223+
new DecisionTreeClassificationModelReader
224+
225+
@Since("2.0.0")
226+
override def load(path: String): DecisionTreeClassificationModel = super.load(path)
227+
228+
private[DecisionTreeClassificationModel]
229+
class DecisionTreeClassificationModelWriter(instance: DecisionTreeClassificationModel)
230+
extends MLWriter {
231+
232+
override protected def saveImpl(path: String): Unit = {
233+
val extraMetadata: JObject = Map(
234+
"numFeatures" -> instance.numFeatures,
235+
"numClasses" -> instance.numClasses)
236+
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
237+
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
238+
val dataPath = new Path(path, "data").toString
239+
sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
240+
}
241+
}
242+
243+
private class DecisionTreeClassificationModelReader
244+
extends MLReader[DecisionTreeClassificationModel] {
245+
246+
/** Checked against metadata when loading model */
247+
private val className = classOf[DecisionTreeClassificationModel].getName
248+
249+
override def load(path: String): DecisionTreeClassificationModel = {
250+
implicit val format = DefaultFormats
251+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
252+
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
253+
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
254+
val root = loadTreeNodes(path, metadata, sqlContext)
255+
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
256+
DefaultParamsReader.getAndSetParams(model, metadata)
257+
model
258+
}
259+
}
206260

207-
/** (private[ml]) Convert a model from the old API */
208-
def fromOld(
261+
/** Convert a model from the old API */
262+
private[ml] def fromOld(
209263
oldModel: OldDecisionTreeModel,
210264
parent: DecisionTreeClassifier,
211265
categoricalFeatures: Map[Int, Int],

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,26 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
101101
}
102102

103103
/** Decodes a param value from JSON. */
104-
def jsonDecode(json: String): T = {
104+
def jsonDecode(json: String): T = Param.jsonDecode[T](json)
105+
106+
private[this] val stringRepresentation = s"${parent}__$name"
107+
108+
override final def toString: String = stringRepresentation
109+
110+
override final def hashCode: Int = toString.##
111+
112+
override final def equals(obj: Any): Boolean = {
113+
obj match {
114+
case p: Param[_] => (p.parent == parent) && (p.name == name)
115+
case _ => false
116+
}
117+
}
118+
}
119+
120+
private[ml] object Param {
121+
122+
/** Decodes a param value from JSON. */
123+
def jsonDecode[T](json: String): T = {
105124
parse(json) match {
106125
case JString(x) =>
107126
x.asInstanceOf[T]
@@ -116,19 +135,6 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
116135
s"${this.getClass.getName} must override jsonDecode to support its value type.")
117136
}
118137
}
119-
120-
private[this] val stringRepresentation = s"${parent}__$name"
121-
122-
override final def toString: String = stringRepresentation
123-
124-
override final def hashCode: Int = toString.##
125-
126-
override final def equals(obj: Any): Boolean = {
127-
obj match {
128-
case p: Param[_] => (p.parent == parent) && (p.name == name)
129-
case _ => false
130-
}
131-
}
132138
}
133139

134140
/**

mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717

1818
package org.apache.spark.ml.regression
1919

20+
import org.apache.hadoop.fs.Path
21+
import org.json4s.{DefaultFormats, JObject}
22+
import org.json4s.JsonDSL._
23+
2024
import org.apache.spark.annotation.{Experimental, Since}
2125
import org.apache.spark.ml.{PredictionModel, Predictor}
2226
import org.apache.spark.ml.param.ParamMap
2327
import org.apache.spark.ml.tree._
28+
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
2429
import org.apache.spark.ml.tree.impl.RandomForest
25-
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
30+
import org.apache.spark.ml.util._
2631
import org.apache.spark.mllib.linalg.Vector
2732
import org.apache.spark.mllib.regression.LabeledPoint
2833
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
@@ -31,6 +36,7 @@ import org.apache.spark.rdd.RDD
3136
import org.apache.spark.sql.DataFrame
3237
import org.apache.spark.sql.functions._
3338

39+
3440
/**
3541
* :: Experimental ::
3642
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
@@ -41,7 +47,7 @@ import org.apache.spark.sql.functions._
4147
@Experimental
4248
final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
4349
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
44-
with DecisionTreeRegressorParams {
50+
with DecisionTreeRegressorParams with DefaultParamsWritable {
4551

4652
@Since("1.4.0")
4753
def this() = this(Identifiable.randomUID("dtr"))
@@ -107,9 +113,12 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
107113

108114
@Since("1.4.0")
109115
@Experimental
110-
object DecisionTreeRegressor {
116+
object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
111117
/** Accessor for supported impurities: variance */
112118
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
119+
120+
@Since("2.0.0")
121+
override def load(path: String): DecisionTreeRegressor = super.load(path)
113122
}
114123

115124
/**
@@ -125,13 +134,13 @@ final class DecisionTreeRegressionModel private[ml] (
125134
override val rootNode: Node,
126135
override val numFeatures: Int)
127136
extends PredictionModel[Vector, DecisionTreeRegressionModel]
128-
with DecisionTreeModel with DecisionTreeRegressorParams with Serializable {
137+
with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
129138

130139
/** @group setParam */
131140
def setVarianceCol(value: String): this.type = set(varianceCol, value)
132141

133142
require(rootNode != null,
134-
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
143+
"DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.")
135144

136145
/**
137146
* Construct a decision tree regression model.
@@ -200,12 +209,55 @@ final class DecisionTreeRegressionModel private[ml] (
200209
private[ml] def toOld: OldDecisionTreeModel = {
201210
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
202211
}
212+
213+
@Since("2.0.0")
214+
override def write: MLWriter =
215+
new DecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this)
203216
}
204217

205-
private[ml] object DecisionTreeRegressionModel {
218+
@Since("2.0.0")
219+
object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionModel] {
220+
221+
@Since("2.0.0")
222+
override def read: MLReader[DecisionTreeRegressionModel] =
223+
new DecisionTreeRegressionModelReader
224+
225+
@Since("2.0.0")
226+
override def load(path: String): DecisionTreeRegressionModel = super.load(path)
227+
228+
private[DecisionTreeRegressionModel]
229+
class DecisionTreeRegressionModelWriter(instance: DecisionTreeRegressionModel)
230+
extends MLWriter {
231+
232+
override protected def saveImpl(path: String): Unit = {
233+
val extraMetadata: JObject = Map(
234+
"numFeatures" -> instance.numFeatures)
235+
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
236+
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
237+
val dataPath = new Path(path, "data").toString
238+
sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
239+
}
240+
}
241+
242+
private class DecisionTreeRegressionModelReader
243+
extends MLReader[DecisionTreeRegressionModel] {
244+
245+
/** Checked against metadata when loading model */
246+
private val className = classOf[DecisionTreeRegressionModel].getName
247+
248+
override def load(path: String): DecisionTreeRegressionModel = {
249+
implicit val format = DefaultFormats
250+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
251+
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
252+
val root = loadTreeNodes(path, metadata, sqlContext)
253+
val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
254+
DefaultParamsReader.getAndSetParams(model, metadata)
255+
model
256+
}
257+
}
206258

207-
/** (private[ml]) Convert a model from the old API */
208-
def fromOld(
259+
/** Convert a model from the old API */
260+
private[ml] def fromOld(
209261
oldModel: OldDecisionTreeModel,
210262
parent: DecisionTreeRegressor,
211263
categoricalFeatures: Map[Int, Int],

mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.tree
1919

20-
import org.apache.spark.annotation.DeveloperApi
20+
import org.apache.spark.annotation.{DeveloperApi, Since}
2121
import org.apache.spark.mllib.linalg.Vector
2222
import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
2323
import org.apache.spark.mllib.tree.model.{Split => OldSplit}
@@ -76,7 +76,7 @@ private[tree] object Split {
7676
final class CategoricalSplit private[ml] (
7777
override val featureIndex: Int,
7878
_leftCategories: Array[Double],
79-
private val numCategories: Int)
79+
@Since("2.0.0") val numCategories: Int)
8080
extends Split {
8181

8282
require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +

0 commit comments

Comments
 (0)